1616
1717import typing as tp
1818import warnings
19- from enum import Enum
2019
2120import implicit .cpu
2221import implicit .gpu
2726
2827from rectools import InternalIds
2928from rectools .models .base import Scores
29+ from rectools .models .rank .rank import Distance
30+ from rectools .models .utils import convert_arr_to_implicit_gpu_matrix
3031from rectools .types import InternalIdsArray
3132
32- from .utils import convert_arr_to_implicit_gpu_matrix
33-
34-
35- class Distance (Enum ):
36- """Distance metric"""
37-
38- DOT = 1 # Bigger value means closer vectors
39- COSINE = 2 # Bigger value means closer vectors
40- EUCLIDEAN = 3 # Smaller value means closer vectors
41-
4233
4334class ImplicitRanker :
4435 """
@@ -58,18 +49,28 @@ class ImplicitRanker:
5849 objects_factors : np.ndarray
5950 Array with embeddings of all objects, shape (n_objects, n_factors).
6051 For item-item similarity models item similarity vectors are viewed as factors.
52+ num_threads : int, default 0
53+ Will be used as `num_threads` parameter for `implicit.cpu.topk.topk`. Omitted if use_gpu is True
54+ use_gpu : bool, default False
55+ If True `implicit.gpu.KnnQuery().topk` will be used instead of classic cpu version.
6156 """
6257
6358 def __init__ (
64- self , distance : Distance , subjects_factors : tp .Union [np .ndarray , sparse .csr_matrix ], objects_factors : np .ndarray
59+ self ,
60+ distance : Distance ,
61+ subjects_factors : tp .Union [np .ndarray , sparse .csr_matrix ],
62+ objects_factors : np .ndarray ,
63+ num_threads : int = 0 ,
64+ use_gpu : bool = False ,
6565 ) -> None :
66-
6766 if isinstance (subjects_factors , sparse .csr_matrix ) and distance != Distance .DOT :
6867 raise ValueError ("To use `sparse.csr_matrix` distance must be `Distance.DOT`" )
6968
7069 self .distance = distance
7170 self .subjects_factors : np .ndarray = subjects_factors .astype (np .float32 )
7271 self .objects_factors : np .ndarray = objects_factors .astype (np .float32 )
72+ self .num_threads = num_threads
73+ self .use_gpu = use_gpu
7374
7475 self .subjects_norms : np .ndarray
7576 if distance == Distance .COSINE :
@@ -85,7 +86,8 @@ def _get_neginf_score(self) -> float:
8586 # we're comparing `scores <= neginf_score`
8687 return float (
8788 np .asarray (
88- np .asarray (- np .finfo (np .float32 ).max , dtype = np .float32 ).view (np .uint32 ) - 1 , dtype = np .uint32
89+ np .asarray (- np .finfo (np .float32 ).max , dtype = np .float32 ).view (np .uint32 ) - 1 ,
90+ dtype = np .uint32 ,
8991 ).view (np .float32 )
9092 )
9193
@@ -118,7 +120,6 @@ def _get_mask_for_correct_scores(self, scores: np.ndarray) -> tp.List[bool]:
118120 def _process_implicit_scores (
119121 self , subject_ids : InternalIds , ids : np .ndarray , scores : np .ndarray
120122 ) -> tp .Tuple [InternalIds , InternalIds , Scores ]:
121-
122123 all_target_ids = []
123124 all_reco_ids : tp .List [np .ndarray ] = []
124125 all_scores : tp .List [np .ndarray ] = []
@@ -152,7 +153,6 @@ def _rank_on_gpu(
152153 object_norms : tp .Optional [np .ndarray ],
153154 filter_query_items : tp .Optional [tp .Union [sparse .csr_matrix , sparse .csr_array ]],
154155 ) -> tp .Tuple [np .ndarray , np .ndarray ]: # pragma: no cover
155-
156156 object_factors = convert_arr_to_implicit_gpu_matrix (object_factors )
157157
158158 if isinstance (subject_factors , sparse .spmatrix ):
@@ -184,19 +184,17 @@ def _rank_on_gpu(
184184 def rank ( # pylint: disable=too-many-branches
185185 self ,
186186 subject_ids : InternalIds ,
187- k : int ,
187+ k : tp . Optional [ int ] = None ,
188188 filter_pairs_csr : tp .Optional [sparse .csr_matrix ] = None ,
189189 sorted_object_whitelist : tp .Optional [InternalIdsArray ] = None ,
190- num_threads : int = 0 ,
191- use_gpu : bool = False ,
192190 ) -> tp .Tuple [InternalIds , InternalIds , Scores ]:
193191 """Rank objects to proceed inference using implicit library topk cpu method.
194192
195193 Parameters
196194 ----------
197195 subject_ids : csr_matrix
198196 Array of ids to recommend for.
199- k : int
197+ k : int, optional, default ``None``
200198 Derived number of recommendations for every subject id.
201199 filter_pairs_csr : sparse.csr_matrix, optional, default ``None``
202200 Subject-object interactions that should be filtered from recommendations.
@@ -205,16 +203,16 @@ def rank( # pylint: disable=too-many-branches
205203 Whitelist of object ids.
206204 If given, only these items will be used for recommendations.
207205 Otherwise all items from dataset will be used.
208- num_threads : int, default 0
209- Will be used as `num_threads` parameter for `implicit.cpu.topk.topk`. Omitted if use_gpu is True
210- use_gpu : bool, default False
211- If True `implicit.gpu.KnnQuery().topk` will be used instead of classic cpu version.
212206
213207 Returns
214208 -------
215209 (InternalIds, InternalIds, Scores)
216210 Array of subject ids, array of recommended items, sorted by score descending and array of scores.
217211 """
212+ if filter_pairs_csr is not None and filter_pairs_csr .shape [0 ] != len (subject_ids ):
213+ explanation = "Number of rows in `filter_pairs_csr` must be equal to `len(sublect_ids)`"
214+ raise ValueError (explanation )
215+
218216 if sorted_object_whitelist is not None :
219217 object_factors = self .objects_factors [sorted_object_whitelist ]
220218
@@ -229,6 +227,9 @@ def rank( # pylint: disable=too-many-branches
229227 object_factors = self .objects_factors
230228 filter_query_items = filter_pairs_csr
231229
230+ if k is None :
231+ k = object_factors .shape [0 ]
232+
232233 subject_factors = self .subjects_factors [subject_ids ]
233234
234235 object_norms = None # for DOT and EUCLIDEAN distance
@@ -243,6 +244,7 @@ def rank( # pylint: disable=too-many-branches
243244
244245 real_k = min (k , object_factors .shape [0 ])
245246
247+ use_gpu = self .use_gpu
246248 if use_gpu and not HAS_CUDA :
247249 warnings .warn ("Forced rank() on CPU" )
248250 use_gpu = False
@@ -263,7 +265,7 @@ def rank( # pylint: disable=too-many-branches
263265 item_norms = object_norms , # query norms for COSINE distance are applied afterwards
264266 filter_query_items = filter_query_items , # queries x objects csr matrix for getting neginf scores
265267 filter_items = None , # rectools doesn't support blacklist for now
266- num_threads = num_threads ,
268+ num_threads = self . num_threads ,
267269 )
268270
269271 if sorted_object_whitelist is not None :
0 commit comments