Commit 3b79375
gemma3 consider loss_kwargs (#1007)
## Summary
When applying the liger-kernel in SFTTrainer of the latest version of
TRL (0.26.2), `return_token_accuracy` is also passed to input_data to
compute `token_accuracy` alongside compute_loss.
However, in Gemma3, `return_token_accuracy` is applied correctly during
the loss step in causal_forward but not in multimodal_forward.
Therefore, using inspect, I wrote code to separate only the kwagrs that
can enter LCE from lm_kwagrs and pass them to loss_kwagrs.
Using this, it functions correctly even in the latest version of trl.
<!---
## Details
This is an optional section; is there anything specific that reviewers
should be aware of?
--->
## Testing Done
<!--- This is a required section; please describe how this change was
tested. --->
<!--
Replace BLANK with your device type. For example, A100-80G-PCIe
Complete the following tasks before sending your PR, and replace `[ ]`
with
`[x]` to indicate you have done them.
-->
- Hardware Type: <BLANK>
- [ ] run `make test` to ensure correctness
- [x] run `make checkstyle` to ensure code style
- [ ] run `make test-convergence` to ensure convergence
---------
Co-authored-by: Tcc0403 <76503978+Tcc0403@users.noreply.github.com>1 parent 71ed8ac commit 3b79375
File tree
2 files changed
+15
-18
lines changed- src/liger_kernel/transformers/model
2 files changed
+15
-18
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
8 | 8 | | |
9 | 9 | | |
10 | 10 | | |
11 | | - | |
12 | 11 | | |
13 | 12 | | |
14 | 13 | | |
| |||
268 | 267 | | |
269 | 268 | | |
270 | 269 | | |
271 | | - | |
272 | | - | |
273 | | - | |
274 | | - | |
275 | | - | |
276 | | - | |
277 | | - | |
278 | | - | |
279 | | - | |
280 | | - | |
281 | | - | |
282 | | - | |
283 | | - | |
284 | | - | |
285 | | - | |
286 | | - | |
287 | | - | |
| 270 | + | |
| 271 | + | |
| 272 | + | |
| 273 | + | |
| 274 | + | |
| 275 | + | |
| 276 | + | |
| 277 | + | |
| 278 | + | |
288 | 279 | | |
289 | 280 | | |
290 | 281 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
| 1 | + | |
| 2 | + | |
1 | 3 | | |
2 | 4 | | |
3 | 5 | | |
| |||
71 | 73 | | |
72 | 74 | | |
73 | 75 | | |
| 76 | + | |
| 77 | + | |
| 78 | + | |
| 79 | + | |
74 | 80 | | |
75 | 81 | | |
76 | 82 | | |
| |||
0 commit comments