Skip to content

Commit 00623e6

Browse files
authored
Add check_csr for partial_fit methods (#683)
We were seeing some poor result on partial_fit_*, that ended up caused by passing a CSC instead of a CSR matrix. (#682) Add the same `check_csr` code to the partial_fit methods that is already done in the fit method - this will warn if passed a non-csr matrix, and automatically convert.
1 parent 4283257 commit 00623e6

File tree

2 files changed

+6
-0
lines changed

2 files changed

+6
-0
lines changed

implicit/cpu/als.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,7 @@ def recalculate_user(self, userid, user_items):
215215
Sparse matrix of (users, items) that contain the users that liked
216216
each item.
217217
"""
218+
user_items = check_csr(user_items)
218219

219220
# we're using the cholesky solver here on purpose, since for a full recompute
220221
users = 1 if np.isscalar(userid) else len(userid)
@@ -252,6 +253,7 @@ def recalculate_item(self, itemid, item_users):
252253
Sparse matrix of (items, users) that contain the users that liked
253254
each item
254255
"""
256+
item_users = check_csr(item_users)
255257

256258
if self.alpha != 1.0:
257259
item_users = self.alpha * item_users

implicit/gpu/_cuda.pyx

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@ from libc.stdint cimport uint16_t
99
from libcpp cimport bool
1010
from libcpp.utility cimport move, pair
1111

12+
from implicit.utils import check_csr
13+
1214
from .als cimport LeastSquaresSolver as CppLeastSquaresSolver
1315
from .bpr cimport bpr_update as cpp_bpr_update
1416
from .knn cimport KnnQuery as CppKnnQuery
@@ -220,6 +222,8 @@ cdef class CSRMatrix(object):
220222
cdef CppCSRMatrix * c_matrix
221223

222224
def __cinit__(self, X):
225+
X = check_csr(X)
226+
223227
cdef int[:] indptr = X.indptr
224228
cdef int[:] indices = X.indices
225229
cdef float[:] data = X.data.astype(np.float32)

0 commit comments

Comments
 (0)