Skip to content

Commit 3d075fd

Browse files
authored
define shift_labels in gemma (#961)
## Summary Minor fix: undefined variable in Gemma3/Paligemma <!--- ## Details This is an optional section; is there anything specific that reviewers should be aware of? ---> ## Testing Done - Hardware Type: 2x NVIDIA RTX 5880 - [x] run `make test` to ensure correctness - [x] run `make checkstyle` to ensure code style - [x] run `make test-convergence` to ensure convergence Signed-off-by: Alexandros Koumparoulis <[email protected]>
1 parent 62c0544 commit 3d075fd

File tree

2 files changed

+2
-0
lines changed

2 files changed

+2
-0
lines changed

src/liger_kernel/transformers/model/gemma3.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -235,6 +235,7 @@ def multimodal_forward(
235235
**lm_kwargs,
236236
)
237237

238+
shift_labels = lm_kwargs.pop("shift_labels", None)
238239
hidden_states = outputs[0]
239240

240241
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep

src/liger_kernel/transformers/model/paligemma.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -330,6 +330,7 @@ def lce_forward(
330330
**lm_kwargs,
331331
)
332332

333+
shift_labels = lm_kwargs.pop("shift_labels", None)
333334
hidden_states = outputs[0]
334335

335336
loss = None

0 commit comments

Comments
 (0)