Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions extensions/xla/finetune/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,9 @@ def train(
xm.mark_step()
# shift the targets such that output n predicts token n+1
logits[-1] = logits[-1][..., :-1, :]
# Remove empty chunks (can happen when last chunk has size 1)
if logits[-1].size(1) == 0:
logits = logits[:-1]
loss = chunked_cross_entropy(logits, targets[..., 1:])
fabric.backward(loss / gradient_accumulation_iters)
xm.mark_step()
Expand Down
3 changes: 3 additions & 0 deletions litgpt/finetune/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,9 @@ def fit(
logits = model(input_ids, lm_head_chunk_size=128)
# shift the targets such that output n predicts token n+1
logits[-1] = logits[-1][..., :-1, :]
# Remove empty chunks (can happen when last chunk has size 1)
if logits[-1].size(1) == 0:
logits = logits[:-1]
loss = chunked_cross_entropy(logits, targets[..., 1:])
fabric.backward(loss / train.gradient_accumulation_iters(devices, num_nodes))

Expand Down
3 changes: 3 additions & 0 deletions litgpt/finetune/adapter_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,9 @@ def fit(
logits = model(input_ids, lm_head_chunk_size=128)
# shift the targets such that output n predicts token n+1
logits[-1] = logits[-1][..., :-1, :]
# Remove empty chunks (can happen when last chunk has size 1)
if logits[-1].size(1) == 0:
logits = logits[:-1]
loss = chunked_cross_entropy(logits, targets[..., 1:])
fabric.backward(loss / train.gradient_accumulation_iters(devices, num_nodes))

Expand Down
3 changes: 3 additions & 0 deletions litgpt/finetune/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,6 +352,9 @@ def fit(
logits = model(input_ids, lm_head_chunk_size=128)
# shift the targets such that output n predicts token n+1
logits[-1] = logits[-1][..., :-1, :]
# Remove empty chunks (can happen when last chunk has size 1)
if logits[-1].size(1) == 0:
logits = logits[:-1]
loss = chunked_cross_entropy(logits, targets[..., 1:])
fabric.backward(loss / train.gradient_accumulation_iters(devices, num_nodes))

Expand Down
3 changes: 3 additions & 0 deletions litgpt/finetune/lora_legacy.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,6 +337,9 @@ def fit(
logits = model(input_ids, lm_head_chunk_size=128)
# shift the targets such that output n predicts token n+1
logits[-1] = logits[-1][..., :-1, :]
# Remove empty chunks (can happen when last chunk has size 1)
if logits[-1].size(1) == 0:
logits = logits[:-1]
loss = chunked_cross_entropy(logits, targets[..., 1:])
fabric.backward(loss / train.gradient_accumulation_iters(devices, num_nodes))

Expand Down
71 changes: 71 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,77 @@ def test_chunked_cross_entropy(ignore_index, B):
torch.testing.assert_close(chunked_loss, baseline_loss)


def test_chunked_cross_entropy_with_empty_last_chunk():
"""Test that chunked_cross_entropy works when last chunk becomes empty after shift.

This tests the fix for the issue where finetune_lora on Gemma models would fail with
IndexError when the last logit chunk has size 1, which becomes size 0 after applying
the shift operation logits[-1] = logits[-1][..., :-1, :].

The fix removes empty chunks before passing to chunked_cross_entropy.
"""
B, V = 2, 100
lm_head_chunk_size = 128

# Test case 1: Sequence length 129 (results in chunks [128, 1])
# After shift, last chunk becomes size 0 and should be removed
T = 129
regular_logits = torch.randn(B, T, V)
targets = torch.randint(0, V, (B, T))

# Simulate what happens in finetune: chunk the logits
chunked_logits = list(regular_logits.split(lm_head_chunk_size, dim=1))
assert len(chunked_logits) == 2
assert chunked_logits[-1].size(1) == 1

# Apply the shift operation (this is what finetune does)
chunked_logits[-1] = chunked_logits[-1][..., :-1, :]
assert chunked_logits[-1].size(1) == 0

# Apply the fix: remove empty chunks
if chunked_logits[-1].size(1) == 0:
chunked_logits = chunked_logits[:-1]

assert len(chunked_logits) == 1

# Now compute loss - should work without error
shifted_targets = targets[..., 1:]
loss = chunked_cross_entropy(chunked_logits, shifted_targets, chunk_size=0)
assert loss.numel() == 1

# Compare with baseline (non-chunked)
baseline_logits = regular_logits[..., :-1, :]
baseline_loss = F.cross_entropy(
baseline_logits.reshape(-1, baseline_logits.size(-1)),
shifted_targets.reshape(-1),
)
torch.testing.assert_close(loss, baseline_loss)

# Test case 2: Sequence length 257 (results in chunks [128, 128, 1])
T = 257
regular_logits = torch.randn(B, T, V)
targets = torch.randint(0, V, (B, T))

chunked_logits = list(regular_logits.split(lm_head_chunk_size, dim=1))
assert len(chunked_logits) == 3

chunked_logits[-1] = chunked_logits[-1][..., :-1, :]
if chunked_logits[-1].size(1) == 0:
chunked_logits = chunked_logits[:-1]

assert len(chunked_logits) == 2

shifted_targets = targets[..., 1:]
loss = chunked_cross_entropy(chunked_logits, shifted_targets, chunk_size=0)

baseline_logits = regular_logits[..., :-1, :]
baseline_loss = F.cross_entropy(
baseline_logits.reshape(-1, baseline_logits.size(-1)),
shifted_targets.reshape(-1),
)
torch.testing.assert_close(loss, baseline_loss)


def test_num_parameters():
model = torch.nn.Linear(2, 2)
assert num_parameters(model) == 6
Expand Down
Loading