diff --git a/extensions/xla/finetune/adapter.py b/extensions/xla/finetune/adapter.py index 051baea75f..176debea89 100644 --- a/extensions/xla/finetune/adapter.py +++ b/extensions/xla/finetune/adapter.py @@ -176,6 +176,9 @@ def train( xm.mark_step() # shift the targets such that output n predicts token n+1 logits[-1] = logits[-1][..., :-1, :] + # Remove empty chunks (can happen when last chunk has size 1) + if logits[-1].size(1) == 0: + logits = logits[:-1] loss = chunked_cross_entropy(logits, targets[..., 1:]) fabric.backward(loss / gradient_accumulation_iters) xm.mark_step() diff --git a/litgpt/finetune/adapter.py b/litgpt/finetune/adapter.py index 813f2c8226..2de39cc966 100644 --- a/litgpt/finetune/adapter.py +++ b/litgpt/finetune/adapter.py @@ -303,6 +303,9 @@ def fit( logits = model(input_ids, lm_head_chunk_size=128) # shift the targets such that output n predicts token n+1 logits[-1] = logits[-1][..., :-1, :] + # Remove empty chunks (can happen when last chunk has size 1) + if logits[-1].size(1) == 0: + logits = logits[:-1] loss = chunked_cross_entropy(logits, targets[..., 1:]) fabric.backward(loss / train.gradient_accumulation_iters(devices, num_nodes)) diff --git a/litgpt/finetune/adapter_v2.py b/litgpt/finetune/adapter_v2.py index b80f5688a5..31785f419a 100644 --- a/litgpt/finetune/adapter_v2.py +++ b/litgpt/finetune/adapter_v2.py @@ -330,6 +330,9 @@ def fit( logits = model(input_ids, lm_head_chunk_size=128) # shift the targets such that output n predicts token n+1 logits[-1] = logits[-1][..., :-1, :] + # Remove empty chunks (can happen when last chunk has size 1) + if logits[-1].size(1) == 0: + logits = logits[:-1] loss = chunked_cross_entropy(logits, targets[..., 1:]) fabric.backward(loss / train.gradient_accumulation_iters(devices, num_nodes)) diff --git a/litgpt/finetune/lora.py b/litgpt/finetune/lora.py index 1ef450f620..cd3d57fa03 100644 --- a/litgpt/finetune/lora.py +++ b/litgpt/finetune/lora.py @@ -352,6 +352,9 @@ def fit( logits = model(input_ids, lm_head_chunk_size=128) # shift the targets such that output n predicts token n+1 logits[-1] = logits[-1][..., :-1, :] + # Remove empty chunks (can happen when last chunk has size 1) + if logits[-1].size(1) == 0: + logits = logits[:-1] loss = chunked_cross_entropy(logits, targets[..., 1:]) fabric.backward(loss / train.gradient_accumulation_iters(devices, num_nodes)) diff --git a/litgpt/finetune/lora_legacy.py b/litgpt/finetune/lora_legacy.py index fe05896df6..e5df0e124a 100644 --- a/litgpt/finetune/lora_legacy.py +++ b/litgpt/finetune/lora_legacy.py @@ -337,6 +337,9 @@ def fit( logits = model(input_ids, lm_head_chunk_size=128) # shift the targets such that output n predicts token n+1 logits[-1] = logits[-1][..., :-1, :] + # Remove empty chunks (can happen when last chunk has size 1) + if logits[-1].size(1) == 0: + logits = logits[:-1] loss = chunked_cross_entropy(logits, targets[..., 1:]) fabric.backward(loss / train.gradient_accumulation_iters(devices, num_nodes)) diff --git a/tests/test_utils.py b/tests/test_utils.py index 4f41366ca8..c1ac146588 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -156,6 +156,77 @@ def test_chunked_cross_entropy(ignore_index, B): torch.testing.assert_close(chunked_loss, baseline_loss) +def test_chunked_cross_entropy_with_empty_last_chunk(): + """Test that chunked_cross_entropy works when last chunk becomes empty after shift. + + This tests the fix for the issue where finetune_lora on Gemma models would fail with + IndexError when the last logit chunk has size 1, which becomes size 0 after applying + the shift operation logits[-1] = logits[-1][..., :-1, :]. + + The fix removes empty chunks before passing to chunked_cross_entropy. + """ + B, V = 2, 100 + lm_head_chunk_size = 128 + + # Test case 1: Sequence length 129 (results in chunks [128, 1]) + # After shift, last chunk becomes size 0 and should be removed + T = 129 + regular_logits = torch.randn(B, T, V) + targets = torch.randint(0, V, (B, T)) + + # Simulate what happens in finetune: chunk the logits + chunked_logits = list(regular_logits.split(lm_head_chunk_size, dim=1)) + assert len(chunked_logits) == 2 + assert chunked_logits[-1].size(1) == 1 + + # Apply the shift operation (this is what finetune does) + chunked_logits[-1] = chunked_logits[-1][..., :-1, :] + assert chunked_logits[-1].size(1) == 0 + + # Apply the fix: remove empty chunks + if chunked_logits[-1].size(1) == 0: + chunked_logits = chunked_logits[:-1] + + assert len(chunked_logits) == 1 + + # Now compute loss - should work without error + shifted_targets = targets[..., 1:] + loss = chunked_cross_entropy(chunked_logits, shifted_targets, chunk_size=0) + assert loss.numel() == 1 + + # Compare with baseline (non-chunked) + baseline_logits = regular_logits[..., :-1, :] + baseline_loss = F.cross_entropy( + baseline_logits.reshape(-1, baseline_logits.size(-1)), + shifted_targets.reshape(-1), + ) + torch.testing.assert_close(loss, baseline_loss) + + # Test case 2: Sequence length 257 (results in chunks [128, 128, 1]) + T = 257 + regular_logits = torch.randn(B, T, V) + targets = torch.randint(0, V, (B, T)) + + chunked_logits = list(regular_logits.split(lm_head_chunk_size, dim=1)) + assert len(chunked_logits) == 3 + + chunked_logits[-1] = chunked_logits[-1][..., :-1, :] + if chunked_logits[-1].size(1) == 0: + chunked_logits = chunked_logits[:-1] + + assert len(chunked_logits) == 2 + + shifted_targets = targets[..., 1:] + loss = chunked_cross_entropy(chunked_logits, shifted_targets, chunk_size=0) + + baseline_logits = regular_logits[..., :-1, :] + baseline_loss = F.cross_entropy( + baseline_logits.reshape(-1, baseline_logits.size(-1)), + shifted_targets.reshape(-1), + ) + torch.testing.assert_close(loss, baseline_loss) + + def test_num_parameters(): model = torch.nn.Linear(2, 2) assert num_parameters(model) == 6