-
Notifications
You must be signed in to change notification settings - Fork 10
Open
Description
effective_llm_alignment/src/trainers/smpo_trainer.py
Lines 845 to 849 in a03cad4
| if special_token_id is not None: | |
| special_token_mask = (labels == special_token_id) & loss_mask | |
| per_token_logps[special_token_mask][chosen_count:] = -per_token_logps[ | |
| special_token_mask | |
| ][chosen_count:] |
Вот в этом куске слайсинг, похоже, работает не так как ожидается - вместо special токенов из rejected сэмплов он выбирает special токены из всего батча начиная с chosen_count.
Когда тензор per_token_logps индексируется булевой маской special_token_mask, то в результате получается 1д тензор, и размерность батча теряется.
per_token_logps shape: torch.Size([4, 972]) # (B, seq)
special_token_mask shape: torch.Size([4, 972]) # (B, seq)
per_token_logps[special_token_mask] shape: torch.Size([8]) # 1д тензор всех special токенов из всего батча
per_token_logps[special_token_mask][chosen_count:] shape: torch.Size([6]) # берём из 1д тензора все токены начиная с chosen_count
Чтобы взять special токены из rejected сэмплов, можно написать, например, так:
rejected_logps = per_token_logps[chosen_count:] # (B, seq)
loss_mask = labels != label_pad_token_id
rejected_mask = loss_mask[chosen_count:]
special_mask_rej = (labels[chosen_count:] == special_token_id) & rejected_mask
rejected_logps[special_mask_rej] = -rejected_logps[special_mask_rej]
Эта же проблема, соответственно, и в других местах, где проводится похожий слайсинг:
| # Winsorize extremal values for rejected tokens |
| # Winsorize extremal values for chosen tokens |
| # Clip minimum logprob for rejected tokens |
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels