Skip to content

Commit b47695f

Browse files
authored
Fix negative sampler config (#278)
Fixed negative sampler kwargs in transformer model config
1 parent 7cbc4ce commit b47695f

File tree

3 files changed

+3
-0
lines changed

3 files changed

+3
-0
lines changed

rectools/models/nn/transformers/base.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -225,6 +225,7 @@ class TransformerModelConfig(ModelConfig):
225225
item_net_constructor_kwargs: tp.Optional[InitKwargs] = None
226226
pos_encoding_kwargs: tp.Optional[InitKwargs] = None
227227
lightning_module_kwargs: tp.Optional[InitKwargs] = None
228+
negative_sampler_kwargs: tp.Optional[InitKwargs] = None
228229
similarity_module_kwargs: tp.Optional[InitKwargs] = None
229230
backbone_kwargs: tp.Optional[InitKwargs] = None
230231

tests/models/nn/transformers/test_bert4rec.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -865,6 +865,7 @@ def initial_config(self) -> tp.Dict[str, tp.Any]:
865865
"item_net_constructor_kwargs": None,
866866
"pos_encoding_kwargs": None,
867867
"lightning_module_kwargs": None,
868+
"negative_sampler_kwargs": None,
868869
"similarity_module_kwargs": None,
869870
"backbone_kwargs": None,
870871
}

tests/models/nn/transformers/test_sasrec.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -973,6 +973,7 @@ def initial_config(self) -> tp.Dict[str, tp.Any]:
973973
"item_net_constructor_kwargs": None,
974974
"pos_encoding_kwargs": None,
975975
"lightning_module_kwargs": None,
976+
"negative_sampler_kwargs": None,
976977
"similarity_module_kwargs": None,
977978
"backbone_kwargs": None,
978979
}

0 commit comments

Comments
 (0)