Skip to content

Commit 200581d

Browse files
authored
Bug fixing (#441)
1 parent 22c70ad commit 200581d

File tree

2 files changed

+1
-6
lines changed

2 files changed

+1
-6
lines changed

tests/layers/test_lm_head.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def _reverse_kl_loss(
3939
loss_per_sample = torch.nn.functional.kl_div(
4040
teacher_log_probs, student_log_probs, reduction="none", log_target=True
4141
).sum(dim=-1)
42-
loss = (loss_per_sample * loss_mask.flatten()).sum() / loss_mask.sum()
42+
loss = (loss_per_sample * loss_mask.flatten()).mean()
4343
return loss
4444

4545

tests/layers/test_ssm.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,6 @@
1919
Apriel2GatedDeltaNet = None
2020
Apriel2Mamba = None
2121

22-
try:
23-
from fast_llm_external_models.apriel_hybrid_ssm.modeling_apriel_hybrid_ssm import KimiDeltaAttention
24-
except ImportError:
25-
KimiDeltaAttention = None
26-
2722
HIDDEN_SIZE = 16
2823
SEQ_LEN = 65
2924

0 commit comments

Comments
 (0)