Skip to content

Commit 527c062

Browse files
authored
Feature/similarity_module (#272)
Added `similarity_module_type` to transformer models and `DistanceSimilarityModule` with `dot` and `cosine` distance options.
1 parent c5dc5f7 commit 527c062

File tree

13 files changed

+412
-157
lines changed

13 files changed

+412
-157
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,9 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
66
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
77

88

9+
### Added
10+
- `SimilarityModuleBase`, `DistanceSimilarityModule`, similarity module to `TransformerTorchBackbone`, parameters to transformer-based models `similarity_module_type`, `similarity_module_kwargs` ([#272](https://github.com/MobileTeleSystems/RecTools/pull/272))
11+
912
## [0.12.0] - 24.02.2025
1013

1114
### Added

rectools/models/nn/transformers/base.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
PreLNTransformerLayers,
4747
TransformerLayersBase,
4848
)
49+
from .similarity import DistanceSimilarityModule, SimilarityModuleBase
4950
from .torch_backbone import TransformerTorchBackbone
5051

5152
InitKwargs = tp.Dict[str, tp.Any]
@@ -97,6 +98,16 @@ def _serialize_type_sequence(obj: tp.Sequence[tp.Type]) -> tp.Tuple[str, ...]:
9798
),
9899
]
99100

101+
SimilarityModuleType = tpe.Annotated[
102+
tp.Type[SimilarityModuleBase],
103+
BeforeValidator(_get_class_obj),
104+
PlainSerializer(
105+
func=get_class_or_function_full_path,
106+
return_type=str,
107+
when_used="json",
108+
),
109+
]
110+
100111
TransformerDataPreparatorType = tpe.Annotated[
101112
tp.Type[TransformerDataPreparatorBase],
102113
BeforeValidator(_get_class_obj),
@@ -183,13 +194,15 @@ class TransformerModelConfig(ModelConfig):
183194
pos_encoding_type: PositionalEncodingType = LearnableInversePositionalEncoding
184195
transformer_layers_type: TransformerLayersType = PreLNTransformerLayers
185196
lightning_module_type: TransformerLightningModuleType = TransformerLightningModule
197+
similarity_module_type: SimilarityModuleType = DistanceSimilarityModule
186198
get_val_mask_func: tp.Optional[ValMaskCallableSerialized] = None
187199
get_trainer_func: tp.Optional[TrainerCallableSerialized] = None
188200
data_preparator_kwargs: tp.Optional[InitKwargs] = None
189201
transformer_layers_kwargs: tp.Optional[InitKwargs] = None
190202
item_net_constructor_kwargs: tp.Optional[InitKwargs] = None
191203
pos_encoding_kwargs: tp.Optional[InitKwargs] = None
192204
lightning_module_kwargs: tp.Optional[InitKwargs] = None
205+
similarity_module_kwargs: tp.Optional[InitKwargs] = None
193206

194207

195208
TransformerModelConfig_T = tp.TypeVar("TransformerModelConfig_T", bound=TransformerModelConfig)
@@ -237,13 +250,15 @@ def __init__( # pylint: disable=too-many-arguments, too-many-locals
237250
item_net_constructor_type: tp.Type[ItemNetConstructorBase] = SumOfEmbeddingsConstructor,
238251
pos_encoding_type: tp.Type[PositionalEncodingBase] = LearnableInversePositionalEncoding,
239252
lightning_module_type: tp.Type[TransformerLightningModuleBase] = TransformerLightningModule,
253+
similarity_module_type: tp.Type[SimilarityModuleBase] = DistanceSimilarityModule,
240254
get_val_mask_func: tp.Optional[ValMaskCallable] = None,
241255
get_trainer_func: tp.Optional[TrainerCallable] = None,
242256
data_preparator_kwargs: tp.Optional[InitKwargs] = None,
243257
transformer_layers_kwargs: tp.Optional[InitKwargs] = None,
244258
item_net_constructor_kwargs: tp.Optional[InitKwargs] = None,
245259
pos_encoding_kwargs: tp.Optional[InitKwargs] = None,
246260
lightning_module_kwargs: tp.Optional[InitKwargs] = None,
261+
similarity_module_kwargs: tp.Optional[InitKwargs] = None,
247262
**kwargs: tp.Any,
248263
) -> None:
249264
super().__init__(verbose=verbose)
@@ -268,6 +283,7 @@ def __init__( # pylint: disable=too-many-arguments, too-many-locals
268283
self.recommend_batch_size = recommend_batch_size
269284
self.recommend_torch_device = recommend_torch_device
270285
self.train_min_user_interactions = train_min_user_interactions
286+
self.similarity_module_type = similarity_module_type
271287
self.item_net_block_types = item_net_block_types
272288
self.item_net_constructor_type = item_net_constructor_type
273289
self.pos_encoding_type = pos_encoding_type
@@ -279,6 +295,7 @@ def __init__( # pylint: disable=too-many-arguments, too-many-locals
279295
self.item_net_constructor_kwargs = item_net_constructor_kwargs
280296
self.pos_encoding_kwargs = pos_encoding_kwargs
281297
self.lightning_module_kwargs = lightning_module_kwargs
298+
self.similarity_module_kwargs = similarity_module_kwargs
282299

283300
self._init_data_preparator()
284301
self._init_trainer()
@@ -295,12 +312,13 @@ def _get_kwargs(actual_kwargs: tp.Optional[InitKwargs]) -> InitKwargs:
295312
return kwargs
296313

297314
def _init_data_preparator(self) -> None:
315+
requires_negatives = self.lightning_module_type.requires_negatives(self.loss)
298316
self.data_preparator = self.data_preparator_type(
299317
session_max_len=self.session_max_len,
300318
batch_size=self.batch_size,
301319
dataloader_num_workers=self.dataloader_num_workers,
302320
train_min_user_interactions=self.train_min_user_interactions,
303-
n_negatives=self.n_negatives if self.loss != "softmax" else None,
321+
n_negatives=self.n_negatives if requires_negatives else None,
304322
get_val_mask_func=self.get_val_mask_func,
305323
shuffle_train=True,
306324
**self._get_kwargs(self.data_preparator_kwargs),
@@ -356,15 +374,20 @@ def _init_transformer_layers(self) -> TransformerLayersBase:
356374
**self._get_kwargs(self.transformer_layers_kwargs),
357375
)
358376

377+
def _init_similarity_module(self) -> SimilarityModuleBase:
378+
return self.similarity_module_type(**self._get_kwargs(self.similarity_module_kwargs))
379+
359380
def _init_torch_model(self, item_model: ItemNetBase) -> TransformerTorchBackbone:
360381
pos_encoding_layer = self._init_pos_encoding_layer()
361382
transformer_layers = self._init_transformer_layers()
383+
similarity_module = self._init_similarity_module()
362384
return TransformerTorchBackbone(
363385
n_heads=self.n_heads,
364386
dropout_rate=self.dropout_rate,
365387
item_model=item_model,
366388
pos_encoding_layer=pos_encoding_layer,
367389
transformer_layers=transformer_layers,
390+
similarity_module=similarity_module,
368391
use_causal_attn=self.use_causal_attn,
369392
use_key_padding_mask=self.use_key_padding_mask,
370393
)

rectools/models/nn/transformers/bert4rec.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
PreLNTransformerLayers,
4545
TransformerLayersBase,
4646
)
47+
from .similarity import DistanceSimilarityModule, SimilarityModuleBase
4748

4849

4950
class BERT4RecDataPreparator(TransformerDataPreparatorBase):
@@ -256,6 +257,8 @@ class BERT4RecModel(TransformerModelBase[BERT4RecModelConfig]):
256257
Type of data preparator used for dataset processing and dataloader creation.
257258
lightning_module_type : type(TransformerLightningModuleBase), default `TransformerLightningModule`
258259
Type of lightning module defining training procedure.
260+
similarity_module_type : type(SimilarityModuleBase), default `DistanceSimilarityModule`
261+
Type of similarity module.
259262
get_val_mask_func : Callable, default ``None``
260263
Function to get validation mask.
261264
get_trainer_func : Callable, default ``None``
@@ -289,6 +292,9 @@ class BERT4RecModel(TransformerModelBase[BERT4RecModelConfig]):
289292
lightning_module_kwargs: optional(dict), default ``None``
290293
Additional keyword arguments to pass during `lightning_module_type` initialization.
291294
Make sure all dict values have JSON serializable types.
295+
similarity_module_kwargs: optional(dict), default ``None``
296+
Additional keyword arguments to pass during `similarity_module_type` initialization.
297+
Make sure all dict values have JSON serializable types.
292298
"""
293299

294300
config_class = BERT4RecModelConfig
@@ -320,6 +326,7 @@ def __init__( # pylint: disable=too-many-arguments, too-many-locals
320326
transformer_layers_type: tp.Type[TransformerLayersBase] = PreLNTransformerLayers,
321327
data_preparator_type: tp.Type[TransformerDataPreparatorBase] = BERT4RecDataPreparator,
322328
lightning_module_type: tp.Type[TransformerLightningModuleBase] = TransformerLightningModule,
329+
similarity_module_type: tp.Type[SimilarityModuleBase] = DistanceSimilarityModule,
323330
get_val_mask_func: tp.Optional[ValMaskCallable] = None,
324331
get_trainer_func: tp.Optional[TrainerCallable] = None,
325332
recommend_batch_size: int = 256,
@@ -332,6 +339,7 @@ def __init__( # pylint: disable=too-many-arguments, too-many-locals
332339
item_net_constructor_kwargs: tp.Optional[InitKwargs] = None,
333340
pos_encoding_kwargs: tp.Optional[InitKwargs] = None,
334341
lightning_module_kwargs: tp.Optional[InitKwargs] = None,
342+
similarity_module_kwargs: tp.Optional[InitKwargs] = None,
335343
):
336344
self.mask_prob = mask_prob
337345

@@ -360,6 +368,7 @@ def __init__( # pylint: disable=too-many-arguments, too-many-locals
360368
recommend_n_threads=recommend_n_threads,
361369
recommend_use_torch_ranking=recommend_use_torch_ranking,
362370
train_min_user_interactions=train_min_user_interactions,
371+
similarity_module_type=similarity_module_type,
363372
item_net_block_types=item_net_block_types,
364373
item_net_constructor_type=item_net_constructor_type,
365374
pos_encoding_type=pos_encoding_type,
@@ -372,6 +381,7 @@ def __init__( # pylint: disable=too-many-arguments, too-many-locals
372381
item_net_constructor_kwargs=item_net_constructor_kwargs,
373382
pos_encoding_kwargs=pos_encoding_kwargs,
374383
lightning_module_kwargs=lightning_module_kwargs,
384+
similarity_module_kwargs=similarity_module_kwargs,
375385
)
376386

377387
def _init_data_preparator(self) -> None:

rectools/models/nn/transformers/data_preparator.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,18 @@ def _process_features_for_id_map(
168168
full_feature_values = np.vstack([extra_token_feature_values, sorted_features.values])
169169
return DenseFeatures.from_iterables(values=full_feature_values, names=raw_features.names)
170170

171+
def _filter_train_interactions(self, train_interactions: pd.DataFrame) -> pd.DataFrame:
172+
"""Filter train interactions."""
173+
user_stats = train_interactions[Columns.User].value_counts()
174+
users = user_stats[user_stats >= self.train_min_user_interactions].index
175+
train_interactions = train_interactions[(train_interactions[Columns.User].isin(users))]
176+
train_interactions = (
177+
train_interactions.sort_values(Columns.Datetime, kind="stable")
178+
.groupby(Columns.User, sort=False)
179+
.tail(self.session_max_len + self.train_session_max_len_addition)
180+
)
181+
return train_interactions
182+
171183
def process_dataset_train(self, dataset: Dataset) -> None:
172184
"""Process train dataset and save data."""
173185
raw_interactions = dataset.get_raw_interactions()
@@ -179,14 +191,7 @@ def process_dataset_train(self, dataset: Dataset) -> None:
179191
interactions = raw_interactions[~val_mask]
180192

181193
# Filter train interactions
182-
user_stats = interactions[Columns.User].value_counts()
183-
users = user_stats[user_stats >= self.train_min_user_interactions].index
184-
interactions = interactions[(interactions[Columns.User].isin(users))]
185-
interactions = (
186-
interactions.sort_values(Columns.Datetime, kind="stable")
187-
.groupby(Columns.User, sort=False)
188-
.tail(self.session_max_len + self.train_session_max_len_addition)
189-
)
194+
interactions = self._filter_train_interactions(interactions)
190195

191196
# Prepare id maps
192197
user_id_map = IdMap.from_values(interactions[Columns.User].values)

0 commit comments

Comments
 (0)