Skip to content

Commit fd351da

Browse files
authored
Clear cached properties on partial_fit_* (#685)
The item_norms/user_norms were incorrect after calling the partial_fit methods, especially for new items. Fix.
1 parent 00623e6 commit fd351da

File tree

2 files changed

+16
-0
lines changed

2 files changed

+16
-0
lines changed

implicit/cpu/als.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -306,6 +306,10 @@ def partial_fit_users(self, userids, user_items):
306306
# update the stored factors with the newly calculated values
307307
self.user_factors[userids] = user_factors
308308

309+
# clear any cached properties that are invalidated by this update
310+
self._user_norms = None
311+
self._XtX = None
312+
309313
def partial_fit_items(self, itemids, item_users):
310314
"""Incrementally updates item factors
311315
@@ -339,6 +343,10 @@ def partial_fit_items(self, itemids, item_users):
339343
# update the stored factors with the newly calculated values
340344
self.item_factors[itemids] = item_factors
341345

346+
# clear any cached properties that are invalidated by this update
347+
self._item_norms = None
348+
self._YtY = None
349+
342350
def explain(self, userid, user_items, itemid, user_weights=None, N=10):
343351
"""Provides explanations for why the item is liked by the user.
344352

implicit/gpu/als.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -234,6 +234,10 @@ def partial_fit_users(self, userids, user_items):
234234

235235
self.user_factors.assign_rows(userids, user_factors)
236236

237+
# clear any cached properties that are invalidated by this update
238+
self._user_norms = self._user_norms_host = None
239+
self._XtX = None
240+
237241
def partial_fit_items(self, itemids, item_users):
238242
"""Incrementally updates item factors
239243
@@ -266,6 +270,10 @@ def partial_fit_items(self, itemids, item_users):
266270
# update the stored factors with the newly calculated values
267271
self.item_factors.assign_rows(itemids, item_factors)
268272

273+
# clear any cached properties that are invalidated by this update
274+
self._item_norms = self._item_norms_host = None
275+
self._YtY = None
276+
269277
@property
270278
def solver(self):
271279
if self._solver is None:

0 commit comments

Comments
 (0)