Skip to content

Commit 41d4bcf

Browse files
authored
pass param down to LigerFusedLinearCrossEntropyLoss (#1010)
## Summary <!--- This is a required section; please describe the main purpose of this proposed code change. ---> <!--- ## Details This is an optional section; is there anything specific that reviewers should be aware of? ---> ## Testing Done <!--- This is a required section; please describe how this change was tested. ---> <!-- Replace A100-80G-PCIe with your device type. For example, A100-80G-PCIe Complete the following tasks before sending your PR, and replace `[ ]` with `[x]` to indicate you have done them. --> - Hardware Type: <BLANK> - [x] run `make test` to ensure correctness - [x] run `make checkstyle` to ensure code style - [x] run `make test-convergence` to ensure convergence --------- Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>
1 parent f408175 commit 41d4bcf

File tree

1 file changed

+16
-1
lines changed

1 file changed

+16
-1
lines changed

src/liger_kernel/transformers/model/gemma3.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -268,7 +268,22 @@ def multimodal_forward(
268268
shift_hidden_states = shift_hidden_states.view(-1, self.config.text_config.hidden_size)
269269
shift_labels = shift_labels.view(-1).to(hidden_device)
270270

271-
lce = LigerFusedLinearCrossEntropyLoss()
271+
# Extract loss-related kwargs for LigerFusedLinearCrossEntropyLoss
272+
lce_param_keys = {
273+
"ce_weight",
274+
"ignore_index",
275+
"lse_square_scale",
276+
"label_smoothing",
277+
"reduction",
278+
"softcap",
279+
"return_z_loss",
280+
"accum_dtype",
281+
"use_token_scaling",
282+
"return_token_accuracy",
283+
}
284+
lce_kwargs = {k: lm_kwargs.pop(k) for k in lce_param_keys if k in lm_kwargs}
285+
286+
lce = LigerFusedLinearCrossEntropyLoss(**lce_kwargs)
272287
result = lce(self.lm_head.weight, shift_hidden_states, shift_labels)
273288
loss, _, token_accuracy = unpack_cross_entropy_result(result)
274289

0 commit comments

Comments
 (0)