Skip to content

Commit f6116e6

Browse files
authored
fix: liger fail to run loss with new param (#124)
Signed-off-by: Anh Uong <[email protected]>
1 parent 24bdadb commit f6116e6

File tree

1 file changed

+1
-0
lines changed

1 file changed

+1
-0
lines changed

plugins/fused-ops-and-kernels/src/fms_acceleration_foak/fused_ops/liger_ce/fused_linear_cross_entropy_loss.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -295,6 +295,7 @@ def lce_forward(
295295
return_dict: Optional[bool] = None,
296296
cache_position: Optional[torch.LongTensor] = None,
297297
num_logits_to_keep: int = 0,
298+
num_items_in_batch = None,
298299
) -> Union[Tuple, CausalLMOutputWithPast]:
299300
r"""
300301
Copy paste llama forward but replace torch cross entropy with liger fused linear cross entropy

0 commit comments

Comments
 (0)