4646 PreLNTransformerLayers ,
4747 TransformerLayersBase ,
4848)
49+ from .similarity import DistanceSimilarityModule , SimilarityModuleBase
4950from .torch_backbone import TransformerTorchBackbone
5051
5152InitKwargs = tp .Dict [str , tp .Any ]
@@ -97,6 +98,16 @@ def _serialize_type_sequence(obj: tp.Sequence[tp.Type]) -> tp.Tuple[str, ...]:
9798 ),
9899]
99100
101+ SimilarityModuleType = tpe .Annotated [
102+ tp .Type [SimilarityModuleBase ],
103+ BeforeValidator (_get_class_obj ),
104+ PlainSerializer (
105+ func = get_class_or_function_full_path ,
106+ return_type = str ,
107+ when_used = "json" ,
108+ ),
109+ ]
110+
100111TransformerDataPreparatorType = tpe .Annotated [
101112 tp .Type [TransformerDataPreparatorBase ],
102113 BeforeValidator (_get_class_obj ),
@@ -183,13 +194,15 @@ class TransformerModelConfig(ModelConfig):
183194 pos_encoding_type : PositionalEncodingType = LearnableInversePositionalEncoding
184195 transformer_layers_type : TransformerLayersType = PreLNTransformerLayers
185196 lightning_module_type : TransformerLightningModuleType = TransformerLightningModule
197+ similarity_module_type : SimilarityModuleType = DistanceSimilarityModule
186198 get_val_mask_func : tp .Optional [ValMaskCallableSerialized ] = None
187199 get_trainer_func : tp .Optional [TrainerCallableSerialized ] = None
188200 data_preparator_kwargs : tp .Optional [InitKwargs ] = None
189201 transformer_layers_kwargs : tp .Optional [InitKwargs ] = None
190202 item_net_constructor_kwargs : tp .Optional [InitKwargs ] = None
191203 pos_encoding_kwargs : tp .Optional [InitKwargs ] = None
192204 lightning_module_kwargs : tp .Optional [InitKwargs ] = None
205+ similarity_module_kwargs : tp .Optional [InitKwargs ] = None
193206
194207
195208TransformerModelConfig_T = tp .TypeVar ("TransformerModelConfig_T" , bound = TransformerModelConfig )
@@ -237,13 +250,15 @@ def __init__( # pylint: disable=too-many-arguments, too-many-locals
237250 item_net_constructor_type : tp .Type [ItemNetConstructorBase ] = SumOfEmbeddingsConstructor ,
238251 pos_encoding_type : tp .Type [PositionalEncodingBase ] = LearnableInversePositionalEncoding ,
239252 lightning_module_type : tp .Type [TransformerLightningModuleBase ] = TransformerLightningModule ,
253+ similarity_module_type : tp .Type [SimilarityModuleBase ] = DistanceSimilarityModule ,
240254 get_val_mask_func : tp .Optional [ValMaskCallable ] = None ,
241255 get_trainer_func : tp .Optional [TrainerCallable ] = None ,
242256 data_preparator_kwargs : tp .Optional [InitKwargs ] = None ,
243257 transformer_layers_kwargs : tp .Optional [InitKwargs ] = None ,
244258 item_net_constructor_kwargs : tp .Optional [InitKwargs ] = None ,
245259 pos_encoding_kwargs : tp .Optional [InitKwargs ] = None ,
246260 lightning_module_kwargs : tp .Optional [InitKwargs ] = None ,
261+ similarity_module_kwargs : tp .Optional [InitKwargs ] = None ,
247262 ** kwargs : tp .Any ,
248263 ) -> None :
249264 super ().__init__ (verbose = verbose )
@@ -268,6 +283,7 @@ def __init__( # pylint: disable=too-many-arguments, too-many-locals
268283 self .recommend_batch_size = recommend_batch_size
269284 self .recommend_torch_device = recommend_torch_device
270285 self .train_min_user_interactions = train_min_user_interactions
286+ self .similarity_module_type = similarity_module_type
271287 self .item_net_block_types = item_net_block_types
272288 self .item_net_constructor_type = item_net_constructor_type
273289 self .pos_encoding_type = pos_encoding_type
@@ -279,6 +295,7 @@ def __init__( # pylint: disable=too-many-arguments, too-many-locals
279295 self .item_net_constructor_kwargs = item_net_constructor_kwargs
280296 self .pos_encoding_kwargs = pos_encoding_kwargs
281297 self .lightning_module_kwargs = lightning_module_kwargs
298+ self .similarity_module_kwargs = similarity_module_kwargs
282299
283300 self ._init_data_preparator ()
284301 self ._init_trainer ()
@@ -295,12 +312,13 @@ def _get_kwargs(actual_kwargs: tp.Optional[InitKwargs]) -> InitKwargs:
295312 return kwargs
296313
297314 def _init_data_preparator (self ) -> None :
315+ requires_negatives = self .lightning_module_type .requires_negatives (self .loss )
298316 self .data_preparator = self .data_preparator_type (
299317 session_max_len = self .session_max_len ,
300318 batch_size = self .batch_size ,
301319 dataloader_num_workers = self .dataloader_num_workers ,
302320 train_min_user_interactions = self .train_min_user_interactions ,
303- n_negatives = self .n_negatives if self . loss != "softmax" else None ,
321+ n_negatives = self .n_negatives if requires_negatives else None ,
304322 get_val_mask_func = self .get_val_mask_func ,
305323 shuffle_train = True ,
306324 ** self ._get_kwargs (self .data_preparator_kwargs ),
@@ -356,15 +374,20 @@ def _init_transformer_layers(self) -> TransformerLayersBase:
356374 ** self ._get_kwargs (self .transformer_layers_kwargs ),
357375 )
358376
377+ def _init_similarity_module (self ) -> SimilarityModuleBase :
378+ return self .similarity_module_type (** self ._get_kwargs (self .similarity_module_kwargs ))
379+
359380 def _init_torch_model (self , item_model : ItemNetBase ) -> TransformerTorchBackbone :
360381 pos_encoding_layer = self ._init_pos_encoding_layer ()
361382 transformer_layers = self ._init_transformer_layers ()
383+ similarity_module = self ._init_similarity_module ()
362384 return TransformerTorchBackbone (
363385 n_heads = self .n_heads ,
364386 dropout_rate = self .dropout_rate ,
365387 item_model = item_model ,
366388 pos_encoding_layer = pos_encoding_layer ,
367389 transformer_layers = transformer_layers ,
390+ similarity_module = similarity_module ,
368391 use_causal_attn = self .use_causal_attn ,
369392 use_key_padding_mask = self .use_key_padding_mask ,
370393 )
0 commit comments