diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..6e3e3f2 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,75 @@ +import pytest +import numpy as np +import scipy.sparse as sp + +from lightfm import LightFM + +# set our default random seed +SEED = 42 + + +@pytest.fixture(scope="session") +def rng(): + """Initialise a shared random number generator for all tests.""" + + return np.random.RandomState(SEED) + + +@pytest.fixture(scope="session") +def array_int32(rng, size=10): + """Initialise an array of type np.int32 of size `size`.""" + + return rng.randint(0, 100, size=size, dtype=np.int32) + + +@pytest.fixture( + scope="session", ids=["tuple", "list", "ndarray"], params=[tuple, list, np.array] +) +def user_ids(array_int32, request): + """Initialise input user_ids valid for calls to the LightFM.predict method. + + Notes + ----- + On parameterized pytest fixtures: This fixture will iterate over all passed + `params`. This avoids having to apply a `pytest.mark.parameterize` decorator to + every test that needs the same `user_ids`. + + You can find out more about parameterized fixtures in the pytest docs: + https://docs.pytest.org/en/stable/parametrize.html + + """ + + _type = request.param + yield _type(array_int32) + + +@pytest.fixture( + scope="session", ids=["tuple", "list", "ndarray"], params=[tuple, list, np.array] +) +def item_ids(array_int32, request): + """Initialise input item_ids valid for calls to the LightFM.predict method. + + Notes + ----- + See `user_ids` fixture for a note on parameterized fixtures. + + """ + _type = request.param + yield _type(array_int32) + + +@pytest.fixture(scope="session") +def train_matrix(rng, n_users=1000, n_items=1000): + """Create a random sparse CSR matrix of shape (n_users, n_items) for training.""" + + return sp.rand(n_users, n_items, format="csr", random_state=rng) + + +@pytest.fixture(scope="session") +def lfm(train_matrix): + """Create a _trained_ LightFM model instance.""" + + model = LightFM() + model.fit(train_matrix) + + return model diff --git a/tests/test_api.py b/tests/test_api.py index 2863212..54ad46f 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -380,3 +380,16 @@ def test_warp_few_items(): model = LightFM(loss="warp", max_sampled=10) model.fit(train) + + +def test_predict_user_item_inputs_with_valid_types(lfm, user_ids, item_ids): + """Test that calls to the predict method with inputs of valid types succeed.""" + + # GIVEN user_ids of a valid type + # AND item_ids of a valid type + # WHEN trained model provided + # THEN calls to LightFM.predict succeed + + h = lfm.predict(user_ids=user_ids, item_ids=item_ids) + assert h.dtype == np.float32 + assert len(h) == len(user_ids)