Skip to content

Commit 19d2921

Browse files
authored
LiGR layers (#295)
added LiGR layers
1 parent 46deae3 commit 19d2921

File tree

5 files changed

+369
-4
lines changed

5 files changed

+369
-4
lines changed

CHANGELOG.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,11 @@ All notable changes to this project will be documented in this file.
55
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
66
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
77

8+
## Unreleased
9+
10+
### Added
11+
- LiGR transformer layers from "From Features to Transformers: Redefining Ranking for Scalable Impact" ([#295](https://github.com/MobileTeleSystems/RecTools/pull/295))
12+
813
## [0.16.0] - 27.07.2025
914

1015
### Added

README.md

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,15 @@ faster than ever before.
3333
- In [HSTU tutorial](examples/tutorials/transformers_HSTU_tutorial.ipynb) we show that original metrics reported for HSTU on public Movielens datasets may actually be **underestimated**
3434
- Configurable, customizable, callback-friendly, checkpoints-included, logs-out-of-the-box, custom-validation-ready, multi-gpu-compatible! See [Transformers Advanced Training User Guide](examples/tutorials/transformers_advanced_training_guide.ipynb) and [Transformers Customization Guide](examples/tutorials/transformers_customization_guide.ipynb)
3535

36+
37+
## ✨ Highlights: RecTools framework at ACM RecSys'25 ✨
38+
39+
**RecTools implementations are featured in ACM RecSys'25: ["eSASRec: Enhancing Transformer-based Recommendations in a Modular Fashion"](https://www.arxiv.org/abs/2508.06450):**
40+
- The article presents a systematic benchmark of Transformer modifications using RecTools models. It offers a detailed evaluation of training objectives, Transformer architectures, loss functions, and negative sampling strategies in realistic, production-like settings
41+
- We introduce a new SOTA baseline, **eSASRec**, which combines SASRec’s training objective with LiGR Transformer layers and Sampled Softmax loss, forming a simple yet powerful recipe
42+
- **eSASRec** shows 23% boost over SOTA models, such as ActionPiece, on academic benchmarks
43+
- [LiGR](https://arxiv.org/pdf/2502.03417) Transformer layers used in **eSASRec** are now in RecTools
44+
3645
Plase note that we always compare the quality of our implementations to academic papers results. [Public benchmarks for transformer models SASRec and BERT4Rec](https://github.com/blondered/bert4rec_repro?tab=readme-ov-file#rectools-transformers-benchmark-results) show that RecTools implementations achieve highest scores on multiple datasets compared to other published results.
3746

3847

@@ -107,7 +116,7 @@ The table below lists recommender models that are available in RecTools.
107116
| Model | Type | Description (🎏 for user/item features, 🔆 for warm inference, ❄️ for cold inference support) | Tutorials & Benchmarks |
108117
|---------------------|----|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------|--------|
109118
| HSTU | Neural Network | `rectools.models.HSTUModel` - Sequential model with unidirectional pointwise aggregated attention mechanism, incorporating relative attention bias from positional and temporal information, introduced in ["Actions speak louder then words..."](https://arxiv.org/pdf/2402.17152), combined with "Shifted Sequence" training objective as in original public benchmarks<br>🎏 | 📓 [HSTU Theory & Practice](examples/tutorials/transformers_HSTU_tutorial.ipynb) <br> 📕 [Transformers Theory & Practice](examples/tutorials/transformers_tutorial.ipynb)<br> 📗 [Advanced training guide](examples/tutorials/transformers_advanced_training_guide.ipynb) <br> 🚀 [Top performance on public datasets](examples/tutorials/transformers_HSTU_tutorial.ipynb)
110-
| SASRec | Neural Network | `rectools.models.SASRecModel` - Transformer-based sequential model with unidirectional attention mechanism and "Shifted Sequence" training objective <br>🎏 | 📕 [Transformers Theory & Practice](examples/tutorials/transformers_tutorial.ipynb)<br> 📗 [Advanced training guide](examples/tutorials/transformers_advanced_training_guide.ipynb) <br> 📘 [Customization guide](examples/tutorials/transformers_customization_guide.ipynb) <br> 🚀 [Top performance on public benchmarks](https://github.com/blondered/bert4rec_repro?tab=readme-ov-file#rectools-transformers-benchmark-results) |
119+
| SASRec | Neural Network | `rectools.models.SASRecModel` - Transformer-based sequential model with unidirectional attention mechanism and "Shifted Sequence" training objective. <br> For eSASRec variant specify `rectools.models.nn.transformers.ligr.LiGRLayers` for `transformer_layers_type` and `sampled_softmax` for `loss` <br>🎏 | 📕 [Transformers Theory & Practice](examples/tutorials/transformers_tutorial.ipynb)<br> 📗 [Advanced training guide](examples/tutorials/transformers_advanced_training_guide.ipynb) <br> 📘 [Customization guide](examples/tutorials/transformers_customization_guide.ipynb) <br> 🚀 [Top performance on public benchmarks](https://github.com/blondered/bert4rec_repro?tab=readme-ov-file#rectools-transformers-benchmark-results) |
111120
| BERT4Rec | Neural Network | `rectools.models.BERT4RecModel` - Transformer-based sequential model with bidirectional attention mechanism and "MLM" (masked item) training objective <br>🎏 | 📕 [Transformers Theory & Practice](examples/tutorials/transformers_tutorial.ipynb)<br> 📗 [Advanced training guide](examples/tutorials/transformers_advanced_training_guide.ipynb) <br> 📘 [Customization guide](examples/tutorials/transformers_customization_guide.ipynb) <br> 🚀 [Top performance on public benchmarks](https://github.com/blondered/bert4rec_repro?tab=readme-ov-file#rectools-transformers-benchmark-results) |
112121
| [implicit](https://github.com/benfred/implicit) ALS Wrapper | Matrix Factorization | `rectools.models.ImplicitALSWrapperModel` - Alternating Least Squares Matrix Factorizattion algorithm for implicit feedback. <br>🎏 | 📙 [Theory & Practice](https://rectools.readthedocs.io/en/latest/examples/tutorials/baselines_extended_tutorial.html#Implicit-ALS)<br> 🚀 [50% boost to metrics with user & item features](examples/5_benchmark_iALS_with_features.ipynb) |
113122
| [implicit](https://github.com/benfred/implicit) BPR-MF Wrapper | Matrix Factorization | `rectools.models.ImplicitBPRWrapperModel` - Bayesian Personalized Ranking Matrix Factorization algorithm. | 📙 [Theory & Practice](https://rectools.readthedocs.io/en/latest/examples/tutorials/baselines_extended_tutorial.html#Bayesian-Personalized-Ranking-Matrix-Factorization-(BPR-MF)) |
Lines changed: 177 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,177 @@
1+
import typing as tp
2+
3+
import torch
4+
from torch import nn
5+
6+
from rectools.models.nn.transformers.net_blocks import TransformerLayersBase
7+
8+
from .net_blocks import init_feed_forward
9+
10+
11+
class LiGRLayer(nn.Module):
12+
"""
13+
Transformer Layer as described in "From Features to Transformers:
14+
Redefining Ranking for Scalable Impact" https://arxiv.org/pdf/2502.03417
15+
16+
Parameters
17+
----------
18+
n_factors: int
19+
Latent embeddings size.
20+
n_heads: int
21+
Number of attention heads.
22+
dropout_rate: float
23+
Probability of a hidden unit to be zeroed.
24+
ff_factors_multiplier: int, default 4
25+
Feed-forward layers latent embedding size multiplier.
26+
bias_in_ff: bool, default ``False``
27+
Add bias in Linear layers of Feed Forward
28+
ff_activation: {"swiglu", "relu", "gelu"}, default "swiglu"
29+
Activation function to use.
30+
"""
31+
32+
def __init__(
33+
self,
34+
n_factors: int,
35+
n_heads: int,
36+
dropout_rate: float,
37+
ff_factors_multiplier: int = 4,
38+
bias_in_ff: bool = False,
39+
ff_activation: str = "swiglu",
40+
):
41+
super().__init__()
42+
self.multi_head_attn = nn.MultiheadAttention(n_factors, n_heads, dropout_rate, batch_first=True)
43+
self.layer_norm_1 = nn.LayerNorm(n_factors)
44+
self.dropout_1 = nn.Dropout(dropout_rate)
45+
self.layer_norm_2 = nn.LayerNorm(n_factors)
46+
self.feed_forward = init_feed_forward(n_factors, ff_factors_multiplier, dropout_rate, ff_activation, bias_in_ff)
47+
self.dropout_2 = nn.Dropout(dropout_rate)
48+
49+
self.gating_linear_1 = nn.Linear(n_factors, n_factors)
50+
self.gating_linear_2 = nn.Linear(n_factors, n_factors)
51+
52+
def forward(
53+
self,
54+
seqs: torch.Tensor,
55+
attn_mask: tp.Optional[torch.Tensor],
56+
key_padding_mask: tp.Optional[torch.Tensor],
57+
) -> torch.Tensor:
58+
"""
59+
Forward pass through transformer block.
60+
61+
Parameters
62+
----------
63+
seqs: torch.Tensor
64+
User sequences of item embeddings.
65+
attn_mask: torch.Tensor, optional
66+
Optional mask to use in forward pass of multi-head attention as `attn_mask`.
67+
key_padding_mask: torch.Tensor, optional
68+
Optional mask to use in forward pass of multi-head attention as `key_padding_mask`.
69+
70+
71+
Returns
72+
-------
73+
torch.Tensor
74+
User sequences passed through transformer layers.
75+
"""
76+
mha_input = self.layer_norm_1(seqs)
77+
mha_output, _ = self.multi_head_attn(
78+
mha_input,
79+
mha_input,
80+
mha_input,
81+
attn_mask=attn_mask,
82+
key_padding_mask=key_padding_mask,
83+
need_weights=False,
84+
)
85+
gated_skip = torch.nn.functional.sigmoid(self.gating_linear_1(seqs))
86+
seqs = seqs + torch.mul(gated_skip, self.dropout_1(mha_output))
87+
88+
ff_input = self.layer_norm_2(seqs)
89+
ff_output = self.feed_forward(ff_input)
90+
gated_skip = torch.nn.functional.sigmoid(self.gating_linear_2(seqs))
91+
seqs = seqs + torch.mul(gated_skip, self.dropout_2(ff_output))
92+
return seqs
93+
94+
95+
class LiGRLayers(TransformerLayersBase):
96+
"""
97+
LiGR Transformer blocks.
98+
99+
Parameters
100+
----------
101+
n_blocks: int
102+
Number of transformer blocks.
103+
n_factors: int
104+
Latent embeddings size.
105+
n_heads: int
106+
Number of attention heads.
107+
dropout_rate: float
108+
Probability of a hidden unit to be zeroed.
109+
ff_factors_multiplier: int, default 4
110+
Feed-forward layers latent embedding size multiplier. Pass in ``transformer_layers_kwargs`` to override.
111+
ff_activation: {"swiglu", "relu", "gelu"}, default "swiglu"
112+
Activation function to use. Pass in ``transformer_layers_kwargs`` to override.
113+
bias_in_ff: bool, default ``False``
114+
Add bias in Linear layers of Feed Forward. Pass in ``transformer_layers_kwargs`` to override.
115+
"""
116+
117+
def __init__(
118+
self,
119+
n_blocks: int,
120+
n_factors: int,
121+
n_heads: int,
122+
dropout_rate: float,
123+
ff_factors_multiplier: int = 4,
124+
ff_activation: str = "swiglu",
125+
bias_in_ff: bool = False,
126+
):
127+
super().__init__()
128+
self.n_blocks = n_blocks
129+
self.n_factors = n_factors
130+
self.n_heads = n_heads
131+
self.dropout_rate = dropout_rate
132+
self.ff_factors_multiplier = ff_factors_multiplier
133+
self.ff_activation = ff_activation
134+
self.bias_in_ff = bias_in_ff
135+
self.transformer_blocks = nn.ModuleList([self._init_transformer_block() for _ in range(self.n_blocks)])
136+
137+
def _init_transformer_block(self) -> nn.Module:
138+
return LiGRLayer(
139+
self.n_factors,
140+
self.n_heads,
141+
self.dropout_rate,
142+
self.ff_factors_multiplier,
143+
bias_in_ff=self.bias_in_ff,
144+
ff_activation=self.ff_activation,
145+
)
146+
147+
def forward(
148+
self,
149+
seqs: torch.Tensor,
150+
timeline_mask: torch.Tensor,
151+
attn_mask: tp.Optional[torch.Tensor],
152+
key_padding_mask: tp.Optional[torch.Tensor],
153+
**kwargs: tp.Any,
154+
) -> torch.Tensor:
155+
"""
156+
Forward pass through transformer blocks.
157+
158+
Parameters
159+
----------
160+
seqs: torch.Tensor
161+
User sequences of item embeddings.
162+
timeline_mask: torch.Tensor
163+
Mask indicating padding elements.
164+
attn_mask: torch.Tensor, optional
165+
Optional mask to use in forward pass of multi-head attention as `attn_mask`.
166+
key_padding_mask: torch.Tensor, optional
167+
Optional mask to use in forward pass of multi-head attention as `key_padding_mask`.
168+
169+
170+
Returns
171+
-------
172+
torch.Tensor
173+
User sequences passed through transformer layers.
174+
"""
175+
for block_idx in range(self.n_blocks):
176+
seqs = self.transformer_blocks[block_idx](seqs, attn_mask, key_padding_mask)
177+
return seqs

rectools/models/nn/transformers/net_blocks.py

Lines changed: 93 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,14 +33,18 @@ class PointWiseFeedForward(nn.Module):
3333
Probability of a hidden unit to be zeroed.
3434
activation: torch.nn.Module
3535
Activation function module.
36+
bias: bool, default ``True``
37+
If ``True``, add bias to linear layers.
3638
"""
3739

38-
def __init__(self, n_factors: int, n_factors_ff: int, dropout_rate: float, activation: torch.nn.Module) -> None:
40+
def __init__(
41+
self, n_factors: int, n_factors_ff: int, dropout_rate: float, activation: torch.nn.Module, bias: bool = True
42+
) -> None:
3943
super().__init__()
40-
self.ff_linear_1 = nn.Linear(n_factors, n_factors_ff)
44+
self.ff_linear_1 = nn.Linear(n_factors, n_factors_ff, bias)
4145
self.ff_dropout_1 = torch.nn.Dropout(dropout_rate)
4246
self.ff_activation = activation
43-
self.ff_linear_2 = nn.Linear(n_factors_ff, n_factors)
47+
self.ff_linear_2 = nn.Linear(n_factors_ff, n_factors, bias)
4448

4549
def forward(self, seqs: torch.Tensor) -> torch.Tensor:
4650
"""
@@ -61,6 +65,92 @@ def forward(self, seqs: torch.Tensor) -> torch.Tensor:
6165
return fin
6266

6367

68+
class SwigluFeedForward(nn.Module):
69+
"""
70+
Feed-Forward network to introduce nonlinearity into the transformer model.
71+
This implementation is based on FuXi and LLama SwigLU https://arxiv.org/pdf/2502.03036,
72+
LiGR https://arxiv.org/pdf/2502.03417
73+
74+
Parameters
75+
----------
76+
n_factors : int
77+
Latent embeddings size.
78+
n_factors_ff : int
79+
How many hidden units to use in the network.
80+
dropout_rate : float
81+
Probability of a hidden unit to be zeroed.
82+
bias: bool, default ``True``
83+
If ``True``, add bias to linear layers.
84+
"""
85+
86+
def __init__(self, n_factors: int, n_factors_ff: int, dropout_rate: float, bias: bool = True) -> None:
87+
super().__init__()
88+
self.ff_linear_1 = nn.Linear(n_factors, n_factors_ff, bias=bias)
89+
self.ff_dropout_1 = torch.nn.Dropout(dropout_rate)
90+
self.ff_activation = torch.nn.SiLU()
91+
self.ff_linear_2 = nn.Linear(n_factors_ff, n_factors, bias=bias)
92+
self.ff_linear_3 = nn.Linear(n_factors, n_factors_ff, bias=bias)
93+
94+
def forward(self, seqs: torch.Tensor) -> torch.Tensor:
95+
"""
96+
Forward pass.
97+
98+
Parameters
99+
----------
100+
seqs : torch.Tensor
101+
User sequences of item embeddings.
102+
103+
Returns
104+
-------
105+
torch.Tensor
106+
User sequence that passed through all layers.
107+
"""
108+
output = self.ff_activation(self.ff_linear_1(seqs)) * self.ff_linear_3(seqs)
109+
fin = self.ff_linear_2(self.ff_dropout_1(output))
110+
return fin
111+
112+
113+
def init_feed_forward(
114+
n_factors: int, ff_factors_multiplier: int, dropout_rate: float, ff_activation: str, bias: bool = True
115+
) -> nn.Module:
116+
"""
117+
Initialise Feed-Forward network with one of activation functions: "swiglu", "relu", "gelu".
118+
119+
Parameters
120+
----------
121+
n_factors : int
122+
Latent embeddings size.
123+
ff_factors_multiplier : int
124+
How many hidden units to use in the network.
125+
dropout_rate : float
126+
Probability of a hidden unit to be zeroed.
127+
ff_activation : {"swiglu", "relu", "gelu"}
128+
Activation function to use.
129+
bias: bool, default ``True``
130+
If ``True``, add bias to linear layers.
131+
132+
Returns
133+
-------
134+
nn.Module
135+
Feed-Forward network.
136+
"""
137+
if ff_activation == "swiglu":
138+
return SwigluFeedForward(n_factors, n_factors * ff_factors_multiplier, dropout_rate, bias=bias)
139+
if ff_activation == "gelu":
140+
return PointWiseFeedForward(
141+
n_factors, n_factors * ff_factors_multiplier, dropout_rate, activation=torch.nn.GELU(), bias=bias
142+
)
143+
if ff_activation == "relu":
144+
return PointWiseFeedForward(
145+
n_factors,
146+
n_factors * ff_factors_multiplier,
147+
dropout_rate,
148+
activation=torch.nn.ReLU(),
149+
bias=bias,
150+
)
151+
raise ValueError(f"Unsupported ff_activation: {ff_activation}")
152+
153+
64154
class TransformerLayersBase(nn.Module):
65155
"""Base class for transformer layers."""
66156

0 commit comments

Comments
 (0)