Skip to content

Commit 3b79375

Browse files
jp1924Tcc0403
andauthored
gemma3 consider loss_kwargs (#1007)
## Summary When applying the liger-kernel in SFTTrainer of the latest version of TRL (0.26.2), `return_token_accuracy` is also passed to input_data to compute `token_accuracy` alongside compute_loss. However, in Gemma3, `return_token_accuracy` is applied correctly during the loss step in causal_forward but not in multimodal_forward. Therefore, using inspect, I wrote code to separate only the kwagrs that can enter LCE from lm_kwagrs and pass them to loss_kwagrs. Using this, it functions correctly even in the latest version of trl. <!--- ## Details This is an optional section; is there anything specific that reviewers should be aware of? ---> ## Testing Done <!--- This is a required section; please describe how this change was tested. ---> <!-- Replace BLANK with your device type. For example, A100-80G-PCIe Complete the following tasks before sending your PR, and replace `[ ]` with `[x]` to indicate you have done them. --> - Hardware Type: <BLANK> - [ ] run `make test` to ensure correctness - [x] run `make checkstyle` to ensure code style - [ ] run `make test-convergence` to ensure convergence --------- Co-authored-by: Tcc0403 <76503978+Tcc0403@users.noreply.github.com>
1 parent 71ed8ac commit 3b79375

File tree

2 files changed

+15
-18
lines changed

2 files changed

+15
-18
lines changed

src/liger_kernel/transformers/model/gemma3.py

Lines changed: 9 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
from transformers.cache_utils import Cache
99
from transformers.utils import logging
1010

11-
from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss
1211
from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
1312
from liger_kernel.transformers.model.loss_utils import unpack_cross_entropy_result
1413
from liger_kernel.transformers.model.output_classes import LigerCausalLMOutputWithPast
@@ -268,23 +267,15 @@ def multimodal_forward(
268267
shift_hidden_states = shift_hidden_states.view(-1, self.config.text_config.hidden_size)
269268
shift_labels = shift_labels.view(-1).to(hidden_device)
270269

271-
# Extract loss-related kwargs for LigerFusedLinearCrossEntropyLoss
272-
lce_param_keys = {
273-
"ce_weight",
274-
"ignore_index",
275-
"lse_square_scale",
276-
"label_smoothing",
277-
"reduction",
278-
"softcap",
279-
"return_z_loss",
280-
"accum_dtype",
281-
"use_token_scaling",
282-
"return_token_accuracy",
283-
}
284-
lce_kwargs = {k: lm_kwargs.pop(k) for k in lce_param_keys if k in lm_kwargs}
285-
286-
lce = LigerFusedLinearCrossEntropyLoss(**lce_kwargs)
287-
result = lce(self.lm_head.weight, shift_hidden_states, shift_labels)
270+
result = LigerForCausalLMLoss(
271+
hidden_states=shift_hidden_states,
272+
lm_head_weight=self.lm_head.weight,
273+
labels=shift_labels,
274+
hidden_size=self.config.text_config.hidden_size,
275+
shift_labels=shift_labels,
276+
final_logit_softcapping=getattr(self.config.text_config, "final_logit_softcapping", None),
277+
**lm_kwargs,
278+
)
288279
loss, _, token_accuracy = unpack_cross_entropy_result(result)
289280

290281
else:

src/liger_kernel/transformers/model/loss_utils.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import inspect
2+
13
from typing import Optional
24
from typing import Tuple
35

@@ -71,6 +73,10 @@ def LigerForCausalLMLoss(
7173
return_token_accuracy: bool = False,
7274
**kwargs,
7375
):
76+
# Filter out inapplicable kwargs to liger_fused_linear_cross_entropy
77+
applicable_params = inspect.signature(F.liger_fused_linear_cross_entropy).parameters
78+
kwargs = {k: v for k, v in kwargs.items() if k in applicable_params}
79+
7480
# Skip upcast since intermediate values for the loss are all fp32 in kernel
7581
if shift_labels is None:
7682
# Shift so that token < n predict n

0 commit comments

Comments
 (0)