@@ -19,7 +19,6 @@ import multiprocessing
1919
2020cimport cython
2121from cython.parallel import prange
22- from cython cimport floating, integral
2322from libcpp cimport bool
2423from libc.math cimport abs
2524
@@ -28,27 +27,32 @@ cimport numpy as np
2827from tqdm.auto import trange
2928
3029
30+ ctypedef np.int64_t INT64_t
31+
32+
3133@ cython.boundscheck (False )
3234@ cython.wraparound (False )
33- def fit_sgd (integral[:] rid , integral[:] cid , floating[:] val ,
34- floating[:, :] U , floating[:, :] V ,
35- floating[:] Bu , floating[:] Bi ,
36- integral num_users , integral num_items ,
37- floating lr , floating reg , floating mu ,
35+ def fit_sgd (INT64_t[:] rid , INT64_t[:] cid , float[:] val ,
36+ float[:, :] U , float[:, :] V ,
37+ float[:] Bu , float[:] Bi ,
38+ float lr , float reg , float mu ,
3839 int max_iter , int num_threads ,
3940 bool use_bias , bool early_stop , bool verbose ):
4041 """ Fit the model parameters (U, V, Bu, Bi) with SGD"""
4142 cdef:
42- integral num_ratings = val.shape[0 ]
43- integral num_factors = U.shape[1 ]
43+ INT64_t num_ratings = val.shape[0 ]
44+ INT64_t u, i, j
45+
46+ int num_factors = U.shape[1 ]
47+ int f
4448
45- floating loss = 0
46- floating last_loss = 0
47- floating r, r_pred, error, u_f, i_f, delta_loss
48- integral u, i, f, j
49+ float loss = 0
50+ float last_loss = 0
51+ float r, r_pred, error, u_f, i_f, delta_loss
52+
4953
50- floating * user
51- floating * item
54+ float * user
55+ float * item
5256
5357 progress = trange(max_iter, disable = not verbose)
5458 for epoch in progress:
0 commit comments