diff --git a/tests/layers/test_lm_head.py b/tests/layers/test_lm_head.py index 623a30d82..a34aa1b36 100644 --- a/tests/layers/test_lm_head.py +++ b/tests/layers/test_lm_head.py @@ -39,7 +39,7 @@ def _reverse_kl_loss( loss_per_sample = torch.nn.functional.kl_div( teacher_log_probs, student_log_probs, reduction="none", log_target=True ).sum(dim=-1) - loss = (loss_per_sample * loss_mask.flatten()).sum() / loss_mask.sum() + loss = (loss_per_sample * loss_mask.flatten()).mean() return loss diff --git a/tests/layers/test_ssm.py b/tests/layers/test_ssm.py index b371ba086..1d968b7fb 100644 --- a/tests/layers/test_ssm.py +++ b/tests/layers/test_ssm.py @@ -19,11 +19,6 @@ Apriel2GatedDeltaNet = None Apriel2Mamba = None -try: - from fast_llm_external_models.apriel_hybrid_ssm.modeling_apriel_hybrid_ssm import KimiDeltaAttention -except ImportError: - KimiDeltaAttention = None - HIDDEN_SIZE = 16 SEQ_LEN = 65