3838)
3939from .constants import MASKING_VALUE , PADDING_VALUE
4040from .data_preparator import TransformerDataPreparatorBase
41+ from .negative_sampler import CatalogUniformSampler , TransformerNegativeSamplerBase
4142from .net_blocks import (
4243 LearnableInversePositionalEncoding ,
4344 PositionalEncodingBase ,
4950
5051
5152class BERT4RecDataPreparator (TransformerDataPreparatorBase ):
52- """Data Preparator for BERT4RecModel."""
53+ """Data Preparator for BERT4RecModel.
54+
55+ Parameters
56+ ----------
57+ session_max_len : int
58+ Maximum length of user sequence.
59+ batch_size : int
60+ How many samples per batch to load.
61+ dataloader_num_workers : int
62+ Number of loader worker processes.
63+ shuffle_train : bool, default True
64+ If ``True``, reshuffles data at each epoch.
65+ train_min_user_interactions : int, default 2
66+ Minimum length of user sequence. Cannot be less than 2.
67+ get_val_mask_func : Callable, default None
68+ Function to get validation mask.
69+ n_negatives : optional(int), default ``None``
70+ Number of negatives for BCE, gBCE and sampled_softmax losses.
71+ negative_sampler: optional(TransformerNegativeSamplerBase), default ``None``
72+ Negative sampler.
73+ mask_prob : float, default 0.15
74+ Probability of masking an item in interactions sequence.
75+ """
5376
5477 train_session_max_len_addition : int = 0
5578 item_extra_tokens : tp .Sequence [Hashable ] = (PADDING_VALUE , MASKING_VALUE )
@@ -61,6 +84,7 @@ def __init__(
6184 batch_size : int ,
6285 dataloader_num_workers : int ,
6386 train_min_user_interactions : int ,
87+ negative_sampler : tp .Optional [TransformerNegativeSamplerBase ] = None ,
6488 mask_prob : float = 0.15 ,
6589 shuffle_train : bool = True ,
6690 get_val_mask_func : tp .Optional [ValMaskCallable ] = None ,
@@ -69,6 +93,7 @@ def __init__(
6993 super ().__init__ (
7094 session_max_len = session_max_len ,
7195 n_negatives = n_negatives ,
96+ negative_sampler = negative_sampler ,
7297 batch_size = batch_size ,
7398 dataloader_num_workers = dataloader_num_workers ,
7499 train_min_user_interactions = train_min_user_interactions ,
@@ -119,13 +144,10 @@ def _collate_fn_train(
119144 yw [i , - len (ses ) :] = ses_weights # ses_weights: [session_len] -> yw[i]: [session_max_len]
120145
121146 batch_dict = {"x" : torch .LongTensor (x ), "y" : torch .LongTensor (y ), "yw" : torch .FloatTensor (yw )}
122- if self .n_negatives is not None :
123- negatives = torch .randint (
124- low = self .n_item_extra_tokens ,
125- high = self .item_id_map .size ,
126- size = (batch_size , self .session_max_len , self .n_negatives ),
127- ) # [batch_size, session_max_len, n_negatives]
128- batch_dict ["negatives" ] = negatives
147+ if self .negative_sampler is not None :
148+ batch_dict ["negatives" ] = self .negative_sampler .get_negatives (
149+ batch_dict , lowest_id = self .n_item_extra_tokens , highest_id = self .item_id_map .size
150+ )
129151 return batch_dict
130152
131153 def _collate_fn_val (self , batch : List [Tuple [List [int ], List [float ]]]) -> Dict [str , torch .Tensor ]:
@@ -147,13 +169,10 @@ def _collate_fn_val(self, batch: List[Tuple[List[int], List[float]]]) -> Dict[st
147169 yw [i , - 1 :] = ses_weights [target_idx ] # yw[i]: [1]
148170
149171 batch_dict = {"x" : torch .LongTensor (x ), "y" : torch .LongTensor (y ), "yw" : torch .FloatTensor (yw )}
150- if self .n_negatives is not None :
151- negatives = torch .randint (
152- low = self .n_item_extra_tokens ,
153- high = self .item_id_map .size ,
154- size = (batch_size , 1 , self .n_negatives ),
155- ) # [batch_size, 1, n_negatives]
156- batch_dict ["negatives" ] = negatives
172+ if self .negative_sampler is not None :
173+ batch_dict ["negatives" ] = self .negative_sampler .get_negatives (
174+ batch_dict , lowest_id = self .n_item_extra_tokens , highest_id = self .item_id_map .size , session_len_limit = 1
175+ )
157176 return batch_dict
158177
159178 def _collate_fn_recommend (self , batch : List [Tuple [List [int ], List [float ]]]) -> Dict [str , torch .Tensor ]:
@@ -213,7 +232,7 @@ class BERT4RecModel(TransformerModelBase[BERT4RecModelConfig]):
213232 loss : {"softmax", "BCE", "gBCE", "sampled_softmax"}, default "softmax"
214233 Loss function.
215234 n_negatives : int, default 1
216- Number of negatives for BCE and gBCE losses.
235+ Number of negatives for BCE, gBCE and sampled_softmax losses.
217236 gbce_t : float, default 0.2
218237 Calibration parameter for gBCE loss.
219238 lr : float, default 0.001
@@ -258,6 +277,8 @@ class BERT4RecModel(TransformerModelBase[BERT4RecModelConfig]):
258277 Type of data preparator used for dataset processing and dataloader creation.
259278 lightning_module_type : type(TransformerLightningModuleBase), default `TransformerLightningModule`
260279 Type of lightning module defining training procedure.
280+ negative_sampler_type: type(TransformerNegativeSamplerBase), default `CatalogUniformSampler`
281+ Type of negative sampler.
261282 similarity_module_type : type(SimilarityModuleBase), default `DistanceSimilarityModule`
262283 Type of similarity module.
263284 backbone_type : type(TransformerBackboneBase), default `TransformerTorchBackbone`
@@ -295,6 +316,9 @@ class BERT4RecModel(TransformerModelBase[BERT4RecModelConfig]):
295316 lightning_module_kwargs: optional(dict), default ``None``
296317 Additional keyword arguments to pass during `lightning_module_type` initialization.
297318 Make sure all dict values have JSON serializable types.
319+ negative_sampler_kwargs: optional(dict), default ``None``
320+ Additional keyword arguments to pass during `negative_sampler_type` initialization.
321+ Make sure all dict values have JSON serializable types.
298322 similarity_module_kwargs: optional(dict), default ``None``
299323 Additional keyword arguments to pass during `similarity_module_type` initialization.
300324 Make sure all dict values have JSON serializable types.
@@ -332,6 +356,7 @@ def __init__( # pylint: disable=too-many-arguments, too-many-locals
332356 transformer_layers_type : tp .Type [TransformerLayersBase ] = PreLNTransformerLayers ,
333357 data_preparator_type : tp .Type [TransformerDataPreparatorBase ] = BERT4RecDataPreparator ,
334358 lightning_module_type : tp .Type [TransformerLightningModuleBase ] = TransformerLightningModule ,
359+ negative_sampler_type : tp .Type [TransformerNegativeSamplerBase ] = CatalogUniformSampler ,
335360 similarity_module_type : tp .Type [SimilarityModuleBase ] = DistanceSimilarityModule ,
336361 backbone_type : tp .Type [TransformerBackboneBase ] = TransformerTorchBackbone ,
337362 get_val_mask_func : tp .Optional [ValMaskCallable ] = None ,
@@ -346,6 +371,7 @@ def __init__( # pylint: disable=too-many-arguments, too-many-locals
346371 item_net_constructor_kwargs : tp .Optional [InitKwargs ] = None ,
347372 pos_encoding_kwargs : tp .Optional [InitKwargs ] = None ,
348373 lightning_module_kwargs : tp .Optional [InitKwargs ] = None ,
374+ negative_sampler_kwargs : tp .Optional [InitKwargs ] = None ,
349375 similarity_module_kwargs : tp .Optional [InitKwargs ] = None ,
350376 backbone_kwargs : tp .Optional [InitKwargs ] = None ,
351377 ):
@@ -381,6 +407,7 @@ def __init__( # pylint: disable=too-many-arguments, too-many-locals
381407 item_net_constructor_type = item_net_constructor_type ,
382408 pos_encoding_type = pos_encoding_type ,
383409 lightning_module_type = lightning_module_type ,
410+ negative_sampler_type = negative_sampler_type ,
384411 backbone_type = backbone_type ,
385412 get_val_mask_func = get_val_mask_func ,
386413 get_trainer_func = get_trainer_func ,
@@ -390,14 +417,17 @@ def __init__( # pylint: disable=too-many-arguments, too-many-locals
390417 item_net_constructor_kwargs = item_net_constructor_kwargs ,
391418 pos_encoding_kwargs = pos_encoding_kwargs ,
392419 lightning_module_kwargs = lightning_module_kwargs ,
420+ negative_sampler_kwargs = negative_sampler_kwargs ,
393421 similarity_module_kwargs = similarity_module_kwargs ,
394422 backbone_kwargs = backbone_kwargs ,
395423 )
396424
397425 def _init_data_preparator (self ) -> None :
426+ requires_negatives = self .lightning_module_type .requires_negatives (self .loss )
398427 self .data_preparator : TransformerDataPreparatorBase = self .data_preparator_type (
399428 session_max_len = self .session_max_len ,
400- n_negatives = self .n_negatives if self .loss != "softmax" else None ,
429+ n_negatives = self .n_negatives if requires_negatives else None ,
430+ negative_sampler = self ._init_negative_sampler () if requires_negatives else None ,
401431 batch_size = self .batch_size ,
402432 dataloader_num_workers = self .dataloader_num_workers ,
403433 train_min_user_interactions = self .train_min_user_interactions ,
0 commit comments