diff --git a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/fused_ops/liger_ce/fused_linear_cross_entropy_loss.py b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/fused_ops/liger_ce/fused_linear_cross_entropy_loss.py index edc655f6..a855024d 100644 --- a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/fused_ops/liger_ce/fused_linear_cross_entropy_loss.py +++ b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/fused_ops/liger_ce/fused_linear_cross_entropy_loss.py @@ -295,6 +295,7 @@ def lce_forward( return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, num_logits_to_keep: int = 0, + num_items_in_batch = None, ) -> Union[Tuple, CausalLMOutputWithPast]: r""" Copy paste llama forward but replace torch cross entropy with liger fused linear cross entropy