-
Notifications
You must be signed in to change notification settings - Fork 239
Open
Description
When preparing data for training, the code shifts target (logits) and input_ids to the left using padding(tensor, left=False):
EAGLE/eagle/traineagle3/cnets.py
Lines 722 to 730 in 2866b68
| target = outs.logits | |
| target = padding(target, left=False) | |
| input_ids = padding(input_ids, left=False) | |
| if target is not None: | |
| target = target.to(device) | |
| loss_mask = loss_mask[..., None] | |
| loss_mask = loss_mask.to(device) | |
This left-shift is done for next-token prediction alignment. However, after shifting, the last position contains a zero-padding value (not a real token). If loss_mask still has 1 at the last position, the model computes loss on a meaningless padded target, potentially degrading training quality.
Suggested fix
Exclude the last position from loss computation in the dataprepare function:
target = outs.logits
target = padding(target, left=False)
input_ids = padding(input_ids, left=False)
if target is not None:
target = target.to(device)
loss_mask[..., -1] = 0 # Exclude last position (now contains padding)
loss_mask = loss_mask[..., None]
loss_mask = loss_mask.to(device)Metadata
Metadata
Assignees
Labels
No labels