Skip to content

Commit 7cbc4ce

Browse files
authored
Feature/Negative sampling (#275)
Added negative sampling customization to transformer-based models
1 parent 69f6736 commit 7cbc4ce

File tree

9 files changed

+249
-44
lines changed

9 files changed

+249
-44
lines changed

CHANGELOG.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
88
## Unreleased
99

1010
### Added
11+
- `TransformerNegativeSamplerBase` and `CatalogUniformSampler` classes, `negative_sampler_type` and `negative_sampler_kwargs` parameters to transformer-based models ([#275](https://github.com/MobileTeleSystems/RecTools/pull/275))
1112
- `SimilarityModuleBase`, `DistanceSimilarityModule`, similarity module to `TransformerTorchBackbone` parameters to transformer-based models `similarity_module_type`, `similarity_module_kwargs` ([#272](https://github.com/MobileTeleSystems/RecTools/pull/272))
1213
- `out_dim` property to `IdEmbeddingsItemNet`, `CatFeaturesItemNet` and `SumOfEmbeddingsConstructor` ([#276](https://github.com/MobileTeleSystems/RecTools/pull/276))
1314
- `TransformerBackboneBase`, `backbone_type` and `backbone_kwargs` parameters to transformer-based models ([#277](https://github.com/MobileTeleSystems/RecTools/pull/277))
1415
- `sampled_softmax` loss option for transformer models ([#274](https://github.com/MobileTeleSystems/RecTools/pull/274))
15-
1616
## [0.12.0] - 24.02.2025
1717

1818
### Added

rectools/models/nn/transformers/base.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
)
4141
from .data_preparator import TransformerDataPreparatorBase
4242
from .lightning import TransformerLightningModule, TransformerLightningModuleBase
43+
from .negative_sampler import CatalogUniformSampler, TransformerNegativeSamplerBase
4344
from .net_blocks import (
4445
LearnableInversePositionalEncoding,
4546
PositionalEncodingBase,
@@ -128,6 +129,16 @@ def _serialize_type_sequence(obj: tp.Sequence[tp.Type]) -> tp.Tuple[str, ...]:
128129
),
129130
]
130131

132+
TransformerNegativeSamplerType = tpe.Annotated[
133+
tp.Type[TransformerNegativeSamplerBase],
134+
BeforeValidator(_get_class_obj),
135+
PlainSerializer(
136+
func=get_class_or_function_full_path,
137+
return_type=str,
138+
when_used="json",
139+
),
140+
]
141+
131142

132143
ItemNetConstructorType = tpe.Annotated[
133144
tp.Type[ItemNetConstructorBase],
@@ -204,6 +215,7 @@ class TransformerModelConfig(ModelConfig):
204215
pos_encoding_type: PositionalEncodingType = LearnableInversePositionalEncoding
205216
transformer_layers_type: TransformerLayersType = PreLNTransformerLayers
206217
lightning_module_type: TransformerLightningModuleType = TransformerLightningModule
218+
negative_sampler_type: TransformerNegativeSamplerType = CatalogUniformSampler
207219
similarity_module_type: SimilarityModuleType = DistanceSimilarityModule
208220
backbone_type: TransformerBackboneType = TransformerTorchBackbone
209221
get_val_mask_func: tp.Optional[ValMaskCallableSerialized] = None
@@ -262,6 +274,7 @@ def __init__( # pylint: disable=too-many-arguments, too-many-locals
262274
item_net_constructor_type: tp.Type[ItemNetConstructorBase] = SumOfEmbeddingsConstructor,
263275
pos_encoding_type: tp.Type[PositionalEncodingBase] = LearnableInversePositionalEncoding,
264276
lightning_module_type: tp.Type[TransformerLightningModuleBase] = TransformerLightningModule,
277+
negative_sampler_type: tp.Type[TransformerNegativeSamplerBase] = CatalogUniformSampler,
265278
similarity_module_type: tp.Type[SimilarityModuleBase] = DistanceSimilarityModule,
266279
backbone_type: tp.Type[TransformerBackboneBase] = TransformerTorchBackbone,
267280
get_val_mask_func: tp.Optional[ValMaskCallable] = None,
@@ -271,6 +284,7 @@ def __init__( # pylint: disable=too-many-arguments, too-many-locals
271284
item_net_constructor_kwargs: tp.Optional[InitKwargs] = None,
272285
pos_encoding_kwargs: tp.Optional[InitKwargs] = None,
273286
lightning_module_kwargs: tp.Optional[InitKwargs] = None,
287+
negative_sampler_kwargs: tp.Optional[InitKwargs] = None,
274288
similarity_module_kwargs: tp.Optional[InitKwargs] = None,
275289
backbone_kwargs: tp.Optional[InitKwargs] = None,
276290
**kwargs: tp.Any,
@@ -302,6 +316,7 @@ def __init__( # pylint: disable=too-many-arguments, too-many-locals
302316
self.item_net_constructor_type = item_net_constructor_type
303317
self.pos_encoding_type = pos_encoding_type
304318
self.lightning_module_type = lightning_module_type
319+
self.negative_sampler_type = negative_sampler_type
305320
self.backbone_type = backbone_type
306321
self.get_val_mask_func = get_val_mask_func
307322
self.get_trainer_func = get_trainer_func
@@ -310,6 +325,7 @@ def __init__( # pylint: disable=too-many-arguments, too-many-locals
310325
self.item_net_constructor_kwargs = item_net_constructor_kwargs
311326
self.pos_encoding_kwargs = pos_encoding_kwargs
312327
self.lightning_module_kwargs = lightning_module_kwargs
328+
self.negative_sampler_kwargs = negative_sampler_kwargs
313329
self.similarity_module_kwargs = similarity_module_kwargs
314330
self.backbone_kwargs = backbone_kwargs
315331

@@ -334,6 +350,7 @@ def _init_data_preparator(self) -> None:
334350
batch_size=self.batch_size,
335351
dataloader_num_workers=self.dataloader_num_workers,
336352
train_min_user_interactions=self.train_min_user_interactions,
353+
negative_sampler=self._init_negative_sampler() if requires_negatives else None,
337354
n_negatives=self.n_negatives if requires_negatives else None,
338355
get_val_mask_func=self.get_val_mask_func,
339356
shuffle_train=True,
@@ -355,6 +372,12 @@ def _init_trainer(self) -> None:
355372
else:
356373
self._trainer = self.get_trainer_func()
357374

375+
def _init_negative_sampler(self) -> TransformerNegativeSamplerBase:
376+
return self.negative_sampler_type(
377+
n_negatives=self.n_negatives,
378+
**self._get_kwargs(self.negative_sampler_kwargs),
379+
)
380+
358381
def _construct_item_net(self, dataset: Dataset) -> ItemNetBase:
359382
return self.item_net_constructor_type.from_dataset(
360383
dataset,

rectools/models/nn/transformers/bert4rec.py

Lines changed: 47 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
)
3939
from .constants import MASKING_VALUE, PADDING_VALUE
4040
from .data_preparator import TransformerDataPreparatorBase
41+
from .negative_sampler import CatalogUniformSampler, TransformerNegativeSamplerBase
4142
from .net_blocks import (
4243
LearnableInversePositionalEncoding,
4344
PositionalEncodingBase,
@@ -49,7 +50,29 @@
4950

5051

5152
class BERT4RecDataPreparator(TransformerDataPreparatorBase):
52-
"""Data Preparator for BERT4RecModel."""
53+
"""Data Preparator for BERT4RecModel.
54+
55+
Parameters
56+
----------
57+
session_max_len : int
58+
Maximum length of user sequence.
59+
batch_size : int
60+
How many samples per batch to load.
61+
dataloader_num_workers : int
62+
Number of loader worker processes.
63+
shuffle_train : bool, default True
64+
If ``True``, reshuffles data at each epoch.
65+
train_min_user_interactions : int, default 2
66+
Minimum length of user sequence. Cannot be less than 2.
67+
get_val_mask_func : Callable, default None
68+
Function to get validation mask.
69+
n_negatives : optional(int), default ``None``
70+
Number of negatives for BCE, gBCE and sampled_softmax losses.
71+
negative_sampler: optional(TransformerNegativeSamplerBase), default ``None``
72+
Negative sampler.
73+
mask_prob : float, default 0.15
74+
Probability of masking an item in interactions sequence.
75+
"""
5376

5477
train_session_max_len_addition: int = 0
5578
item_extra_tokens: tp.Sequence[Hashable] = (PADDING_VALUE, MASKING_VALUE)
@@ -61,6 +84,7 @@ def __init__(
6184
batch_size: int,
6285
dataloader_num_workers: int,
6386
train_min_user_interactions: int,
87+
negative_sampler: tp.Optional[TransformerNegativeSamplerBase] = None,
6488
mask_prob: float = 0.15,
6589
shuffle_train: bool = True,
6690
get_val_mask_func: tp.Optional[ValMaskCallable] = None,
@@ -69,6 +93,7 @@ def __init__(
6993
super().__init__(
7094
session_max_len=session_max_len,
7195
n_negatives=n_negatives,
96+
negative_sampler=negative_sampler,
7297
batch_size=batch_size,
7398
dataloader_num_workers=dataloader_num_workers,
7499
train_min_user_interactions=train_min_user_interactions,
@@ -119,13 +144,10 @@ def _collate_fn_train(
119144
yw[i, -len(ses) :] = ses_weights # ses_weights: [session_len] -> yw[i]: [session_max_len]
120145

121146
batch_dict = {"x": torch.LongTensor(x), "y": torch.LongTensor(y), "yw": torch.FloatTensor(yw)}
122-
if self.n_negatives is not None:
123-
negatives = torch.randint(
124-
low=self.n_item_extra_tokens,
125-
high=self.item_id_map.size,
126-
size=(batch_size, self.session_max_len, self.n_negatives),
127-
) # [batch_size, session_max_len, n_negatives]
128-
batch_dict["negatives"] = negatives
147+
if self.negative_sampler is not None:
148+
batch_dict["negatives"] = self.negative_sampler.get_negatives(
149+
batch_dict, lowest_id=self.n_item_extra_tokens, highest_id=self.item_id_map.size
150+
)
129151
return batch_dict
130152

131153
def _collate_fn_val(self, batch: List[Tuple[List[int], List[float]]]) -> Dict[str, torch.Tensor]:
@@ -147,13 +169,10 @@ def _collate_fn_val(self, batch: List[Tuple[List[int], List[float]]]) -> Dict[st
147169
yw[i, -1:] = ses_weights[target_idx] # yw[i]: [1]
148170

149171
batch_dict = {"x": torch.LongTensor(x), "y": torch.LongTensor(y), "yw": torch.FloatTensor(yw)}
150-
if self.n_negatives is not None:
151-
negatives = torch.randint(
152-
low=self.n_item_extra_tokens,
153-
high=self.item_id_map.size,
154-
size=(batch_size, 1, self.n_negatives),
155-
) # [batch_size, 1, n_negatives]
156-
batch_dict["negatives"] = negatives
172+
if self.negative_sampler is not None:
173+
batch_dict["negatives"] = self.negative_sampler.get_negatives(
174+
batch_dict, lowest_id=self.n_item_extra_tokens, highest_id=self.item_id_map.size, session_len_limit=1
175+
)
157176
return batch_dict
158177

159178
def _collate_fn_recommend(self, batch: List[Tuple[List[int], List[float]]]) -> Dict[str, torch.Tensor]:
@@ -213,7 +232,7 @@ class BERT4RecModel(TransformerModelBase[BERT4RecModelConfig]):
213232
loss : {"softmax", "BCE", "gBCE", "sampled_softmax"}, default "softmax"
214233
Loss function.
215234
n_negatives : int, default 1
216-
Number of negatives for BCE and gBCE losses.
235+
Number of negatives for BCE, gBCE and sampled_softmax losses.
217236
gbce_t : float, default 0.2
218237
Calibration parameter for gBCE loss.
219238
lr : float, default 0.001
@@ -258,6 +277,8 @@ class BERT4RecModel(TransformerModelBase[BERT4RecModelConfig]):
258277
Type of data preparator used for dataset processing and dataloader creation.
259278
lightning_module_type : type(TransformerLightningModuleBase), default `TransformerLightningModule`
260279
Type of lightning module defining training procedure.
280+
negative_sampler_type: type(TransformerNegativeSamplerBase), default `CatalogUniformSampler`
281+
Type of negative sampler.
261282
similarity_module_type : type(SimilarityModuleBase), default `DistanceSimilarityModule`
262283
Type of similarity module.
263284
backbone_type : type(TransformerBackboneBase), default `TransformerTorchBackbone`
@@ -295,6 +316,9 @@ class BERT4RecModel(TransformerModelBase[BERT4RecModelConfig]):
295316
lightning_module_kwargs: optional(dict), default ``None``
296317
Additional keyword arguments to pass during `lightning_module_type` initialization.
297318
Make sure all dict values have JSON serializable types.
319+
negative_sampler_kwargs: optional(dict), default ``None``
320+
Additional keyword arguments to pass during `negative_sampler_type` initialization.
321+
Make sure all dict values have JSON serializable types.
298322
similarity_module_kwargs: optional(dict), default ``None``
299323
Additional keyword arguments to pass during `similarity_module_type` initialization.
300324
Make sure all dict values have JSON serializable types.
@@ -332,6 +356,7 @@ def __init__( # pylint: disable=too-many-arguments, too-many-locals
332356
transformer_layers_type: tp.Type[TransformerLayersBase] = PreLNTransformerLayers,
333357
data_preparator_type: tp.Type[TransformerDataPreparatorBase] = BERT4RecDataPreparator,
334358
lightning_module_type: tp.Type[TransformerLightningModuleBase] = TransformerLightningModule,
359+
negative_sampler_type: tp.Type[TransformerNegativeSamplerBase] = CatalogUniformSampler,
335360
similarity_module_type: tp.Type[SimilarityModuleBase] = DistanceSimilarityModule,
336361
backbone_type: tp.Type[TransformerBackboneBase] = TransformerTorchBackbone,
337362
get_val_mask_func: tp.Optional[ValMaskCallable] = None,
@@ -346,6 +371,7 @@ def __init__( # pylint: disable=too-many-arguments, too-many-locals
346371
item_net_constructor_kwargs: tp.Optional[InitKwargs] = None,
347372
pos_encoding_kwargs: tp.Optional[InitKwargs] = None,
348373
lightning_module_kwargs: tp.Optional[InitKwargs] = None,
374+
negative_sampler_kwargs: tp.Optional[InitKwargs] = None,
349375
similarity_module_kwargs: tp.Optional[InitKwargs] = None,
350376
backbone_kwargs: tp.Optional[InitKwargs] = None,
351377
):
@@ -381,6 +407,7 @@ def __init__( # pylint: disable=too-many-arguments, too-many-locals
381407
item_net_constructor_type=item_net_constructor_type,
382408
pos_encoding_type=pos_encoding_type,
383409
lightning_module_type=lightning_module_type,
410+
negative_sampler_type=negative_sampler_type,
384411
backbone_type=backbone_type,
385412
get_val_mask_func=get_val_mask_func,
386413
get_trainer_func=get_trainer_func,
@@ -390,14 +417,17 @@ def __init__( # pylint: disable=too-many-arguments, too-many-locals
390417
item_net_constructor_kwargs=item_net_constructor_kwargs,
391418
pos_encoding_kwargs=pos_encoding_kwargs,
392419
lightning_module_kwargs=lightning_module_kwargs,
420+
negative_sampler_kwargs=negative_sampler_kwargs,
393421
similarity_module_kwargs=similarity_module_kwargs,
394422
backbone_kwargs=backbone_kwargs,
395423
)
396424

397425
def _init_data_preparator(self) -> None:
426+
requires_negatives = self.lightning_module_type.requires_negatives(self.loss)
398427
self.data_preparator: TransformerDataPreparatorBase = self.data_preparator_type(
399428
session_max_len=self.session_max_len,
400-
n_negatives=self.n_negatives if self.loss != "softmax" else None,
429+
n_negatives=self.n_negatives if requires_negatives else None,
430+
negative_sampler=self._init_negative_sampler() if requires_negatives else None,
401431
batch_size=self.batch_size,
402432
dataloader_num_workers=self.dataloader_num_workers,
403433
train_min_user_interactions=self.train_min_user_interactions,

rectools/models/nn/transformers/data_preparator.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from rectools.dataset.identifiers import IdMap
3030

3131
from .constants import PADDING_VALUE
32+
from .negative_sampler import TransformerNegativeSamplerBase
3233

3334

3435
class SequenceDataset(TorchDataset):
@@ -104,6 +105,10 @@ class TransformerDataPreparatorBase:
104105
Minimum length of user sequence. Cannot be less than 2.
105106
get_val_mask_func : Callable, default None
106107
Function to get validation mask.
108+
n_negatives : optional(int), default ``None``
109+
Number of negatives for BCE, gBCE and sampled_softmax losses.
110+
negative_sampler: optional(TransformerNegativeSamplerBase), default ``None``
111+
Negative sampler.
107112
"""
108113

109114
# We sometimes need data preparators to add +1 to actual session_max_len
@@ -119,15 +124,17 @@ def __init__(
119124
dataloader_num_workers: int,
120125
shuffle_train: bool = True,
121126
train_min_user_interactions: int = 2,
122-
n_negatives: tp.Optional[int] = None,
123127
get_val_mask_func: tp.Optional[tp.Callable] = None,
128+
n_negatives: tp.Optional[int] = None,
129+
negative_sampler: tp.Optional[TransformerNegativeSamplerBase] = None,
124130
**kwargs: tp.Any,
125131
) -> None:
126132
self.item_id_map: IdMap
127133
self.extra_token_ids: tp.Dict
128134
self.train_dataset: Dataset
129135
self.val_interactions: tp.Optional[pd.DataFrame] = None
130136
self.session_max_len = session_max_len
137+
self.negative_sampler = negative_sampler
131138
self.n_negatives = n_negatives
132139
self.batch_size = batch_size
133140
self.dataloader_num_workers = dataloader_num_workers
@@ -189,6 +196,7 @@ def process_dataset_train(self, dataset: Dataset) -> None:
189196
if self.get_val_mask_func is not None:
190197
val_mask = self.get_val_mask_func(raw_interactions)
191198
interactions = raw_interactions[~val_mask]
199+
interactions.reset_index(drop=True, inplace=True)
192200

193201
# Filter train interactions
194202
interactions = self._filter_train_interactions(interactions)

rectools/models/nn/transformers/lightning.py

Lines changed: 43 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,18 @@ class TransformerLightningModuleBase(LightningModule): # pylint: disable=too-ma
4040
----------
4141
torch_model : TransformerBackboneBase
4242
Torch model to make recommendations.
43+
model_config: Dict[str, Any]
44+
Model config.
45+
dataset_schema: DatasetSchemaDict
46+
Dataset schema.
47+
item_external_ids: ExternalIds
48+
External item ids from train dataset.
49+
item_extra_tokens : Sequence(Hashable)
50+
Elements used for sequence padding.
4351
lr : float
4452
Learning rate.
53+
gbce_t : float
54+
Calibration parameter for gBCE loss.
4555
loss : str, default "softmax"
4656
Loss function.
4757
adam_betas : Tuple[float, float], default (0.9, 0.98)
@@ -240,7 +250,37 @@ def _recommend_i2i(
240250

241251

242252
class TransformerLightningModule(TransformerLightningModuleBase):
243-
"""Lightning module to train transformer models."""
253+
"""Lightning module to train transformer models.
254+
255+
Parameters
256+
----------
257+
torch_model : TransformerBackboneBase
258+
Torch model to make recommendations.
259+
model_config: Dict[str, Any]
260+
Model config.
261+
dataset_schema: DatasetSchemaDict
262+
Dataset schema.
263+
item_external_ids: ExternalIds
264+
External item ids from train dataset.
265+
item_extra_tokens : Sequence(Hashable)
266+
Elements used for sequence padding.
267+
lr : float
268+
Learning rate.
269+
gbce_t : float
270+
Calibration parameter for gBCE loss.
271+
loss : str, default "softmax"
272+
Loss function.
273+
adam_betas : Tuple[float, float], default (0.9, 0.98)
274+
Coefficients for running averages of gradient and its square.
275+
data_preparator : TransformerDataPreparatorBase
276+
Data preparator.
277+
verbose : int, default 0
278+
Verbosity level.
279+
train_loss_name : str, default "train_loss"
280+
Name of the training loss.
281+
val_loss_name : str, default "val_loss"
282+
Name of the training loss.
283+
"""
244284

245285
i2i_dist = Distance.COSINE
246286

@@ -296,7 +336,7 @@ def validation_step(self, batch: tp.Dict[str, torch.Tensor], batch_idx: int) ->
296336
type_logits = "pos_neg_logits" if self._requires_negatives else "logits"
297337
outputs = {
298338
"loss": loss,
299-
type_logits: logits,
339+
type_logits: logits.squeeze(),
300340
}
301341
else:
302342
outputs = self._calc_custom_loss_outputs(batch, batch_idx) # pragma: no cover
@@ -339,7 +379,7 @@ def _get_user_item_embeddings(
339379
for batch in recommend_dataloader:
340380
batch = {k: v.to(device) for k, v in batch.items()}
341381
batch_embs = self.torch_model.encode_sessions(batch, item_embs)[:, -1, :]
342-
user_embs.append(batch_embs)
382+
user_embs.append(batch_embs.cpu())
343383

344384
return torch.cat(user_embs), item_embs
345385

0 commit comments

Comments
 (0)