From bc4ab0794a5e65bd2880deb76e467069ae341c6b Mon Sep 17 00:00:00 2001 From: Mark Douthwaite Date: Mon, 8 Feb 2021 18:36:30 +0000 Subject: [PATCH 1/6] Add improved tests for new type checking logic. --- tests/conftest.py | 79 +++++++++++++++++++++++++++++++++++++++++++++++ tests/test_api.py | 13 ++++++++ 2 files changed, 92 insertions(+) create mode 100644 tests/conftest.py diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..2ee29a3 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,79 @@ +import pytest +import numpy as np +import scipy.sparse as sp + +from lightfm import LightFM + +# set our default random seed +SEED = 42 + + +@pytest.fixture(scope="session") +def rng(): + """Initialise a shared random number generator for all tests.""" + + return np.random.RandomState(SEED) + + +@pytest.fixture(scope="session") +def array_int32(rng, size=10): + """Initialise an array of type np.int32 of size `size`.""" + + return rng.randint(0, 100, size=size, dtype=np.int32) + + +@pytest.fixture( + scope="session", + ids=["tuple", "list", "ndarray"], + params=[tuple, list, np.array] +) +def user_ids(array_int32, request): + """Initialise valid input user_ids for calls to the LightFM.predict method. + + Notes + ----- + On parameterized pytest fixtures: This fixture will iterate over all passed + `params`. This avoids having to apply a `pytest.mark.parameterize` decorator to + every test that needs the same `user_ids`. + + You can find out more about parameterized fixtures in the pytest docs: + https://docs.pytest.org/en/stable/parametrize.html + + """ + + _type = request.param + yield _type(array_int32) + + +@pytest.fixture( + scope="session", + ids=["tuple", "list", "ndarray"], + params=[tuple, list, np.array] +) +def item_ids(array_int32, request): + """Initialise valid input item_ids for calls to the LightFM.predict method. + + Notes + ----- + See `user_ids` fixture for a note on parameterized fixtures. + + """ + _type = request.param + yield _type(array_int32) + + +@pytest.fixture(scope="session") +def train_matrix(rng, n_users=1000, n_items=1000): + """Create a random sparse CSR matrix of shape (n_users, n_items) for training.""" + + return sp.rand(n_users, n_items, format="csr", random_state=rng) + + +@pytest.fixture(scope="session") +def lfm(train_matrix): + """Create a _trained_ LightFM model instance.""" + + model = LightFM() + model.fit(train_matrix) + + return model diff --git a/tests/test_api.py b/tests/test_api.py index 2863212..d766987 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -380,3 +380,16 @@ def test_warp_few_items(): model = LightFM(loss="warp", max_sampled=10) model.fit(train) + + +def test_predict_input_with_valid_types(lfm, user_ids, item_ids): + """Test that calls to the predict method with inputs of valid types succeed.""" + + # GIVEN user_ids of a valid type (tuple, list, ndarray) + # AND item_ids of a valid type (tuple, list, ndarray) + # WHEN trained model provided + # THEN calls to LightFM.predict succeed + + h = lfm.predict(user_ids=user_ids, item_ids=item_ids) + assert h.dtype == np.float32 + assert len(h) == len(user_ids) From 4f56d882e546afd51b4976a57a50a1644d8e4658 Mon Sep 17 00:00:00 2001 From: Mark Douthwaite Date: Mon, 8 Feb 2021 18:46:41 +0000 Subject: [PATCH 2/6] Drop explicit reference to valid types in new test --- tests/test_api.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_api.py b/tests/test_api.py index d766987..c775625 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -385,8 +385,8 @@ def test_warp_few_items(): def test_predict_input_with_valid_types(lfm, user_ids, item_ids): """Test that calls to the predict method with inputs of valid types succeed.""" - # GIVEN user_ids of a valid type (tuple, list, ndarray) - # AND item_ids of a valid type (tuple, list, ndarray) + # GIVEN user_ids of a valid type + # AND item_ids of a valid type # WHEN trained model provided # THEN calls to LightFM.predict succeed From 7f177b39db2d5126ea83b296e4024cd43473516d Mon Sep 17 00:00:00 2001 From: Mark Douthwaite Date: Mon, 8 Feb 2021 18:48:57 +0000 Subject: [PATCH 3/6] Update test_api.py --- tests/test_api.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_api.py b/tests/test_api.py index c775625..b30c43c 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -382,7 +382,7 @@ def test_warp_few_items(): model.fit(train) -def test_predict_input_with_valid_types(lfm, user_ids, item_ids): +def test_predict_input_arrays_with_valid_types(lfm, user_ids, item_ids): """Test that calls to the predict method with inputs of valid types succeed.""" # GIVEN user_ids of a valid type From 30b4643a3bb0d7d1862f229b390ace97896bd61b Mon Sep 17 00:00:00 2001 From: Mark Douthwaite Date: Mon, 8 Feb 2021 18:49:29 +0000 Subject: [PATCH 4/6] Update test_api.py --- tests/test_api.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_api.py b/tests/test_api.py index b30c43c..54ad46f 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -382,7 +382,7 @@ def test_warp_few_items(): model.fit(train) -def test_predict_input_arrays_with_valid_types(lfm, user_ids, item_ids): +def test_predict_user_item_inputs_with_valid_types(lfm, user_ids, item_ids): """Test that calls to the predict method with inputs of valid types succeed.""" # GIVEN user_ids of a valid type From 205ce80ec93758120c1482e76347fd6574a1bbd1 Mon Sep 17 00:00:00 2001 From: Mark Douthwaite Date: Mon, 8 Feb 2021 18:58:33 +0000 Subject: [PATCH 5/6] Update conftest.py --- tests/conftest.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 2ee29a3..e31b932 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -28,7 +28,7 @@ def array_int32(rng, size=10): params=[tuple, list, np.array] ) def user_ids(array_int32, request): - """Initialise valid input user_ids for calls to the LightFM.predict method. + """Initialise input user_ids valid for calls to the LightFM.predict method. Notes ----- @@ -51,7 +51,7 @@ def user_ids(array_int32, request): params=[tuple, list, np.array] ) def item_ids(array_int32, request): - """Initialise valid input item_ids for calls to the LightFM.predict method. + """Initialise input item_ids valid for calls to the LightFM.predict method. Notes ----- From 011384a93adfff356adaad232de31c00c13f618c Mon Sep 17 00:00:00 2001 From: Mark Douthwaite Date: Mon, 8 Feb 2021 19:07:16 +0000 Subject: [PATCH 6/6] Apply black. --- tests/conftest.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index e31b932..6e3e3f2 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -23,9 +23,7 @@ def array_int32(rng, size=10): @pytest.fixture( - scope="session", - ids=["tuple", "list", "ndarray"], - params=[tuple, list, np.array] + scope="session", ids=["tuple", "list", "ndarray"], params=[tuple, list, np.array] ) def user_ids(array_int32, request): """Initialise input user_ids valid for calls to the LightFM.predict method. @@ -46,9 +44,7 @@ def user_ids(array_int32, request): @pytest.fixture( - scope="session", - ids=["tuple", "list", "ndarray"], - params=[tuple, list, np.array] + scope="session", ids=["tuple", "list", "ndarray"], params=[tuple, list, np.array] ) def item_ids(array_int32, request): """Initialise input item_ids valid for calls to the LightFM.predict method.