Skip to content

Commit 6b7fda0

Browse files
authored
Feature/Sampled_softmax (#274)
Added sampled softmax loss for transformer models
1 parent f15e7d3 commit 6b7fda0

File tree

7 files changed

+48
-4
lines changed

7 files changed

+48
-4
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,12 @@ All notable changes to this project will be documented in this file.
55
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
66
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
77

8+
## Unreleased
89

910
### Added
1011
- `SimilarityModuleBase`, `DistanceSimilarityModule`, similarity module to `TransformerTorchBackbone`, parameters to transformer-based models `similarity_module_type`, `similarity_module_kwargs` ([#272](https://github.com/MobileTeleSystems/RecTools/pull/272))
1112
- `out_dim` property to `IdEmbeddingsItemNet`, `CatFeaturesItemNet` and `SumOfEmbeddingsConstructor` ([#276](https://github.com/MobileTeleSystems/RecTools/pull/276))
13+
- `sampled_softmax` loss option for transformer models ([#274](https://github.com/MobileTeleSystems/RecTools/pull/274))
1214

1315
## [0.12.0] - 24.02.2025
1416

rectools/models/nn/transformers/bert4rec.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -209,7 +209,7 @@ class BERT4RecModel(TransformerModelBase[BERT4RecModelConfig]):
209209
train_min_user_interactions : int, default 2
210210
Minimum number of interactions user should have to be used for training. Should be greater
211211
than 1.
212-
loss : {"softmax", "BCE", "gBCE"}, default "softmax"
212+
loss : {"softmax", "BCE", "gBCE", "sampled_softmax"}, default "softmax"
213213
Loss function.
214214
n_negatives : int, default 1
215215
Number of negatives for BCE and gBCE losses.

rectools/models/nn/transformers/lightning.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ def requires_negatives(loss: str) -> tp.Optional[bool]:
102102
if loss == "softmax":
103103
return False
104104

105-
if loss in ["BCE", "gBCE"]:
105+
if loss in ["BCE", "gBCE", "sampled_softmax"]:
106106
return True
107107

108108
return None
@@ -120,6 +120,9 @@ def get_loss_calculator(
120120
if self.loss == "gBCE":
121121
return self._calc_gbce_loss
122122

123+
if self.loss == "sampled_softmax":
124+
return self._calc_sampled_softmax_loss
125+
123126
return None
124127

125128
@classmethod
@@ -185,6 +188,13 @@ def _calc_gbce_loss(self, logits: torch.Tensor, y: torch.Tensor, w: torch.Tensor
185188
loss = self._calc_bce_loss(logits, y, w)
186189
return loss
187190

191+
def _calc_sampled_softmax_loss(self, logits: torch.Tensor, y: torch.Tensor, w: torch.Tensor) -> torch.Tensor:
192+
# We put positive logits at index 1 since index 0 is used to ignore padding
193+
logits[:, :, [0, 1]] = logits[:, :, [1, 0]]
194+
target = (y != 0).long()
195+
loss = self._calc_softmax_loss(logits, target, w)
196+
return loss
197+
188198
def configure_optimizers(self) -> torch.optim.Adam:
189199
"""Choose what optimizers and learning-rate schedulers to use in optimization"""
190200
optimizer = torch.optim.Adam(self.torch_model.parameters(), lr=self.lr, betas=self.adam_betas)

rectools/models/nn/transformers/sasrec.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -289,7 +289,7 @@ class SASRecModel(TransformerModelBase[SASRecModelConfig]):
289289
train_min_user_interactions : int, default 2
290290
Minimum number of interactions user should have to be used for training. Should be greater
291291
than 1.
292-
loss : {"softmax", "BCE", "gBCE"}, default "softmax"
292+
loss : {"softmax", "BCE", "gBCE", "sampled_softmax"}, default "softmax"
293293
Loss function.
294294
n_negatives : int, default 1
295295
Number of negatives for BCE and gBCE losses.

tests/models/nn/transformers/test_base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -261,7 +261,7 @@ def test_raises_when_load_weights_from_checkpoint_not_fitted_model(
261261
(True, ["epoch", "step", "train_loss", "val_loss"]),
262262
),
263263
)
264-
@pytest.mark.parametrize("loss", ("softmax", "BCE", "gBCE"))
264+
@pytest.mark.parametrize("loss", ("softmax", "BCE", "gBCE", "sampled_softmax"))
265265
def test_log_metrics(
266266
self,
267267
model_cls: tp.Type[TransformerModelBase],

tests/models/nn/transformers/test_bert4rec.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -295,6 +295,16 @@ def get_trainer() -> Trainer:
295295
}
296296
),
297297
),
298+
(
299+
"sampled_softmax",
300+
pd.DataFrame(
301+
{
302+
Columns.User: [30, 40, 40],
303+
Columns.Item: [12, 12, 13],
304+
Columns.Rank: [1, 1, 2],
305+
}
306+
),
307+
),
298308
),
299309
)
300310
@pytest.mark.parametrize("u2i_dist", ("dot", "cosine"))

tests/models/nn/transformers/test_sasrec.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -326,6 +326,17 @@ def get_trainer() -> Trainer:
326326
),
327327
"dot",
328328
),
329+
(
330+
"sampled_softmax",
331+
pd.DataFrame(
332+
{
333+
Columns.User: [10, 10, 30, 30, 30, 40, 40, 40],
334+
Columns.Item: [17, 15, 13, 17, 14, 13, 14, 15],
335+
Columns.Rank: [1, 2, 1, 2, 3, 1, 2, 3],
336+
}
337+
),
338+
"dot",
339+
),
329340
(
330341
"BCE",
331342
pd.DataFrame(
@@ -348,6 +359,17 @@ def get_trainer() -> Trainer:
348359
),
349360
"cosine",
350361
),
362+
(
363+
"sampled_softmax",
364+
pd.DataFrame(
365+
{
366+
Columns.User: [10, 10, 30, 30, 30, 40, 40, 40],
367+
Columns.Item: [17, 15, 13, 14, 17, 13, 14, 15],
368+
Columns.Rank: [1, 2, 1, 2, 3, 1, 2, 3],
369+
}
370+
),
371+
"cosine",
372+
),
351373
),
352374
)
353375
def test_u2i_losses(

0 commit comments

Comments
 (0)