@@ -323,6 +323,33 @@ def test_second_fit_refits_model(self, dataset: Dataset) -> None:
323323 model = LightFMWrapperModel (model = base_model , epochs = 5 , num_threads = 1 )
324324 assert_second_fit_refits_model (model , dataset )
325325
326+ @pytest .mark .parametrize ("loss" , ("logistic" , "bpr" , "warp" ))
327+ @pytest .mark .parametrize ("use_features_in_dataset" , (False , True ))
328+ def test_per_epoch_partial_fit_consistent_with_regular_fit (
329+ self ,
330+ dataset : Dataset ,
331+ dataset_with_features : Dataset ,
332+ use_features_in_dataset : bool ,
333+ loss : str ,
334+ ) -> None :
335+ if use_features_in_dataset :
336+ dataset = dataset_with_features
337+
338+ epochs = 20
339+
340+ base_model_1 = LightFM (no_components = 2 , loss = loss , random_state = 1 )
341+ model_1 = LightFMWrapperModel (model = base_model_1 , epochs = epochs , num_threads = 1 ).fit (dataset )
342+
343+ base_model_2 = LightFM (no_components = 2 , loss = loss , random_state = 1 )
344+ model_2 = LightFMWrapperModel (model = base_model_2 , epochs = epochs , num_threads = 1 )
345+ for _ in range (epochs ):
346+ model_2 .fit_partial (dataset , epochs = 1 )
347+
348+ assert np .allclose (model_1 .model .item_biases , model_2 .model .item_biases )
349+ assert np .allclose (model_1 .model .user_biases , model_2 .model .user_biases )
350+ assert np .allclose (model_1 .model .item_embeddings , model_2 .model .item_embeddings )
351+ assert np .allclose (model_1 .model .user_embeddings , model_2 .model .user_embeddings )
352+
326353 def test_fail_when_getting_cold_reco_with_no_biases (self , dataset : Dataset ) -> None :
327354 class NoBiasesLightFMWrapperModel (LightFMWrapperModel ):
328355 def _get_items_factors (self , dataset : Dataset ) -> Factors :
0 commit comments