Skip to content

Commit 3e8db67

Browse files
authored
Feature/transformers_batch_extras (#287)
- `extra_cols` argument for `TransformerDataPreparatorBase` - optional `extras` key in batch for Transformer-based models - `item_tower_forward` and `session_tower_forward` methods for `SimilarityModuleBase`
1 parent ea266cd commit 3e8db67

File tree

10 files changed

+170
-96
lines changed

10 files changed

+170
-96
lines changed

CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,10 @@ All notable changes to this project will be documented in this file.
55
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
66
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
77

8+
## Unreleased
9+
### Added
10+
- `extras` argument to `SequenceDataset`, `extra_cols` argument to `TransformerDataPreparatorBase`, `session_tower_forward` and `item_tower_forward` methods to `SimilarityModuleBase` ([#287](https://github.com/MobileTeleSystems/RecTools/pull/287))
11+
812
## [0.14.0] - 16.05.2025
913

1014
### Added

rectools/dataset/dataset.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -348,7 +348,10 @@ def get_user_item_matrix(
348348
return matrix
349349

350350
def get_raw_interactions(
351-
self, include_weight: bool = True, include_datetime: bool = True, include_extra_cols: bool = True
351+
self,
352+
include_weight: bool = True,
353+
include_datetime: bool = True,
354+
include_extra_cols: tp.Union[bool, tp.List[str]] = True,
352355
) -> pd.DataFrame:
353356
"""
354357
Return interactions as a `pd.DataFrame` object with replacing internal user and item ids to external ones.

rectools/dataset/interactions.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,7 @@ def to_external(
167167
item_id_map: IdMap,
168168
include_weight: bool = True,
169169
include_datetime: bool = True,
170-
include_extra_cols: bool = True,
170+
include_extra_cols: tp.Union[bool, tp.List[str]] = True,
171171
) -> pd.DataFrame:
172172
"""
173173
Convert itself to `pd.DataFrame` with replacing internal user and item ids to external ones.
@@ -182,8 +182,9 @@ def to_external(
182182
Whether to include weight column into resulting table or not
183183
include_datetime : bool, default ``True``
184184
Whether to include datetime column into resulting table or not.
185-
include_extra_cols: bool, default ``True``
186-
Whether to include extra columns into resulting table or not.
185+
include_extra_cols: bool or List[str], default ``True``
186+
If bool, indicates whether to include all extra columns into resulting table or not.
187+
If list of strings, indicates which extra columns to include into resulting table.
187188
188189
Returns
189190
-------
@@ -201,9 +202,13 @@ def to_external(
201202
cols_to_add.append(Columns.Weight)
202203
if include_datetime:
203204
cols_to_add.append(Columns.Datetime)
204-
if include_extra_cols:
205+
206+
extra_cols = []
207+
if isinstance(include_extra_cols, list):
208+
extra_cols = [col for col in include_extra_cols if col in self.df and col not in Columns.Interactions]
209+
elif include_extra_cols:
205210
extra_cols = [col for col in self.df if col not in Columns.Interactions]
206-
cols_to_add.extend(extra_cols)
211+
cols_to_add.extend(extra_cols)
207212

208213
for col in cols_to_add:
209214
res[col] = self.df[col]

rectools/models/nn/transformers/bert4rec.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
ValMaskCallable,
3737
)
3838
from .constants import MASKING_VALUE, PADDING_VALUE
39-
from .data_preparator import InitKwargs, TransformerDataPreparatorBase
39+
from .data_preparator import BatchElement, InitKwargs, TransformerDataPreparatorBase
4040
from .negative_sampler import CatalogUniformSampler, TransformerNegativeSamplerBase
4141
from .net_blocks import (
4242
LearnableInversePositionalEncoding,
@@ -128,7 +128,7 @@ def _mask_session(
128128

129129
def _collate_fn_train(
130130
self,
131-
batch: List[Tuple[List[int], List[float]]],
131+
batch: tp.List[BatchElement],
132132
) -> Dict[str, torch.Tensor]:
133133
"""
134134
Mask session elements to receive `x`.
@@ -141,7 +141,7 @@ def _collate_fn_train(
141141
x = np.zeros((batch_size, self.session_max_len))
142142
y = np.zeros((batch_size, self.session_max_len))
143143
yw = np.zeros((batch_size, self.session_max_len))
144-
for i, (ses, ses_weights) in enumerate(batch):
144+
for i, (ses, ses_weights, _) in enumerate(batch):
145145
masked_session, target = self._mask_session(ses)
146146
x[i, -len(ses) :] = masked_session # ses: [session_len] -> x[i]: [session_max_len]
147147
y[i, -len(ses) :] = target # ses: [session_len] -> y[i]: [session_max_len]
@@ -154,12 +154,12 @@ def _collate_fn_train(
154154
)
155155
return batch_dict
156156

157-
def _collate_fn_val(self, batch: List[Tuple[List[int], List[float]]]) -> Dict[str, torch.Tensor]:
157+
def _collate_fn_val(self, batch: tp.List[BatchElement]) -> Dict[str, torch.Tensor]:
158158
batch_size = len(batch)
159159
x = np.zeros((batch_size, self.session_max_len))
160160
y = np.zeros((batch_size, 1)) # until only leave-one-strategy
161161
yw = np.zeros((batch_size, 1)) # until only leave-one-strategy
162-
for i, (ses, ses_weights) in enumerate(batch):
162+
for i, (ses, ses_weights, _) in enumerate(batch):
163163
input_session = [ses[idx] for idx, weight in enumerate(ses_weights) if weight == 0]
164164
session = input_session.copy()
165165

@@ -179,14 +179,14 @@ def _collate_fn_val(self, batch: List[Tuple[List[int], List[float]]]) -> Dict[st
179179
)
180180
return batch_dict
181181

182-
def _collate_fn_recommend(self, batch: List[Tuple[List[int], List[float]]]) -> Dict[str, torch.Tensor]:
182+
def _collate_fn_recommend(self, batch: tp.List[BatchElement]) -> Dict[str, torch.Tensor]:
183183
"""
184184
Right truncation, left padding to `session_max_len`
185185
During inference model will use (`session_max_len` - 1) interactions
186186
and one extra "MASK" token will be added for making predictions.
187187
"""
188188
x = np.zeros((len(batch), self.session_max_len))
189-
for i, (ses, _) in enumerate(batch):
189+
for i, (ses, _, _) in enumerate(batch):
190190
session = ses.copy()
191191
session = session + [self.extra_token_ids[MASKING_VALUE]]
192192
x[i, -len(ses) - 1 :] = session[-self.session_max_len :]

rectools/models/nn/transformers/data_preparator.py

Lines changed: 50 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@
3232
from .negative_sampler import TransformerNegativeSamplerBase
3333

3434
InitKwargs = tp.Dict[str, tp.Any]
35+
# (user session, session weights, extra columns)
36+
BatchElement = tp.Tuple[tp.List[int], tp.List[float], tp.Dict[str, tp.List[tp.Any]]]
3537

3638

3739
class SequenceDataset(TorchDataset):
@@ -46,17 +48,26 @@ class SequenceDataset(TorchDataset):
4648
Weight of each interaction from the session.
4749
"""
4850

49-
def __init__(self, sessions: tp.List[tp.List[int]], weights: tp.List[tp.List[float]]):
51+
def __init__(
52+
self,
53+
sessions: tp.List[tp.List[int]],
54+
weights: tp.List[tp.List[float]],
55+
extras: tp.Optional[tp.Dict[str, tp.List[tp.Any]]] = None,
56+
):
5057
self.sessions = sessions
5158
self.weights = weights
59+
self.extras = extras
5260

5361
def __len__(self) -> int:
5462
return len(self.sessions)
5563

56-
def __getitem__(self, index: int) -> tp.Tuple[tp.List[int], tp.List[float]]:
64+
def __getitem__(self, index: int) -> BatchElement:
5765
session = self.sessions[index] # [session_len]
5866
weights = self.weights[index] # [session_len]
59-
return session, weights
67+
extras = (
68+
{feature_name: features[index] for feature_name, features in self.extras.items()} if self.extras else {}
69+
)
70+
return session, weights, extras
6071

6172
@classmethod
6273
def from_interactions(
@@ -73,17 +84,19 @@ def from_interactions(
7384
interactions : pd.DataFrame
7485
User-item interactions.
7586
"""
87+
cols_to_agg = [col for col in interactions.columns if col != Columns.User]
7688
sessions = (
7789
interactions.sort_values(Columns.Datetime, kind="stable")
78-
.groupby(Columns.User, sort=sort_users)[[Columns.Item, Columns.Weight]]
90+
.groupby(Columns.User, sort=sort_users)[cols_to_agg]
7991
.agg(list)
8092
)
81-
sessions, weights = (
93+
sessions_items, weights = (
8294
sessions[Columns.Item].to_list(),
8395
sessions[Columns.Weight].to_list(),
8496
)
85-
86-
return cls(sessions=sessions, weights=weights)
97+
extra_cols = [col for col in interactions.columns if col not in Columns.Interactions]
98+
extras = {col: sessions[col].to_list() for col in extra_cols} if len(extra_cols) > 0 else None
99+
return cls(sessions=sessions_items, weights=weights, extras=extras)
87100

88101

89102
class TransformerDataPreparatorBase: # pylint: disable=too-many-instance-attributes
@@ -114,6 +127,8 @@ class TransformerDataPreparatorBase: # pylint: disable=too-many-instance-attrib
114127
get_val_mask_func_kwargs: optional(InitKwargs), default ``None``
115128
Additional keyword arguments for the get_val_mask_func.
116129
Make sure all dict values have JSON serializable types.
130+
extra_cols: optional(List[str]), default ``None``
131+
Extra columns to keep in train and recommend datasets.
117132
"""
118133

119134
# We sometimes need data preparators to add +1 to actual session_max_len
@@ -133,6 +148,7 @@ def __init__(
133148
n_negatives: tp.Optional[int] = None,
134149
negative_sampler: tp.Optional[TransformerNegativeSamplerBase] = None,
135150
get_val_mask_func_kwargs: tp.Optional[InitKwargs] = None,
151+
extra_cols: tp.Optional[tp.List[str]] = None,
136152
**kwargs: tp.Any,
137153
) -> None:
138154
self.item_id_map: IdMap
@@ -148,6 +164,7 @@ def __init__(
148164
self.shuffle_train = shuffle_train
149165
self.get_val_mask_func = get_val_mask_func
150166
self.get_val_mask_func_kwargs = get_val_mask_func_kwargs
167+
self.extra_cols = extra_cols
151168

152169
def get_known_items_sorted_internal_ids(self) -> np.ndarray:
153170
"""Return internal item ids from processed dataset in sorted order."""
@@ -203,7 +220,8 @@ def _filter_train_interactions(self, train_interactions: pd.DataFrame) -> pd.Dat
203220

204221
def process_dataset_train(self, dataset: Dataset) -> None:
205222
"""Process train dataset and save data."""
206-
raw_interactions = dataset.get_raw_interactions()
223+
extra_cols = False if self.extra_cols is None else self.extra_cols
224+
raw_interactions = dataset.get_raw_interactions(include_extra_cols=extra_cols)
207225

208226
# Exclude val interaction targets from train if needed
209227
interactions = raw_interactions
@@ -231,7 +249,12 @@ def process_dataset_train(self, dataset: Dataset) -> None:
231249

232250
# Prepare train dataset
233251
# User features are dropped for now because model doesn't support them
234-
final_interactions = Interactions.from_raw(interactions, user_id_map, item_id_map, keep_extra_cols=True)
252+
final_interactions = Interactions.from_raw(
253+
interactions,
254+
user_id_map,
255+
item_id_map,
256+
keep_extra_cols=True,
257+
)
235258
self.train_dataset = Dataset(user_id_map, item_id_map, final_interactions, item_features=item_features)
236259
self.item_id_map = self.train_dataset.item_id_map
237260
self._init_extra_token_ids()
@@ -246,7 +269,9 @@ def process_dataset_train(self, dataset: Dataset) -> None:
246269
val_interactions = interactions[interactions[Columns.User].isin(val_targets[Columns.User].unique())].copy()
247270
val_interactions[Columns.Weight] = 0
248271
val_interactions = pd.concat([val_interactions, val_targets], axis=0)
249-
self.val_interactions = Interactions.from_raw(val_interactions, user_id_map, item_id_map).df
272+
self.val_interactions = Interactions.from_raw(
273+
val_interactions, user_id_map, item_id_map, keep_extra_cols=True
274+
).df
250275

251276
def _init_extra_token_ids(self) -> None:
252277
extra_token_ids = self.item_id_map.convert_to_internal(self.item_extra_tokens)
@@ -340,7 +365,10 @@ def transform_dataset_u2i(self, dataset: Dataset, users: ExternalIds) -> Dataset
340365
Final item_id_map is model item_id_map constructed during training.
341366
"""
342367
# Filter interactions in dataset internal ids
343-
interactions = dataset.interactions.df
368+
required_cols = Columns.Interactions
369+
if self.extra_cols is not None:
370+
required_cols = required_cols + self.extra_cols
371+
interactions = dataset.interactions.df[required_cols]
344372
users_internal = dataset.user_id_map.convert_to_internal(users, strict=False)
345373
items_internal = dataset.item_id_map.convert_to_internal(self.get_known_item_ids(), strict=False)
346374
interactions = interactions[interactions[Columns.User].isin(users_internal)]
@@ -359,7 +387,9 @@ def transform_dataset_u2i(self, dataset: Dataset, users: ExternalIds) -> Dataset
359387
if n_filtered > 0:
360388
explanation = f"""{n_filtered} target users were considered cold because of missing known items"""
361389
warnings.warn(explanation)
362-
filtered_interactions = Interactions.from_raw(interactions, rec_user_id_map, self.item_id_map)
390+
filtered_interactions = Interactions.from_raw(
391+
interactions, rec_user_id_map, self.item_id_map, keep_extra_cols=True
392+
)
363393
filtered_dataset = Dataset(rec_user_id_map, self.item_id_map, filtered_interactions)
364394
return filtered_dataset
365395

@@ -381,26 +411,29 @@ def transform_dataset_i2i(self, dataset: Dataset) -> Dataset:
381411
Final user_id_map is the same as dataset original.
382412
Final item_id_map is model item_id_map constructed during training.
383413
"""
384-
interactions = dataset.get_raw_interactions()
414+
extra_cols = False if self.extra_cols is None else self.extra_cols
415+
interactions = dataset.get_raw_interactions(include_extra_cols=extra_cols)
385416
interactions = interactions[interactions[Columns.Item].isin(self.get_known_item_ids())]
386-
filtered_interactions = Interactions.from_raw(interactions, dataset.user_id_map, self.item_id_map)
417+
filtered_interactions = Interactions.from_raw(
418+
interactions, dataset.user_id_map, self.item_id_map, keep_extra_cols=True
419+
)
387420
filtered_dataset = Dataset(dataset.user_id_map, self.item_id_map, filtered_interactions)
388421
return filtered_dataset
389422

390423
def _collate_fn_train(
391424
self,
392-
batch: tp.List[tp.Tuple[tp.List[int], tp.List[float]]],
425+
batch: tp.List[BatchElement],
393426
) -> tp.Dict[str, torch.Tensor]:
394427
raise NotImplementedError()
395428

396429
def _collate_fn_val(
397430
self,
398-
batch: tp.List[tp.Tuple[tp.List[int], tp.List[float]]],
431+
batch: tp.List[BatchElement],
399432
) -> tp.Dict[str, torch.Tensor]:
400433
raise NotImplementedError()
401434

402435
def _collate_fn_recommend(
403436
self,
404-
batch: tp.List[tp.Tuple[tp.List[int], tp.List[float]]],
437+
batch: tp.List[BatchElement],
405438
) -> tp.Dict[str, torch.Tensor]:
406439
raise NotImplementedError()

rectools/models/nn/transformers/lightning.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -387,7 +387,9 @@ def _get_user_item_embeddings(
387387
for batch in recommend_dataloader:
388388
batch = {k: v.to(device) for k, v in batch.items()}
389389
batch_embs = self.torch_model.encode_sessions(batch, item_embs)[:, -1, :]
390+
batch_embs = self.torch_model.similarity_module.session_tower_forward(batch_embs)
390391
user_embs.append(batch_embs.cpu())
392+
item_embs = self.torch_model.similarity_module.item_tower_forward(item_embs)
391393

392394
return torch.cat(user_embs), item_embs
393395

rectools/models/nn/transformers/sasrec.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414

1515
import typing as tp
16-
from typing import Dict, List, Tuple
16+
from typing import Dict
1717

1818
import numpy as np
1919
import torch
@@ -36,7 +36,7 @@
3636
TransformerModelConfig,
3737
ValMaskCallable,
3838
)
39-
from .data_preparator import InitKwargs, TransformerDataPreparatorBase
39+
from .data_preparator import BatchElement, InitKwargs, TransformerDataPreparatorBase
4040
from .negative_sampler import CatalogUniformSampler, TransformerNegativeSamplerBase
4141
from .net_blocks import (
4242
LearnableInversePositionalEncoding,
@@ -80,7 +80,7 @@ class SASRecDataPreparator(TransformerDataPreparatorBase):
8080

8181
def _collate_fn_train(
8282
self,
83-
batch: List[Tuple[List[int], List[float]]],
83+
batch: tp.List[BatchElement],
8484
) -> Dict[str, torch.Tensor]:
8585
"""
8686
Truncate each session from right to keep `session_max_len` items.
@@ -91,7 +91,7 @@ def _collate_fn_train(
9191
x = np.zeros((batch_size, self.session_max_len))
9292
y = np.zeros((batch_size, self.session_max_len))
9393
yw = np.zeros((batch_size, self.session_max_len))
94-
for i, (ses, ses_weights) in enumerate(batch):
94+
for i, (ses, ses_weights, _) in enumerate(batch):
9595
x[i, -len(ses) + 1 :] = ses[:-1] # ses: [session_len] -> x[i]: [session_max_len]
9696
y[i, -len(ses) + 1 :] = ses[1:] # ses: [session_len] -> y[i]: [session_max_len]
9797
yw[i, -len(ses) + 1 :] = ses_weights[1:] # ses_weights: [session_len] -> yw[i]: [session_max_len]
@@ -103,12 +103,12 @@ def _collate_fn_train(
103103
)
104104
return batch_dict
105105

106-
def _collate_fn_val(self, batch: List[Tuple[List[int], List[float]]]) -> Dict[str, torch.Tensor]:
106+
def _collate_fn_val(self, batch: tp.List[BatchElement]) -> Dict[str, torch.Tensor]:
107107
batch_size = len(batch)
108108
x = np.zeros((batch_size, self.session_max_len))
109109
y = np.zeros((batch_size, 1)) # Only leave-one-strategy is supported for losses
110110
yw = np.zeros((batch_size, 1)) # Only leave-one-strategy is supported for losses
111-
for i, (ses, ses_weights) in enumerate(batch):
111+
for i, (ses, ses_weights, _) in enumerate(batch):
112112
input_session = [ses[idx] for idx, weight in enumerate(ses_weights) if weight == 0]
113113

114114
# take only first target for leave-one-strategy
@@ -126,10 +126,10 @@ def _collate_fn_val(self, batch: List[Tuple[List[int], List[float]]]) -> Dict[st
126126
)
127127
return batch_dict
128128

129-
def _collate_fn_recommend(self, batch: List[Tuple[List[int], List[float]]]) -> Dict[str, torch.Tensor]:
129+
def _collate_fn_recommend(self, batch: tp.List[BatchElement]) -> Dict[str, torch.Tensor]:
130130
"""Right truncation, left padding to session_max_len"""
131131
x = np.zeros((len(batch), self.session_max_len))
132-
for i, (ses, _) in enumerate(batch):
132+
for i, (ses, _, _) in enumerate(batch):
133133
x[i, -len(ses) :] = ses[-self.session_max_len :]
134134
return {"x": torch.LongTensor(x)}
135135

0 commit comments

Comments
 (0)