Skip to content

Commit 69f6736

Browse files
authored
Feature/transformer_backbone (#277)
Added `backbone_type` and `backbone_kwargs` parameters to transformer-based models
1 parent 6b7fda0 commit 69f6736

File tree

10 files changed

+219
-71
lines changed

10 files changed

+219
-71
lines changed

CHANGELOG.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
88
## Unreleased
99

1010
### Added
11-
- `SimilarityModuleBase`, `DistanceSimilarityModule`, similarity module to `TransformerTorchBackbone`, parameters to transformer-based models `similarity_module_type`, `similarity_module_kwargs` ([#272](https://github.com/MobileTeleSystems/RecTools/pull/272))
11+
- `SimilarityModuleBase`, `DistanceSimilarityModule`, similarity module to `TransformerTorchBackbone` parameters to transformer-based models `similarity_module_type`, `similarity_module_kwargs` ([#272](https://github.com/MobileTeleSystems/RecTools/pull/272))
1212
- `out_dim` property to `IdEmbeddingsItemNet`, `CatFeaturesItemNet` and `SumOfEmbeddingsConstructor` ([#276](https://github.com/MobileTeleSystems/RecTools/pull/276))
13+
- `TransformerBackboneBase`, `backbone_type` and `backbone_kwargs` parameters to transformer-based models ([#277](https://github.com/MobileTeleSystems/RecTools/pull/277))
1314
- `sampled_softmax` loss option for transformer models ([#274](https://github.com/MobileTeleSystems/RecTools/pull/274))
1415

1516
## [0.12.0] - 24.02.2025

examples/tutorials/transformers_advanced_training_guide.ipynb

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -412,15 +412,15 @@
412412
"name": "stdout",
413413
"output_type": "stream",
414414
"text": [
415-
"epoch,step,train_loss,val_loss\r",
415+
"epoch,step,train_loss,val_loss\r\n",
416416
"\r\n",
417-
"0,1,,22.365339279174805\r",
417+
"0,1,,22.365339279174805\r\n",
418418
"\r\n",
419-
"0,1,22.38391876220703,\r",
419+
"0,1,22.38391876220703,\r\n",
420420
"\r\n",
421-
"1,3,,22.189851760864258\r",
421+
"1,3,,22.189851760864258\r\n",
422422
"\r\n",
423-
"1,3,22.898216247558594,\r",
423+
"1,3,22.898216247558594,\r\n",
424424
"\r\n"
425425
]
426426
}
@@ -526,23 +526,23 @@
526526
"name": "stdout",
527527
"output_type": "stream",
528528
"text": [
529-
"epoch,step,train_loss,val_loss\r",
529+
"epoch,step,train_loss,val_loss\r\n",
530530
"\r\n",
531-
"0,1,,22.343637466430664\r",
531+
"0,1,,22.343637466430664\r\n",
532532
"\r\n",
533-
"0,1,22.36273765563965,\r",
533+
"0,1,22.36273765563965,\r\n",
534534
"\r\n",
535-
"1,3,,22.159835815429688\r",
535+
"1,3,,22.159835815429688\r\n",
536536
"\r\n",
537-
"1,3,22.33755874633789,\r",
537+
"1,3,22.33755874633789,\r\n",
538538
"\r\n",
539-
"2,5,,21.94308853149414\r",
539+
"2,5,,21.94308853149414\r\n",
540540
"\r\n",
541-
"2,5,22.244243621826172,\r",
541+
"2,5,22.244243621826172,\r\n",
542542
"\r\n",
543-
"3,7,,21.702259063720703\r",
543+
"3,7,,21.702259063720703\r\n",
544544
"\r\n",
545-
"3,7,22.196012496948242,\r",
545+
"3,7,22.196012496948242,\r\n",
546546
"\r\n"
547547
]
548548
}
@@ -898,7 +898,7 @@
898898
" ) -> None:\n",
899899
" logits = outputs[\"logits\"]\n",
900900
" if logits is None:\n",
901-
" logits = pl_module.torch_model.encode_sessions(batch[\"x\"], pl_module.item_embs)[:, -1, :]\n",
901+
" logits = pl_module.torch_model.encode_sessions(batch, pl_module.item_embs)[:, -1, :]\n",
902902
" _, sorted_batch_recos = logits.topk(k=self.top_k)\n",
903903
"\n",
904904
" batch_recos = sorted_batch_recos.tolist()\n",
@@ -2039,9 +2039,9 @@
20392039
],
20402040
"metadata": {
20412041
"kernelspec": {
2042-
"display_name": "rectools",
2042+
"display_name": ".venv",
20432043
"language": "python",
2044-
"name": "rectools"
2044+
"name": "python3"
20452045
},
20462046
"language_info": {
20472047
"codemirror_mode": {
@@ -2053,7 +2053,7 @@
20532053
"name": "python",
20542054
"nbconvert_exporter": "python",
20552055
"pygments_lexer": "ipython3",
2056-
"version": "3.9.12"
2056+
"version": "3.10.13"
20572057
}
20582058
},
20592059
"nbformat": 4,

rectools/models/nn/transformers/base.py

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@
4747
TransformerLayersBase,
4848
)
4949
from .similarity import DistanceSimilarityModule, SimilarityModuleBase
50-
from .torch_backbone import TransformerTorchBackbone
50+
from .torch_backbone import TransformerBackboneBase, TransformerTorchBackbone
5151

5252
InitKwargs = tp.Dict[str, tp.Any]
5353

@@ -108,6 +108,16 @@ def _serialize_type_sequence(obj: tp.Sequence[tp.Type]) -> tp.Tuple[str, ...]:
108108
),
109109
]
110110

111+
TransformerBackboneType = tpe.Annotated[
112+
tp.Type[TransformerBackboneBase],
113+
BeforeValidator(_get_class_obj),
114+
PlainSerializer(
115+
func=get_class_or_function_full_path,
116+
return_type=str,
117+
when_used="json",
118+
),
119+
]
120+
111121
TransformerDataPreparatorType = tpe.Annotated[
112122
tp.Type[TransformerDataPreparatorBase],
113123
BeforeValidator(_get_class_obj),
@@ -195,6 +205,7 @@ class TransformerModelConfig(ModelConfig):
195205
transformer_layers_type: TransformerLayersType = PreLNTransformerLayers
196206
lightning_module_type: TransformerLightningModuleType = TransformerLightningModule
197207
similarity_module_type: SimilarityModuleType = DistanceSimilarityModule
208+
backbone_type: TransformerBackboneType = TransformerTorchBackbone
198209
get_val_mask_func: tp.Optional[ValMaskCallableSerialized] = None
199210
get_trainer_func: tp.Optional[TrainerCallableSerialized] = None
200211
data_preparator_kwargs: tp.Optional[InitKwargs] = None
@@ -203,6 +214,7 @@ class TransformerModelConfig(ModelConfig):
203214
pos_encoding_kwargs: tp.Optional[InitKwargs] = None
204215
lightning_module_kwargs: tp.Optional[InitKwargs] = None
205216
similarity_module_kwargs: tp.Optional[InitKwargs] = None
217+
backbone_kwargs: tp.Optional[InitKwargs] = None
206218

207219

208220
TransformerModelConfig_T = tp.TypeVar("TransformerModelConfig_T", bound=TransformerModelConfig)
@@ -251,6 +263,7 @@ def __init__( # pylint: disable=too-many-arguments, too-many-locals
251263
pos_encoding_type: tp.Type[PositionalEncodingBase] = LearnableInversePositionalEncoding,
252264
lightning_module_type: tp.Type[TransformerLightningModuleBase] = TransformerLightningModule,
253265
similarity_module_type: tp.Type[SimilarityModuleBase] = DistanceSimilarityModule,
266+
backbone_type: tp.Type[TransformerBackboneBase] = TransformerTorchBackbone,
254267
get_val_mask_func: tp.Optional[ValMaskCallable] = None,
255268
get_trainer_func: tp.Optional[TrainerCallable] = None,
256269
data_preparator_kwargs: tp.Optional[InitKwargs] = None,
@@ -259,6 +272,7 @@ def __init__( # pylint: disable=too-many-arguments, too-many-locals
259272
pos_encoding_kwargs: tp.Optional[InitKwargs] = None,
260273
lightning_module_kwargs: tp.Optional[InitKwargs] = None,
261274
similarity_module_kwargs: tp.Optional[InitKwargs] = None,
275+
backbone_kwargs: tp.Optional[InitKwargs] = None,
262276
**kwargs: tp.Any,
263277
) -> None:
264278
super().__init__(verbose=verbose)
@@ -288,6 +302,7 @@ def __init__( # pylint: disable=too-many-arguments, too-many-locals
288302
self.item_net_constructor_type = item_net_constructor_type
289303
self.pos_encoding_type = pos_encoding_type
290304
self.lightning_module_type = lightning_module_type
305+
self.backbone_type = backbone_type
291306
self.get_val_mask_func = get_val_mask_func
292307
self.get_trainer_func = get_trainer_func
293308
self.data_preparator_kwargs = data_preparator_kwargs
@@ -296,6 +311,7 @@ def __init__( # pylint: disable=too-many-arguments, too-many-locals
296311
self.pos_encoding_kwargs = pos_encoding_kwargs
297312
self.lightning_module_kwargs = lightning_module_kwargs
298313
self.similarity_module_kwargs = similarity_module_kwargs
314+
self.backbone_kwargs = backbone_kwargs
299315

300316
self._init_data_preparator()
301317
self._init_trainer()
@@ -377,11 +393,11 @@ def _init_transformer_layers(self) -> TransformerLayersBase:
377393
def _init_similarity_module(self) -> SimilarityModuleBase:
378394
return self.similarity_module_type(**self._get_kwargs(self.similarity_module_kwargs))
379395

380-
def _init_torch_model(self, item_model: ItemNetBase) -> TransformerTorchBackbone:
396+
def _init_torch_model(self, item_model: ItemNetBase) -> TransformerBackboneBase:
381397
pos_encoding_layer = self._init_pos_encoding_layer()
382398
transformer_layers = self._init_transformer_layers()
383399
similarity_module = self._init_similarity_module()
384-
return TransformerTorchBackbone(
400+
return self.backbone_type(
385401
n_heads=self.n_heads,
386402
dropout_rate=self.dropout_rate,
387403
item_model=item_model,
@@ -390,11 +406,12 @@ def _init_torch_model(self, item_model: ItemNetBase) -> TransformerTorchBackbone
390406
similarity_module=similarity_module,
391407
use_causal_attn=self.use_causal_attn,
392408
use_key_padding_mask=self.use_key_padding_mask,
409+
**self._get_kwargs(self.backbone_kwargs),
393410
)
394411

395412
def _init_lightning_model(
396413
self,
397-
torch_model: TransformerTorchBackbone,
414+
torch_model: TransformerBackboneBase,
398415
dataset_schema: DatasetSchemaDict,
399416
item_external_ids: ExternalIds,
400417
model_config: tp.Dict[str, tp.Any],
@@ -490,7 +507,7 @@ def _recommend_i2i(
490507
)
491508

492509
@property
493-
def torch_model(self) -> TransformerTorchBackbone:
510+
def torch_model(self) -> TransformerBackboneBase:
494511
"""Pytorch model."""
495512
return self.lightning_model.torch_model
496513

rectools/models/nn/transformers/bert4rec.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
TransformerLayersBase,
4646
)
4747
from .similarity import DistanceSimilarityModule, SimilarityModuleBase
48+
from .torch_backbone import TransformerBackboneBase, TransformerTorchBackbone
4849

4950

5051
class BERT4RecDataPreparator(TransformerDataPreparatorBase):
@@ -259,6 +260,8 @@ class BERT4RecModel(TransformerModelBase[BERT4RecModelConfig]):
259260
Type of lightning module defining training procedure.
260261
similarity_module_type : type(SimilarityModuleBase), default `DistanceSimilarityModule`
261262
Type of similarity module.
263+
backbone_type : type(TransformerBackboneBase), default `TransformerTorchBackbone`
264+
Type of torch backbone.
262265
get_val_mask_func : Callable, default ``None``
263266
Function to get validation mask.
264267
get_trainer_func : Callable, default ``None``
@@ -295,6 +298,9 @@ class BERT4RecModel(TransformerModelBase[BERT4RecModelConfig]):
295298
similarity_module_kwargs: optional(dict), default ``None``
296299
Additional keyword arguments to pass during `similarity_module_type` initialization.
297300
Make sure all dict values have JSON serializable types.
301+
backbone_kwargs: optional(dict), default ``None``
302+
Additional keyword arguments to pass during `backbone_type` initialization.
303+
Make sure all dict values have JSON serializable types.
298304
"""
299305

300306
config_class = BERT4RecModelConfig
@@ -327,6 +333,7 @@ def __init__( # pylint: disable=too-many-arguments, too-many-locals
327333
data_preparator_type: tp.Type[TransformerDataPreparatorBase] = BERT4RecDataPreparator,
328334
lightning_module_type: tp.Type[TransformerLightningModuleBase] = TransformerLightningModule,
329335
similarity_module_type: tp.Type[SimilarityModuleBase] = DistanceSimilarityModule,
336+
backbone_type: tp.Type[TransformerBackboneBase] = TransformerTorchBackbone,
330337
get_val_mask_func: tp.Optional[ValMaskCallable] = None,
331338
get_trainer_func: tp.Optional[TrainerCallable] = None,
332339
recommend_batch_size: int = 256,
@@ -340,6 +347,7 @@ def __init__( # pylint: disable=too-many-arguments, too-many-locals
340347
pos_encoding_kwargs: tp.Optional[InitKwargs] = None,
341348
lightning_module_kwargs: tp.Optional[InitKwargs] = None,
342349
similarity_module_kwargs: tp.Optional[InitKwargs] = None,
350+
backbone_kwargs: tp.Optional[InitKwargs] = None,
343351
):
344352
self.mask_prob = mask_prob
345353

@@ -373,6 +381,7 @@ def __init__( # pylint: disable=too-many-arguments, too-many-locals
373381
item_net_constructor_type=item_net_constructor_type,
374382
pos_encoding_type=pos_encoding_type,
375383
lightning_module_type=lightning_module_type,
384+
backbone_type=backbone_type,
376385
get_val_mask_func=get_val_mask_func,
377386
get_trainer_func=get_trainer_func,
378387
data_preparator_kwargs=data_preparator_kwargs,
@@ -382,6 +391,7 @@ def __init__( # pylint: disable=too-many-arguments, too-many-locals
382391
pos_encoding_kwargs=pos_encoding_kwargs,
383392
lightning_module_kwargs=lightning_module_kwargs,
384393
similarity_module_kwargs=similarity_module_kwargs,
394+
backbone_kwargs=backbone_kwargs,
385395
)
386396

387397
def _init_data_preparator(self) -> None:

rectools/models/nn/transformers/lightning.py

Lines changed: 14 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from rectools.types import InternalIdsArray
2727

2828
from .data_preparator import TransformerDataPreparatorBase
29-
from .torch_backbone import TransformerTorchBackbone
29+
from .torch_backbone import TransformerBackboneBase
3030

3131
# #### -------------- Lightning Base Model -------------- #### #
3232

@@ -38,7 +38,7 @@ class TransformerLightningModuleBase(LightningModule): # pylint: disable=too-ma
3838
3939
Parameters
4040
----------
41-
torch_model : TransformerTorchBackbone
41+
torch_model : TransformerBackboneBase
4242
Torch model to make recommendations.
4343
lr : float
4444
Learning rate.
@@ -61,7 +61,7 @@ class TransformerLightningModuleBase(LightningModule): # pylint: disable=too-ma
6161

6262
def __init__(
6363
self,
64-
torch_model: TransformerTorchBackbone,
64+
torch_model: TransformerBackboneBase,
6565
model_config: tp.Dict[str, tp.Any],
6666
dataset_schema: DatasetSchemaDict,
6767
item_external_ids: ExternalIds,
@@ -250,13 +250,12 @@ def on_train_start(self) -> None:
250250

251251
def get_batch_logits(self, batch: tp.Dict[str, torch.Tensor]) -> torch.Tensor:
252252
"""Get bacth logits."""
253-
x = batch["x"] # x: [batch_size, session_max_len]
254253
if self._requires_negatives:
255254
y, negatives = batch["y"], batch["negatives"]
256255
pos_neg = torch.cat([y.unsqueeze(-1), negatives], dim=-1)
257-
logits = self.torch_model(sessions=x, item_ids=pos_neg)
256+
logits = self.torch_model(batch=batch, candidate_item_ids=pos_neg)
258257
else:
259-
logits = self.torch_model(sessions=x)
258+
logits = self.torch_model(batch=batch)
260259
return logits
261260

262261
def training_step(self, batch: tp.Dict[str, torch.Tensor], batch_idx: int) -> torch.Tensor:
@@ -338,7 +337,8 @@ def _get_user_item_embeddings(
338337
item_embs = self.torch_model.item_model.get_all_embeddings()
339338
user_embs = []
340339
for batch in recommend_dataloader:
341-
batch_embs = self.torch_model.encode_sessions(batch["x"].to(device), item_embs)[:, -1, :]
340+
batch = {k: v.to(device) for k, v in batch.items()}
341+
batch_embs = self.torch_model.encode_sessions(batch, item_embs)[:, -1, :]
342342
user_embs.append(batch_embs)
343343

344344
return torch.cat(user_embs), item_embs
@@ -360,17 +360,14 @@ def _recommend_u2i(
360360

361361
user_embs, item_embs = self._get_user_item_embeddings(recommend_dataloader, torch_device)
362362

363-
all_user_ids, all_reco_ids, all_scores = (
364-
self.torch_model.similarity_module._recommend_u2i( # pylint: disable=protected-access
365-
user_embs=user_embs,
366-
item_embs=item_embs,
367-
user_ids=user_ids,
368-
k=k,
369-
sorted_item_ids_to_recommend=sorted_item_ids_to_recommend,
370-
ui_csr_for_filter=ui_csr_for_filter,
371-
)
363+
return self.torch_model.similarity_module._recommend_u2i( # pylint: disable=protected-access
364+
user_embs=user_embs,
365+
item_embs=item_embs,
366+
user_ids=user_ids,
367+
k=k,
368+
sorted_item_ids_to_recommend=sorted_item_ids_to_recommend,
369+
ui_csr_for_filter=ui_csr_for_filter,
372370
)
373-
return all_user_ids, all_reco_ids, all_scores
374371

375372
def _recommend_i2i(
376373
self,

rectools/models/nn/transformers/sasrec.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
TransformerLayersBase,
4646
)
4747
from .similarity import DistanceSimilarityModule, SimilarityModuleBase
48+
from .torch_backbone import TransformerBackboneBase, TransformerTorchBackbone
4849

4950

5051
class SASRecDataPreparator(TransformerDataPreparatorBase):
@@ -339,6 +340,8 @@ class SASRecModel(TransformerModelBase[SASRecModelConfig]):
339340
Type of lightning module defining training procedure.
340341
similarity_module_type : type(SimilarityModuleBase), default `DistanceSimilarityModule`
341342
Type of similarity module.
343+
backbone_type : type(TransformerBackboneBase), default `TransformerTorchBackbone`
344+
Type of torch backbone.
342345
get_val_mask_func : Callable, default ``None``
343346
Function to get validation mask.
344347
get_trainer_func : Callable, default ``None``
@@ -375,6 +378,9 @@ class SASRecModel(TransformerModelBase[SASRecModelConfig]):
375378
similarity_module_kwargs: optional(dict), default ``None``
376379
Additional keyword arguments to pass during `similarity_module_type` initialization.
377380
Make sure all dict values have JSON serializable types.
381+
backbone_kwargs: optional(dict), default ``None``
382+
Additional keyword arguments to pass during `backbone_type` initialization.
383+
Make sure all dict values have JSON serializable types.
378384
"""
379385

380386
config_class = SASRecModelConfig
@@ -406,6 +412,7 @@ def __init__( # pylint: disable=too-many-arguments, too-many-locals
406412
data_preparator_type: tp.Type[TransformerDataPreparatorBase] = SASRecDataPreparator,
407413
lightning_module_type: tp.Type[TransformerLightningModuleBase] = TransformerLightningModule,
408414
similarity_module_type: tp.Type[SimilarityModuleBase] = DistanceSimilarityModule,
415+
backbone_type: tp.Type[TransformerBackboneBase] = TransformerTorchBackbone,
409416
get_val_mask_func: tp.Optional[ValMaskCallable] = None,
410417
get_trainer_func: tp.Optional[TrainerCallable] = None,
411418
recommend_batch_size: int = 256,
@@ -418,6 +425,7 @@ def __init__( # pylint: disable=too-many-arguments, too-many-locals
418425
pos_encoding_kwargs: tp.Optional[InitKwargs] = None,
419426
lightning_module_kwargs: tp.Optional[InitKwargs] = None,
420427
similarity_module_kwargs: tp.Optional[InitKwargs] = None,
428+
backbone_kwargs: tp.Optional[InitKwargs] = None,
421429
):
422430
super().__init__(
423431
transformer_layers_type=transformer_layers_type,
@@ -449,6 +457,7 @@ def __init__( # pylint: disable=too-many-arguments, too-many-locals
449457
item_net_constructor_type=item_net_constructor_type,
450458
pos_encoding_type=pos_encoding_type,
451459
lightning_module_type=lightning_module_type,
460+
backbone_type=backbone_type,
452461
get_val_mask_func=get_val_mask_func,
453462
get_trainer_func=get_trainer_func,
454463
data_preparator_kwargs=data_preparator_kwargs,
@@ -457,4 +466,5 @@ def __init__( # pylint: disable=too-many-arguments, too-many-locals
457466
pos_encoding_kwargs=pos_encoding_kwargs,
458467
lightning_module_kwargs=lightning_module_kwargs,
459468
similarity_module_kwargs=similarity_module_kwargs,
469+
backbone_kwargs=backbone_kwargs,
460470
)

0 commit comments

Comments
 (0)