From fd728a5ac550637d8148cc98d53da94a1e0549a8 Mon Sep 17 00:00:00 2001 From: Anh Uong Date: Tue, 11 Feb 2025 14:38:59 -0700 Subject: [PATCH] fix: liger fail to run loss with new param Signed-off-by: Anh Uong --- .../fused_ops/liger_ce/fused_linear_cross_entropy_loss.py | 1 + 1 file changed, 1 insertion(+) diff --git a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/fused_ops/liger_ce/fused_linear_cross_entropy_loss.py b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/fused_ops/liger_ce/fused_linear_cross_entropy_loss.py index edc655f6..a855024d 100644 --- a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/fused_ops/liger_ce/fused_linear_cross_entropy_loss.py +++ b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/fused_ops/liger_ce/fused_linear_cross_entropy_loss.py @@ -295,6 +295,7 @@ def lce_forward( return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, num_logits_to_keep: int = 0, + num_items_in_batch = None, ) -> Union[Tuple, CausalLMOutputWithPast]: r""" Copy paste llama forward but replace torch cross entropy with liger fused linear cross entropy