Skip to content

Commit 75a30c1

Browse files
authored
Support fit_partial() for LightFM (#223)
`fit_partial()` support for LightFM
1 parent de8c44a commit 75a30c1

File tree

3 files changed

+36
-3
lines changed

3 files changed

+36
-3
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1919
- `load_model` function ([#213](https://github.com/MobileTeleSystems/RecTools/pull/213))
2020
- `model_from_config` function ([#214](https://github.com/MobileTeleSystems/RecTools/pull/214))
2121
- `get_cat_features` method to `SparseFeatures` ([#221](https://github.com/MobileTeleSystems/RecTools/pull/221))
22+
- Support `fit_partial()` for LightFM ([#223](https://github.com/MobileTeleSystems/RecTools/pull/223))
2223
- LightFM Python 3.12+ support ([#224](https://github.com/MobileTeleSystems/RecTools/pull/224))
2324

2425
### Removed

rectools/models/lightfm.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -164,20 +164,25 @@ def _from_config(cls, config: LightFMWrapperModelConfig) -> tpe.Self:
164164
model = model_cls(**params)
165165
return cls(model=model, epochs=config.epochs, num_threads=config.num_threads, verbose=config.verbose)
166166

167-
def _fit(self, dataset: Dataset) -> None: # type: ignore
167+
def _fit(self, dataset: Dataset) -> None:
168168
self.model = deepcopy(self._model)
169+
self._fit_partial(dataset, self.n_epochs)
170+
171+
def _fit_partial(self, dataset: Dataset, epochs: int) -> None:
172+
if not self.is_fitted:
173+
self.model = deepcopy(self._model)
169174

170175
ui_coo = dataset.get_user_item_matrix(include_weights=True).tocoo(copy=False)
171176
user_features = self._prepare_features(dataset.get_hot_user_features(), dataset.n_hot_users)
172177
item_features = self._prepare_features(dataset.get_hot_item_features(), dataset.n_hot_items)
173178
sample_weight = None if self._model.loss == "warp-kos" else ui_coo
174179

175-
self.model.fit(
180+
self.model.fit_partial(
176181
ui_coo,
177182
user_features=user_features,
178183
item_features=item_features,
179184
sample_weight=sample_weight,
180-
epochs=self.n_epochs,
185+
epochs=epochs,
181186
num_threads=self.n_threads,
182187
verbose=self.verbose > 0,
183188
)

tests/models/test_lightfm.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -323,6 +323,33 @@ def test_second_fit_refits_model(self, dataset: Dataset) -> None:
323323
model = LightFMWrapperModel(model=base_model, epochs=5, num_threads=1)
324324
assert_second_fit_refits_model(model, dataset)
325325

326+
@pytest.mark.parametrize("loss", ("logistic", "bpr", "warp"))
327+
@pytest.mark.parametrize("use_features_in_dataset", (False, True))
328+
def test_per_epoch_partial_fit_consistent_with_regular_fit(
329+
self,
330+
dataset: Dataset,
331+
dataset_with_features: Dataset,
332+
use_features_in_dataset: bool,
333+
loss: str,
334+
) -> None:
335+
if use_features_in_dataset:
336+
dataset = dataset_with_features
337+
338+
epochs = 20
339+
340+
base_model_1 = LightFM(no_components=2, loss=loss, random_state=1)
341+
model_1 = LightFMWrapperModel(model=base_model_1, epochs=epochs, num_threads=1).fit(dataset)
342+
343+
base_model_2 = LightFM(no_components=2, loss=loss, random_state=1)
344+
model_2 = LightFMWrapperModel(model=base_model_2, epochs=epochs, num_threads=1)
345+
for _ in range(epochs):
346+
model_2.fit_partial(dataset, epochs=1)
347+
348+
assert np.allclose(model_1.model.item_biases, model_2.model.item_biases)
349+
assert np.allclose(model_1.model.user_biases, model_2.model.user_biases)
350+
assert np.allclose(model_1.model.item_embeddings, model_2.model.item_embeddings)
351+
assert np.allclose(model_1.model.user_embeddings, model_2.model.user_embeddings)
352+
326353
def test_fail_when_getting_cold_reco_with_no_biases(self, dataset: Dataset) -> None:
327354
class NoBiasesLightFMWrapperModel(LightFMWrapperModel):
328355
def _get_items_factors(self, dataset: Dataset) -> Factors:

0 commit comments

Comments
 (0)