Skip to content

Commit b49f0cd

Browse files
authored
Feature/torch_ranker (#251)
`TorchRanker` to allow ranking with gpu using pytorch
1 parent 5664a3c commit b49f0cd

File tree

15 files changed

+1190
-272
lines changed

15 files changed

+1190
-272
lines changed

CHANGELOG.md

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,17 @@ All notable changes to this project will be documented in this file.
55
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
66
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
77

8-
98
## Unreleased
109

1110
### Added
1211
- `use_gpu` for PureSVD ([#229](https://github.com/MobileTeleSystems/RecTools/pull/229))
1312
- `from_params` method for models and `model_from_params` function ([#252](https://github.com/MobileTeleSystems/RecTools/pull/252))
13+
- `TorchRanker` ranker which calculates scores using torch. Supports GPU. [#251](https://github.com/MobileTeleSystems/RecTools/pull/251)
14+
- `Ranker` ranker protocol which unify rankers call. [#251](https://github.com/MobileTeleSystems/RecTools/pull/251)
15+
16+
### Changed
1417

18+
- `ImplicitRanker` `rank` method compatible with `Ranker` protocol. `use_gpu` and `num_threads` params moved from `rank` method to `__init__`. [#251](https://github.com/MobileTeleSystems/RecTools/pull/251)
1519

1620
## [0.10.0] - 16.01.2025
1721

rectools/models/ease.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,6 @@ def __init__(
8585
recommend_use_gpu_ranking: bool = True,
8686
verbose: int = 0,
8787
):
88-
8988
super().__init__(verbose=verbose)
9089
self.weight: np.ndarray
9190
self.regularization = regularization
@@ -146,16 +145,17 @@ def _recommend_u2i(
146145
distance=Distance.DOT,
147146
subjects_factors=user_items,
148147
objects_factors=self.weight,
148+
use_gpu=self.recommend_use_gpu_ranking and HAS_CUDA,
149+
num_threads=self.recommend_n_threads,
149150
)
151+
150152
ui_csr_for_filter = user_items[user_ids] if filter_viewed else None
151153

152154
all_user_ids, all_reco_ids, all_scores = ranker.rank(
153155
subject_ids=user_ids,
154156
k=k,
155157
filter_pairs_csr=ui_csr_for_filter,
156158
sorted_object_whitelist=sorted_item_ids_to_recommend,
157-
num_threads=self.recommend_n_threads,
158-
use_gpu=self.recommend_use_gpu_ranking and HAS_CUDA,
159159
)
160160

161161
return all_user_ids, all_reco_ids, all_scores

rectools/models/rank/__init__.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
# Copyright 2022-2025 MTS (Mobile Telesystems)
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# pylint: disable=wrong-import-position
16+
17+
"""
18+
Recommendation models (:mod:`rectools.models.rank`)
19+
==============================================
20+
21+
Rankers to build recs from embeddings.
22+
23+
24+
Rankers
25+
------
26+
`rank.ImplicitRanker`
27+
`rank.TorchRanker`
28+
"""
29+
30+
try:
31+
from .rank_torch import TorchRanker
32+
except ImportError: # pragma: no cover
33+
from .compat import TorchRanker # type: ignore
34+
35+
from rectools.models.rank.rank import Distance, Ranker
36+
from rectools.models.rank.rank_implicit import ImplicitRanker
37+
38+
__all__ = [
39+
"TorchRanker",
40+
"ImplicitRanker",
41+
"Distance",
42+
"Ranker",
43+
]

rectools/models/rank/compat.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
from rectools.compat import RequirementUnavailable
2+
3+
4+
class TorchRanker(RequirementUnavailable):
5+
"""Dummy class, which is returned if there are no dependencies required for the model"""
6+
7+
requirement = "torch"

rectools/models/rank/rank.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
import typing as tp
2+
from enum import Enum
3+
4+
from scipy import sparse
5+
6+
from rectools import InternalIds
7+
from rectools.models.base import Scores
8+
from rectools.types import InternalIdsArray
9+
10+
11+
class Distance(Enum):
12+
"""Distance metric"""
13+
14+
DOT = 1 # Bigger value means closer vectors
15+
COSINE = 2 # Bigger value means closer vectors
16+
EUCLIDEAN = 3 # Smaller value means closer vectors
17+
18+
19+
class Ranker(tp.Protocol):
20+
"""Protocol for all rankers"""
21+
22+
def rank(
23+
self,
24+
subject_ids: InternalIds,
25+
k: tp.Optional[int] = None,
26+
filter_pairs_csr: tp.Optional[sparse.csr_matrix] = None,
27+
sorted_object_whitelist: tp.Optional[InternalIdsArray] = None,
28+
) -> tp.Tuple[InternalIds, InternalIds, Scores]: # pragma: no cover
29+
"""Rank objects by corresponding embeddings.
30+
31+
Parameters
32+
----------
33+
subject_ids : InternalIds
34+
Array of ids to recommend for.
35+
k : int, optional, default ``None``
36+
Derived number of recommendations for every subject id.
37+
Return all recs if None.
38+
filter_pairs_csr : sparse.csr_matrix, optional, default ``None``
39+
Subject-object interactions that should be filtered from recommendations.
40+
This is relevant for u2i case.
41+
sorted_object_whitelist : sparse.csr_matrix, optional, default ``None``
42+
Whitelist of object ids.
43+
If given, only these items will be used for recommendations.
44+
Otherwise all items from dataset will be used.
45+
46+
Returns
47+
-------
48+
(InternalIds, InternalIds, Scores)
49+
Array of subject ids, array of recommended items, sorted by score descending and array of scores.
50+
"""
Lines changed: 27 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616

1717
import typing as tp
1818
import warnings
19-
from enum import Enum
2019

2120
import implicit.cpu
2221
import implicit.gpu
@@ -27,18 +26,10 @@
2726

2827
from rectools import InternalIds
2928
from rectools.models.base import Scores
29+
from rectools.models.rank.rank import Distance
30+
from rectools.models.utils import convert_arr_to_implicit_gpu_matrix
3031
from rectools.types import InternalIdsArray
3132

32-
from .utils import convert_arr_to_implicit_gpu_matrix
33-
34-
35-
class Distance(Enum):
36-
"""Distance metric"""
37-
38-
DOT = 1 # Bigger value means closer vectors
39-
COSINE = 2 # Bigger value means closer vectors
40-
EUCLIDEAN = 3 # Smaller value means closer vectors
41-
4233

4334
class ImplicitRanker:
4435
"""
@@ -58,18 +49,28 @@ class ImplicitRanker:
5849
objects_factors : np.ndarray
5950
Array with embeddings of all objects, shape (n_objects, n_factors).
6051
For item-item similarity models item similarity vectors are viewed as factors.
52+
num_threads : int, default 0
53+
Will be used as `num_threads` parameter for `implicit.cpu.topk.topk`. Omitted if use_gpu is True
54+
use_gpu : bool, default False
55+
If True `implicit.gpu.KnnQuery().topk` will be used instead of classic cpu version.
6156
"""
6257

6358
def __init__(
64-
self, distance: Distance, subjects_factors: tp.Union[np.ndarray, sparse.csr_matrix], objects_factors: np.ndarray
59+
self,
60+
distance: Distance,
61+
subjects_factors: tp.Union[np.ndarray, sparse.csr_matrix],
62+
objects_factors: np.ndarray,
63+
num_threads: int = 0,
64+
use_gpu: bool = False,
6565
) -> None:
66-
6766
if isinstance(subjects_factors, sparse.csr_matrix) and distance != Distance.DOT:
6867
raise ValueError("To use `sparse.csr_matrix` distance must be `Distance.DOT`")
6968

7069
self.distance = distance
7170
self.subjects_factors: np.ndarray = subjects_factors.astype(np.float32)
7271
self.objects_factors: np.ndarray = objects_factors.astype(np.float32)
72+
self.num_threads = num_threads
73+
self.use_gpu = use_gpu
7374

7475
self.subjects_norms: np.ndarray
7576
if distance == Distance.COSINE:
@@ -85,7 +86,8 @@ def _get_neginf_score(self) -> float:
8586
# we're comparing `scores <= neginf_score`
8687
return float(
8788
np.asarray(
88-
np.asarray(-np.finfo(np.float32).max, dtype=np.float32).view(np.uint32) - 1, dtype=np.uint32
89+
np.asarray(-np.finfo(np.float32).max, dtype=np.float32).view(np.uint32) - 1,
90+
dtype=np.uint32,
8991
).view(np.float32)
9092
)
9193

@@ -118,7 +120,6 @@ def _get_mask_for_correct_scores(self, scores: np.ndarray) -> tp.List[bool]:
118120
def _process_implicit_scores(
119121
self, subject_ids: InternalIds, ids: np.ndarray, scores: np.ndarray
120122
) -> tp.Tuple[InternalIds, InternalIds, Scores]:
121-
122123
all_target_ids = []
123124
all_reco_ids: tp.List[np.ndarray] = []
124125
all_scores: tp.List[np.ndarray] = []
@@ -152,7 +153,6 @@ def _rank_on_gpu(
152153
object_norms: tp.Optional[np.ndarray],
153154
filter_query_items: tp.Optional[tp.Union[sparse.csr_matrix, sparse.csr_array]],
154155
) -> tp.Tuple[np.ndarray, np.ndarray]: # pragma: no cover
155-
156156
object_factors = convert_arr_to_implicit_gpu_matrix(object_factors)
157157

158158
if isinstance(subject_factors, sparse.spmatrix):
@@ -184,19 +184,17 @@ def _rank_on_gpu(
184184
def rank( # pylint: disable=too-many-branches
185185
self,
186186
subject_ids: InternalIds,
187-
k: int,
187+
k: tp.Optional[int] = None,
188188
filter_pairs_csr: tp.Optional[sparse.csr_matrix] = None,
189189
sorted_object_whitelist: tp.Optional[InternalIdsArray] = None,
190-
num_threads: int = 0,
191-
use_gpu: bool = False,
192190
) -> tp.Tuple[InternalIds, InternalIds, Scores]:
193191
"""Rank objects to proceed inference using implicit library topk cpu method.
194192
195193
Parameters
196194
----------
197195
subject_ids : csr_matrix
198196
Array of ids to recommend for.
199-
k : int
197+
k : int, optional, default ``None``
200198
Derived number of recommendations for every subject id.
201199
filter_pairs_csr : sparse.csr_matrix, optional, default ``None``
202200
Subject-object interactions that should be filtered from recommendations.
@@ -205,16 +203,16 @@ def rank( # pylint: disable=too-many-branches
205203
Whitelist of object ids.
206204
If given, only these items will be used for recommendations.
207205
Otherwise all items from dataset will be used.
208-
num_threads : int, default 0
209-
Will be used as `num_threads` parameter for `implicit.cpu.topk.topk`. Omitted if use_gpu is True
210-
use_gpu : bool, default False
211-
If True `implicit.gpu.KnnQuery().topk` will be used instead of classic cpu version.
212206
213207
Returns
214208
-------
215209
(InternalIds, InternalIds, Scores)
216210
Array of subject ids, array of recommended items, sorted by score descending and array of scores.
217211
"""
212+
if filter_pairs_csr is not None and filter_pairs_csr.shape[0] != len(subject_ids):
213+
explanation = "Number of rows in `filter_pairs_csr` must be equal to `len(sublect_ids)`"
214+
raise ValueError(explanation)
215+
218216
if sorted_object_whitelist is not None:
219217
object_factors = self.objects_factors[sorted_object_whitelist]
220218

@@ -229,6 +227,9 @@ def rank( # pylint: disable=too-many-branches
229227
object_factors = self.objects_factors
230228
filter_query_items = filter_pairs_csr
231229

230+
if k is None:
231+
k = object_factors.shape[0]
232+
232233
subject_factors = self.subjects_factors[subject_ids]
233234

234235
object_norms = None # for DOT and EUCLIDEAN distance
@@ -243,6 +244,7 @@ def rank( # pylint: disable=too-many-branches
243244

244245
real_k = min(k, object_factors.shape[0])
245246

247+
use_gpu = self.use_gpu
246248
if use_gpu and not HAS_CUDA:
247249
warnings.warn("Forced rank() on CPU")
248250
use_gpu = False
@@ -263,7 +265,7 @@ def rank( # pylint: disable=too-many-branches
263265
item_norms=object_norms, # query norms for COSINE distance are applied afterwards
264266
filter_query_items=filter_query_items, # queries x objects csr matrix for getting neginf scores
265267
filter_items=None, # rectools doesn't support blacklist for now
266-
num_threads=num_threads,
268+
num_threads=self.num_threads,
267269
)
268270

269271
if sorted_object_whitelist is not None:

0 commit comments

Comments
 (0)