diff --git a/pytabkit/models/nn_models/tabr.py b/pytabkit/models/nn_models/tabr.py index 5fabd96..27f056a 100644 --- a/pytabkit/models/nn_models/tabr.py +++ b/pytabkit/models/nn_models/tabr.py @@ -388,7 +388,11 @@ def forward( probs = F.softmax(similarities, dim=-1) probs = self.dropout(probs) - context_y_emb = self.label_encoder(candidate_y[context_idx][..., None]) + context_y = candidate_y[context_idx][..., None] + if isinstance(self.label_encoder, nn.Sequential): + context_y_emb = self.label_encoder(context_y.long()) + else: + context_y_emb = self.label_encoder(context_y.float()) values = context_y_emb + self.T(k[:, None] - context_k) context_x = (probs[:, None] @ values).squeeze(1) x = x + context_x