Fix IndexError in finetune scripts when last logit chunk becomes empty #2141
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Problem
Users encountered an
IndexError
when runningfinetune_lora
and other finetune scripts on Gemma models (and potentially other models with certain sequence lengths). The error occurred during training when processing batches with specific sequence lengths.Root Cause
When
lm_head_chunk_size=128
is used in the model forward pass, logits are returned as a list of chunks. The finetune code applies a shift operation to align predictions with targets:The bug: When the last chunk has a sequence length of exactly 1, the slicing operation
[..., :-1, :]
creates a chunk with length 0. This empty chunk then causeschunked_cross_entropy
to fail because PyTorch'ssplit()
function doesn't accept a split size of 0.This occurs when the total sequence length is of the form
128*n + 1
(e.g., 1, 129, 257, 385, etc.).Solution
Added a simple check after the shift operation to remove empty chunks:
This ensures all chunks passed to the loss function have non-zero sequence length.
Changes
litgpt/finetune/lora.py
litgpt/finetune/adapter.py
litgpt/finetune/adapter_v2.py
litgpt/finetune/lora_legacy.py
extensions/xla/finetune/adapter.py
test_chunked_cross_entropy_with_empty_last_chunk()
to validate the fixImpact
Fixes #[issue_number]
Warning
Firewall rules blocked me from connecting to one or more addresses (expand for details)
I tried to connect to the following addresses, but was blocked by firewall rules:
huggingface.co
python3
(dns block)If you need me to access, download, or install something from one of these locations, you can either:
Original prompt
This section details on the original issue you should resolve
<issue_title>finetune_lora on gemma bug</issue_title>
<issue_description>### Bug description
I am trying to use finetune_lora to do PEFT on gemma model, and I have tried:
both encouter IndexError. I have also tried other series models like QwQ and llama etc, all look fine.
It seems some people met similar bug( but on gemma-7b), not sure whether they are some problem.
What operating system are you using?
Linux
LitGPT Version
litgpt0.5.7 & litgpt0.5.8.dev1