3232from .negative_sampler import TransformerNegativeSamplerBase
3333
3434InitKwargs = tp .Dict [str , tp .Any ]
35+ # (user session, session weights, extra columns)
36+ BatchElement = tp .Tuple [tp .List [int ], tp .List [float ], tp .Dict [str , tp .List [tp .Any ]]]
3537
3638
3739class SequenceDataset (TorchDataset ):
@@ -46,17 +48,26 @@ class SequenceDataset(TorchDataset):
4648 Weight of each interaction from the session.
4749 """
4850
49- def __init__ (self , sessions : tp .List [tp .List [int ]], weights : tp .List [tp .List [float ]]):
51+ def __init__ (
52+ self ,
53+ sessions : tp .List [tp .List [int ]],
54+ weights : tp .List [tp .List [float ]],
55+ extras : tp .Optional [tp .Dict [str , tp .List [tp .Any ]]] = None ,
56+ ):
5057 self .sessions = sessions
5158 self .weights = weights
59+ self .extras = extras
5260
5361 def __len__ (self ) -> int :
5462 return len (self .sessions )
5563
56- def __getitem__ (self , index : int ) -> tp . Tuple [ tp . List [ int ], tp . List [ float ]] :
64+ def __getitem__ (self , index : int ) -> BatchElement :
5765 session = self .sessions [index ] # [session_len]
5866 weights = self .weights [index ] # [session_len]
59- return session , weights
67+ extras = (
68+ {feature_name : features [index ] for feature_name , features in self .extras .items ()} if self .extras else {}
69+ )
70+ return session , weights , extras
6071
6172 @classmethod
6273 def from_interactions (
@@ -73,17 +84,19 @@ def from_interactions(
7384 interactions : pd.DataFrame
7485 User-item interactions.
7586 """
87+ cols_to_agg = [col for col in interactions .columns if col != Columns .User ]
7688 sessions = (
7789 interactions .sort_values (Columns .Datetime , kind = "stable" )
78- .groupby (Columns .User , sort = sort_users )[[ Columns . Item , Columns . Weight ] ]
90+ .groupby (Columns .User , sort = sort_users )[cols_to_agg ]
7991 .agg (list )
8092 )
81- sessions , weights = (
93+ sessions_items , weights = (
8294 sessions [Columns .Item ].to_list (),
8395 sessions [Columns .Weight ].to_list (),
8496 )
85-
86- return cls (sessions = sessions , weights = weights )
97+ extra_cols = [col for col in interactions .columns if col not in Columns .Interactions ]
98+ extras = {col : sessions [col ].to_list () for col in extra_cols } if len (extra_cols ) > 0 else None
99+ return cls (sessions = sessions_items , weights = weights , extras = extras )
87100
88101
89102class TransformerDataPreparatorBase : # pylint: disable=too-many-instance-attributes
@@ -114,6 +127,8 @@ class TransformerDataPreparatorBase: # pylint: disable=too-many-instance-attrib
114127 get_val_mask_func_kwargs: optional(InitKwargs), default ``None``
115128 Additional keyword arguments for the get_val_mask_func.
116129 Make sure all dict values have JSON serializable types.
130+ extra_cols: optional(List[str]), default ``None``
131+ Extra columns to keep in train and recommend datasets.
117132 """
118133
119134 # We sometimes need data preparators to add +1 to actual session_max_len
@@ -133,6 +148,7 @@ def __init__(
133148 n_negatives : tp .Optional [int ] = None ,
134149 negative_sampler : tp .Optional [TransformerNegativeSamplerBase ] = None ,
135150 get_val_mask_func_kwargs : tp .Optional [InitKwargs ] = None ,
151+ extra_cols : tp .Optional [tp .List [str ]] = None ,
136152 ** kwargs : tp .Any ,
137153 ) -> None :
138154 self .item_id_map : IdMap
@@ -148,6 +164,7 @@ def __init__(
148164 self .shuffle_train = shuffle_train
149165 self .get_val_mask_func = get_val_mask_func
150166 self .get_val_mask_func_kwargs = get_val_mask_func_kwargs
167+ self .extra_cols = extra_cols
151168
152169 def get_known_items_sorted_internal_ids (self ) -> np .ndarray :
153170 """Return internal item ids from processed dataset in sorted order."""
@@ -203,7 +220,8 @@ def _filter_train_interactions(self, train_interactions: pd.DataFrame) -> pd.Dat
203220
204221 def process_dataset_train (self , dataset : Dataset ) -> None :
205222 """Process train dataset and save data."""
206- raw_interactions = dataset .get_raw_interactions ()
223+ extra_cols = False if self .extra_cols is None else self .extra_cols
224+ raw_interactions = dataset .get_raw_interactions (include_extra_cols = extra_cols )
207225
208226 # Exclude val interaction targets from train if needed
209227 interactions = raw_interactions
@@ -231,7 +249,12 @@ def process_dataset_train(self, dataset: Dataset) -> None:
231249
232250 # Prepare train dataset
233251 # User features are dropped for now because model doesn't support them
234- final_interactions = Interactions .from_raw (interactions , user_id_map , item_id_map , keep_extra_cols = True )
252+ final_interactions = Interactions .from_raw (
253+ interactions ,
254+ user_id_map ,
255+ item_id_map ,
256+ keep_extra_cols = True ,
257+ )
235258 self .train_dataset = Dataset (user_id_map , item_id_map , final_interactions , item_features = item_features )
236259 self .item_id_map = self .train_dataset .item_id_map
237260 self ._init_extra_token_ids ()
@@ -246,7 +269,9 @@ def process_dataset_train(self, dataset: Dataset) -> None:
246269 val_interactions = interactions [interactions [Columns .User ].isin (val_targets [Columns .User ].unique ())].copy ()
247270 val_interactions [Columns .Weight ] = 0
248271 val_interactions = pd .concat ([val_interactions , val_targets ], axis = 0 )
249- self .val_interactions = Interactions .from_raw (val_interactions , user_id_map , item_id_map ).df
272+ self .val_interactions = Interactions .from_raw (
273+ val_interactions , user_id_map , item_id_map , keep_extra_cols = True
274+ ).df
250275
251276 def _init_extra_token_ids (self ) -> None :
252277 extra_token_ids = self .item_id_map .convert_to_internal (self .item_extra_tokens )
@@ -340,7 +365,10 @@ def transform_dataset_u2i(self, dataset: Dataset, users: ExternalIds) -> Dataset
340365 Final item_id_map is model item_id_map constructed during training.
341366 """
342367 # Filter interactions in dataset internal ids
343- interactions = dataset .interactions .df
368+ required_cols = Columns .Interactions
369+ if self .extra_cols is not None :
370+ required_cols = required_cols + self .extra_cols
371+ interactions = dataset .interactions .df [required_cols ]
344372 users_internal = dataset .user_id_map .convert_to_internal (users , strict = False )
345373 items_internal = dataset .item_id_map .convert_to_internal (self .get_known_item_ids (), strict = False )
346374 interactions = interactions [interactions [Columns .User ].isin (users_internal )]
@@ -359,7 +387,9 @@ def transform_dataset_u2i(self, dataset: Dataset, users: ExternalIds) -> Dataset
359387 if n_filtered > 0 :
360388 explanation = f"""{ n_filtered } target users were considered cold because of missing known items"""
361389 warnings .warn (explanation )
362- filtered_interactions = Interactions .from_raw (interactions , rec_user_id_map , self .item_id_map )
390+ filtered_interactions = Interactions .from_raw (
391+ interactions , rec_user_id_map , self .item_id_map , keep_extra_cols = True
392+ )
363393 filtered_dataset = Dataset (rec_user_id_map , self .item_id_map , filtered_interactions )
364394 return filtered_dataset
365395
@@ -381,26 +411,29 @@ def transform_dataset_i2i(self, dataset: Dataset) -> Dataset:
381411 Final user_id_map is the same as dataset original.
382412 Final item_id_map is model item_id_map constructed during training.
383413 """
384- interactions = dataset .get_raw_interactions ()
414+ extra_cols = False if self .extra_cols is None else self .extra_cols
415+ interactions = dataset .get_raw_interactions (include_extra_cols = extra_cols )
385416 interactions = interactions [interactions [Columns .Item ].isin (self .get_known_item_ids ())]
386- filtered_interactions = Interactions .from_raw (interactions , dataset .user_id_map , self .item_id_map )
417+ filtered_interactions = Interactions .from_raw (
418+ interactions , dataset .user_id_map , self .item_id_map , keep_extra_cols = True
419+ )
387420 filtered_dataset = Dataset (dataset .user_id_map , self .item_id_map , filtered_interactions )
388421 return filtered_dataset
389422
390423 def _collate_fn_train (
391424 self ,
392- batch : tp .List [tp . Tuple [ tp . List [ int ], tp . List [ float ]] ],
425+ batch : tp .List [BatchElement ],
393426 ) -> tp .Dict [str , torch .Tensor ]:
394427 raise NotImplementedError ()
395428
396429 def _collate_fn_val (
397430 self ,
398- batch : tp .List [tp . Tuple [ tp . List [ int ], tp . List [ float ]] ],
431+ batch : tp .List [BatchElement ],
399432 ) -> tp .Dict [str , torch .Tensor ]:
400433 raise NotImplementedError ()
401434
402435 def _collate_fn_recommend (
403436 self ,
404- batch : tp .List [tp . Tuple [ tp . List [ int ], tp . List [ float ]] ],
437+ batch : tp .List [BatchElement ],
405438 ) -> tp .Dict [str , torch .Tensor ]:
406439 raise NotImplementedError ()
0 commit comments