Skip to content

Commit 79dc992

Browse files
authored
Feature/trainer kwargs (#280)
Added `get_val_mask_func_kwargs` and `get_trainer_func_kwargs` arguments
1 parent a6ec1c1 commit 79dc992

File tree

7 files changed

+187
-15
lines changed

7 files changed

+187
-15
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1111
- Python 3.13 support ([#227](https://github.com/MobileTeleSystems/RecTools/pull/227))
1212
- `fit_partial` implementation for transformer-based models ([#273](https://github.com/MobileTeleSystems/RecTools/pull/273))
1313
- `map_location` and `model_params_update` arguments for the function `load_from_checkpoint` for Transformer-based models. Use `map_location` to explicitly specify the computing new device and `model_params_update` to update original model parameters (e.g. remove training-specific parameters that are not needed anymore) ([#281](https://github.com/MobileTeleSystems/RecTools/pull/281))
14+
- `get_val_mask_func_kwargs` and `get_trainer_func_kwargs` arguments for Transformer-based models to allow keyword arguments in custom functions used for model training. ([#280](https://github.com/MobileTeleSystems/RecTools/pull/280))
1415

1516
## [0.13.0] - 10.04.2025
1617

rectools/models/nn/transformers/base.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838
ItemNetConstructorBase,
3939
SumOfEmbeddingsConstructor,
4040
)
41-
from .data_preparator import TransformerDataPreparatorBase
41+
from .data_preparator import InitKwargs, TransformerDataPreparatorBase
4242
from .lightning import TransformerLightningModule, TransformerLightningModuleBase
4343
from .negative_sampler import CatalogUniformSampler, TransformerNegativeSamplerBase
4444
from .net_blocks import (
@@ -50,8 +50,6 @@
5050
from .similarity import DistanceSimilarityModule, SimilarityModuleBase
5151
from .torch_backbone import TransformerBackboneBase, TransformerTorchBackbone
5252

53-
InitKwargs = tp.Dict[str, tp.Any]
54-
5553
# #### -------------- Transformer Model Config -------------- #### #
5654

5755

@@ -161,7 +159,7 @@ def _serialize_type_sequence(obj: tp.Sequence[tp.Type]) -> tp.Tuple[str, ...]:
161159
]
162160

163161

164-
ValMaskCallable = Callable[[], np.ndarray]
162+
ValMaskCallable = Callable[..., np.ndarray]
165163

166164
ValMaskCallableSerialized = tpe.Annotated[
167165
ValMaskCallable,
@@ -173,7 +171,7 @@ def _serialize_type_sequence(obj: tp.Sequence[tp.Type]) -> tp.Tuple[str, ...]:
173171
),
174172
]
175173

176-
TrainerCallable = Callable[[], Trainer]
174+
TrainerCallable = Callable[..., Trainer]
177175

178176
TrainerCallableSerialized = tpe.Annotated[
179177
TrainerCallable,
@@ -220,6 +218,8 @@ class TransformerModelConfig(ModelConfig):
220218
backbone_type: TransformerBackboneType = TransformerTorchBackbone
221219
get_val_mask_func: tp.Optional[ValMaskCallableSerialized] = None
222220
get_trainer_func: tp.Optional[TrainerCallableSerialized] = None
221+
get_val_mask_func_kwargs: tp.Optional[InitKwargs] = None
222+
get_trainer_func_kwargs: tp.Optional[InitKwargs] = None
223223
data_preparator_kwargs: tp.Optional[InitKwargs] = None
224224
transformer_layers_kwargs: tp.Optional[InitKwargs] = None
225225
item_net_constructor_kwargs: tp.Optional[InitKwargs] = None
@@ -280,6 +280,8 @@ def __init__( # pylint: disable=too-many-arguments, too-many-locals
280280
backbone_type: tp.Type[TransformerBackboneBase] = TransformerTorchBackbone,
281281
get_val_mask_func: tp.Optional[ValMaskCallable] = None,
282282
get_trainer_func: tp.Optional[TrainerCallable] = None,
283+
get_val_mask_func_kwargs: tp.Optional[InitKwargs] = None,
284+
get_trainer_func_kwargs: tp.Optional[InitKwargs] = None,
283285
data_preparator_kwargs: tp.Optional[InitKwargs] = None,
284286
transformer_layers_kwargs: tp.Optional[InitKwargs] = None,
285287
item_net_constructor_kwargs: tp.Optional[InitKwargs] = None,
@@ -321,6 +323,8 @@ def __init__( # pylint: disable=too-many-arguments, too-many-locals
321323
self.backbone_type = backbone_type
322324
self.get_val_mask_func = get_val_mask_func
323325
self.get_trainer_func = get_trainer_func
326+
self.get_val_mask_func_kwargs = get_val_mask_func_kwargs
327+
self.get_trainer_func_kwargs = get_trainer_func_kwargs
324328
self.data_preparator_kwargs = data_preparator_kwargs
325329
self.transformer_layers_kwargs = transformer_layers_kwargs
326330
self.item_net_constructor_kwargs = item_net_constructor_kwargs
@@ -354,6 +358,7 @@ def _init_data_preparator(self) -> None:
354358
negative_sampler=self._init_negative_sampler() if requires_negatives else None,
355359
n_negatives=self.n_negatives if requires_negatives else None,
356360
get_val_mask_func=self.get_val_mask_func,
361+
get_val_mask_func_kwargs=self.get_val_mask_func_kwargs,
357362
**self._get_kwargs(self.data_preparator_kwargs),
358363
)
359364

@@ -370,7 +375,7 @@ def _init_trainer(self) -> None:
370375
devices=1,
371376
)
372377
else:
373-
self._trainer = self.get_trainer_func()
378+
self._trainer = self.get_trainer_func(**self._get_kwargs(self.get_trainer_func_kwargs))
374379

375380
def _init_negative_sampler(self) -> TransformerNegativeSamplerBase:
376381
return self.negative_sampler_type(

rectools/models/nn/transformers/bert4rec.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@
2727
SumOfEmbeddingsConstructor,
2828
)
2929
from .base import (
30-
InitKwargs,
3130
TrainerCallable,
3231
TransformerDataPreparatorType,
3332
TransformerLightningModule,
@@ -37,7 +36,7 @@
3736
ValMaskCallable,
3837
)
3938
from .constants import MASKING_VALUE, PADDING_VALUE
40-
from .data_preparator import TransformerDataPreparatorBase
39+
from .data_preparator import InitKwargs, TransformerDataPreparatorBase
4140
from .negative_sampler import CatalogUniformSampler, TransformerNegativeSamplerBase
4241
from .net_blocks import (
4342
LearnableInversePositionalEncoding,
@@ -72,6 +71,9 @@ class BERT4RecDataPreparator(TransformerDataPreparatorBase):
7271
Negative sampler.
7372
mask_prob : float, default 0.15
7473
Probability of masking an item in interactions sequence.
74+
get_val_mask_func_kwargs: optional(InitKwargs), default ``None``
75+
Additional arguments for the get_val_mask_func.
76+
Make sure all dict values have JSON serializable types.
7577
"""
7678

7779
train_session_max_len_addition: int = 0
@@ -88,6 +90,7 @@ def __init__(
8890
mask_prob: float = 0.15,
8991
get_val_mask_func: tp.Optional[ValMaskCallable] = None,
9092
shuffle_train: bool = True,
93+
get_val_mask_func_kwargs: tp.Optional[InitKwargs] = None,
9194
**kwargs: tp.Any,
9295
) -> None:
9396
super().__init__(
@@ -99,6 +102,7 @@ def __init__(
99102
train_min_user_interactions=train_min_user_interactions,
100103
shuffle_train=shuffle_train,
101104
get_val_mask_func=get_val_mask_func,
105+
get_val_mask_func_kwargs=get_val_mask_func_kwargs,
102106
)
103107
self.mask_prob = mask_prob
104108

@@ -301,6 +305,12 @@ class BERT4RecModel(TransformerModelBase[BERT4RecModelConfig]):
301305
When set to ``None``, "cuda" will be used if it is available, "cpu" otherwise.
302306
If you want to change this parameter after model is initialized,
303307
you can manually assign new value to model `recommend_torch_device` attribute.
308+
get_val_mask_func_kwargs: optional(InitKwargs), default ``None``
309+
Additional keyword arguments for the get_val_mask_func.
310+
Make sure all dict values have JSON serializable types.
311+
get_trainer_func_kwargs: optional(InitKwargs), default ``None``
312+
Additional keyword arguments for the get_trainer_func.
313+
Make sure all dict values have JSON serializable types.
304314
data_preparator_kwargs: optional(dict), default ``None``
305315
Additional keyword arguments to pass during `data_preparator_type` initialization.
306316
Make sure all dict values have JSON serializable types.
@@ -361,6 +371,8 @@ def __init__( # pylint: disable=too-many-arguments, too-many-locals
361371
backbone_type: tp.Type[TransformerBackboneBase] = TransformerTorchBackbone,
362372
get_val_mask_func: tp.Optional[ValMaskCallable] = None,
363373
get_trainer_func: tp.Optional[TrainerCallable] = None,
374+
get_val_mask_func_kwargs: tp.Optional[InitKwargs] = None,
375+
get_trainer_func_kwargs: tp.Optional[InitKwargs] = None,
364376
recommend_batch_size: int = 256,
365377
recommend_torch_device: tp.Optional[str] = None,
366378
recommend_use_torch_ranking: bool = True,
@@ -411,6 +423,8 @@ def __init__( # pylint: disable=too-many-arguments, too-many-locals
411423
backbone_type=backbone_type,
412424
get_val_mask_func=get_val_mask_func,
413425
get_trainer_func=get_trainer_func,
426+
get_val_mask_func_kwargs=get_val_mask_func_kwargs,
427+
get_trainer_func_kwargs=get_trainer_func_kwargs,
414428
data_preparator_kwargs=data_preparator_kwargs,
415429
transformer_layers_kwargs=transformer_layers_kwargs,
416430
item_net_block_kwargs=item_net_block_kwargs,
@@ -433,5 +447,6 @@ def _init_data_preparator(self) -> None:
433447
train_min_user_interactions=self.train_min_user_interactions,
434448
mask_prob=self.mask_prob,
435449
get_val_mask_func=self.get_val_mask_func,
450+
get_val_mask_func_kwargs=self.get_val_mask_func_kwargs,
436451
**self._get_kwargs(self.data_preparator_kwargs),
437452
)

rectools/models/nn/transformers/data_preparator.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@
3131
from .constants import PADDING_VALUE
3232
from .negative_sampler import TransformerNegativeSamplerBase
3333

34+
InitKwargs = tp.Dict[str, tp.Any]
35+
3436

3537
class SequenceDataset(TorchDataset):
3638
"""
@@ -84,7 +86,7 @@ def from_interactions(
8486
return cls(sessions=sessions, weights=weights)
8587

8688

87-
class TransformerDataPreparatorBase:
89+
class TransformerDataPreparatorBase: # pylint: disable=too-many-instance-attributes
8890
"""
8991
Base class for data preparator. To change train/recommend dataset processing, train/recommend dataloaders inherit
9092
from this class and pass your custom data preparator to your model parameters.
@@ -109,6 +111,9 @@ class TransformerDataPreparatorBase:
109111
Number of negatives for BCE, gBCE and sampled_softmax losses.
110112
negative_sampler: optional(TransformerNegativeSamplerBase), default ``None``
111113
Negative sampler.
114+
get_val_mask_func_kwargs: optional(InitKwargs), default ``None``
115+
Additional keyword arguments for the get_val_mask_func.
116+
Make sure all dict values have JSON serializable types.
112117
"""
113118

114119
# We sometimes need data preparators to add +1 to actual session_max_len
@@ -127,6 +132,7 @@ def __init__(
127132
shuffle_train: bool = True,
128133
n_negatives: tp.Optional[int] = None,
129134
negative_sampler: tp.Optional[TransformerNegativeSamplerBase] = None,
135+
get_val_mask_func_kwargs: tp.Optional[InitKwargs] = None,
130136
**kwargs: tp.Any,
131137
) -> None:
132138
self.item_id_map: IdMap
@@ -141,6 +147,7 @@ def __init__(
141147
self.train_min_user_interactions = train_min_user_interactions
142148
self.shuffle_train = shuffle_train
143149
self.get_val_mask_func = get_val_mask_func
150+
self.get_val_mask_func_kwargs = get_val_mask_func_kwargs
144151

145152
def get_known_items_sorted_internal_ids(self) -> np.ndarray:
146153
"""Return internal item ids from processed dataset in sorted order."""
@@ -150,6 +157,13 @@ def get_known_item_ids(self) -> np.ndarray:
150157
"""Return external item ids from processed dataset in sorted order."""
151158
return self.item_id_map.get_external_sorted_by_internal()[self.n_item_extra_tokens :]
152159

160+
@staticmethod
161+
def _ensure_kwargs_dict(actual_kwargs: tp.Optional[InitKwargs]) -> InitKwargs:
162+
kwargs = {}
163+
if actual_kwargs is not None:
164+
kwargs = actual_kwargs
165+
return kwargs
166+
153167
@property
154168
def n_item_extra_tokens(self) -> int:
155169
"""Return number of padding elements"""
@@ -194,7 +208,9 @@ def process_dataset_train(self, dataset: Dataset) -> None:
194208
# Exclude val interaction targets from train if needed
195209
interactions = raw_interactions
196210
if self.get_val_mask_func is not None:
197-
val_mask = self.get_val_mask_func(raw_interactions)
211+
val_mask = self.get_val_mask_func(
212+
raw_interactions, **self._ensure_kwargs_dict(self.get_val_mask_func_kwargs)
213+
)
198214
interactions = raw_interactions[~val_mask]
199215
interactions.reset_index(drop=True, inplace=True)
200216

rectools/models/nn/transformers/sasrec.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@
2727
SumOfEmbeddingsConstructor,
2828
)
2929
from .base import (
30-
InitKwargs,
3130
TrainerCallable,
3231
TransformerDataPreparatorType,
3332
TransformerLayersType,
@@ -37,7 +36,7 @@
3736
TransformerModelConfig,
3837
ValMaskCallable,
3938
)
40-
from .data_preparator import TransformerDataPreparatorBase
39+
from .data_preparator import InitKwargs, TransformerDataPreparatorBase
4140
from .negative_sampler import CatalogUniformSampler, TransformerNegativeSamplerBase
4241
from .net_blocks import (
4342
LearnableInversePositionalEncoding,
@@ -72,6 +71,9 @@ class SASRecDataPreparator(TransformerDataPreparatorBase):
7271
Number of negatives for BCE, gBCE and sampled_softmax losses.
7372
negative_sampler: optional(TransformerNegativeSamplerBase), default ``None``
7473
Negative sampler.
74+
get_val_mask_func_kwargs: optional(InitKwargs), default ``None``
75+
Additional arguments for the get_val_mask_func.
76+
Make sure all dict values have JSON serializable types.
7577
"""
7678

7779
train_session_max_len_addition: int = 1
@@ -379,6 +381,12 @@ class SASRecModel(TransformerModelBase[SASRecModelConfig]):
379381
When set to ``None``, "cuda" will be used if it is available, "cpu" otherwise.
380382
If you want to change this parameter after model is initialized,
381383
you can manually assign new value to model `recommend_torch_device` attribute.
384+
get_val_mask_func_kwargs: optional(InitKwargs), default ``None``
385+
Additional keyword arguments for the get_val_mask_func.
386+
Make sure all dict values have JSON serializable types.
387+
get_trainer_func_kwargs: optional(InitKwargs), default ``None``
388+
Additional keyword arguments for the get_trainer_func.
389+
Make sure all dict values have JSON serializable types.
382390
data_preparator_kwargs: optional(dict), default ``None``
383391
Additional keyword arguments to pass during `data_preparator_type` initialization.
384392
Make sure all dict values have JSON serializable types.
@@ -438,6 +446,8 @@ def __init__( # pylint: disable=too-many-arguments, too-many-locals
438446
backbone_type: tp.Type[TransformerBackboneBase] = TransformerTorchBackbone,
439447
get_val_mask_func: tp.Optional[ValMaskCallable] = None,
440448
get_trainer_func: tp.Optional[TrainerCallable] = None,
449+
get_val_mask_func_kwargs: tp.Optional[InitKwargs] = None,
450+
get_trainer_func_kwargs: tp.Optional[InitKwargs] = None,
441451
recommend_batch_size: int = 256,
442452
recommend_torch_device: tp.Optional[str] = None,
443453
recommend_use_torch_ranking: bool = True,
@@ -485,6 +495,8 @@ def __init__( # pylint: disable=too-many-arguments, too-many-locals
485495
backbone_type=backbone_type,
486496
get_val_mask_func=get_val_mask_func,
487497
get_trainer_func=get_trainer_func,
498+
get_val_mask_func_kwargs=get_val_mask_func_kwargs,
499+
get_trainer_func_kwargs=get_trainer_func_kwargs,
488500
data_preparator_kwargs=data_preparator_kwargs,
489501
transformer_layers_kwargs=transformer_layers_kwargs,
490502
item_net_constructor_kwargs=item_net_constructor_kwargs,

0 commit comments

Comments
 (0)