Skip to content

Commit f73d054

Browse files
authored
Fix/warp-kos loss for LightFMWrapperModel (#175)
- Fixed `NotImplementedError` for LightFM with warp-kos loss
1 parent 0f4034c commit f73d054

File tree

3 files changed

+15
-1
lines changed

3 files changed

+15
-1
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1515
### Fixed
1616
- `display()` method in `MetricsApp` ([#169](https://github.com/MobileTeleSystems/RecTools/pull/169))
1717

18+
### Fixed
19+
- Allow warp-kos loss for LightFMWrapperModel ([#175](https://github.com/MobileTeleSystems/RecTools/pull/175))
20+
1821
## [0.7.0] - 29.07.2024
1922

2023
### Added

rectools/models/lightfm.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,12 +77,13 @@ def _fit(self, dataset: Dataset) -> None: # type: ignore
7777
ui_coo = dataset.get_user_item_matrix(include_weights=True).tocoo(copy=False)
7878
user_features = self._prepare_features(dataset.get_hot_user_features(), dataset.n_hot_users)
7979
item_features = self._prepare_features(dataset.get_hot_item_features(), dataset.n_hot_items)
80+
sample_weight = None if self._model.loss == "warp-kos" else ui_coo
8081

8182
self.model.fit(
8283
ui_coo,
8384
user_features=user_features,
8485
item_features=item_features,
85-
sample_weight=ui_coo,
86+
sample_weight=sample_weight,
8687
epochs=self.n_epochs,
8788
num_threads=self.n_threads,
8889
verbose=self.verbose > 0,

tests/models/test_lightfm.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -222,6 +222,16 @@ def test_with_weights(self, interactions_df: pd.DataFrame) -> None:
222222
actual,
223223
)
224224

225+
def test_with_warp_kos(self, dataset: Dataset) -> None:
226+
base_model = DeterministicLightFM(no_components=2, loss="warp-kos")
227+
try:
228+
LightFMWrapperModel(model=base_model, epochs=10).fit(dataset)
229+
except NotImplementedError:
230+
pytest.fail("Should not raise NotImplementedError")
231+
except ValueError:
232+
# LightFM raises ValueError with the dataset
233+
pass
234+
225235
def test_get_vectors(self, dataset_with_features: Dataset) -> None:
226236
base_model = LightFM(no_components=2, loss="logistic")
227237
model = LightFMWrapperModel(model=base_model).fit(dataset_with_features)

0 commit comments

Comments
 (0)