-
Notifications
You must be signed in to change notification settings - Fork 2.6k
DPO padding-free #5141
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
DPO padding-free #5141
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -90,6 +90,63 @@ def get_dataset_column_names(dataset: Dataset | IterableDataset) -> list[str]: | |
| return list(next(iter(dataset)).keys()) if dataset.column_names is None else dataset.column_names | ||
|
|
||
|
|
||
| def flatten_batch_for_padding_free( | ||
| input_ids: torch.Tensor, attention_mask: torch.Tensor | ||
| ) -> tuple[torch.Tensor, torch.Tensor]: | ||
| """ | ||
| Flattens a right-padded batch for padding-free attention. | ||
|
|
||
| Args: | ||
| input_ids (`torch.Tensor`): | ||
| Tensor of token IDs with shape `(batch_size, sequence_length)`. | ||
| attention_mask (`torch.Tensor`): | ||
| Tensor with shape `(batch_size, sequence_length)` where non-padding tokens are marked with `1`. | ||
|
|
||
| Returns: | ||
| `tuple[torch.Tensor, torch.Tensor]`: | ||
| A tuple `(flat_input_ids, flat_position_ids)` where: | ||
| - `flat_input_ids` has shape `(1, total_non_padding_tokens)` and contains all non-padding tokens. | ||
| - `flat_position_ids` has shape `(1, total_non_padding_tokens)` and resets positions at each sequence | ||
| boundary. | ||
| """ | ||
| non_padding_mask = attention_mask.bool() | ||
| position_ids = attention_mask.cumsum(dim=1) - 1 | ||
| position_ids = position_ids.masked_fill(~non_padding_mask, 0) | ||
| flat_input_ids = input_ids[non_padding_mask].unsqueeze(0) | ||
| flat_position_ids = position_ids[non_padding_mask].unsqueeze(0) | ||
| return flat_input_ids, flat_position_ids | ||
|
|
||
|
|
||
| def restore_padding_from_flattened( | ||
| tensor: torch.Tensor, flat_position_ids: torch.Tensor, padding_value: int = 0 | ||
| ) -> torch.Tensor: | ||
| """ | ||
| Restores per-example padding from a flattened tensor produced in padding-free mode. | ||
|
|
||
| This helper is designed for shifted next-token tensors (for example, logits computed with `[..., :-1, :]`), so the | ||
| restored sequence length is derived from `flat_position_ids` and corresponds to `sequence_lengths - 1`. | ||
|
|
||
| Args: | ||
| tensor (`torch.Tensor`): | ||
| Flattened tensor with shape `(1, total_non_padding_tokens - 1, ...)`. | ||
| flat_position_ids (`torch.Tensor`): | ||
| Flattened position IDs returned by [`flatten_batch_for_padding_free`] with shape `(1, | ||
| total_non_padding_tokens)`. | ||
| padding_value (`int`, *optional*, defaults to `0`): | ||
| Value used to pad restored sequences to a common length. | ||
|
|
||
| Returns: | ||
| `torch.Tensor`: | ||
| Restored tensor with shape `(batch_size, max(sequence_lengths - 1), ...)`. | ||
| """ | ||
| keep_mask = flat_position_ids[:, 1:].ne(0).squeeze(0) | ||
| tensor = tensor.squeeze(0)[keep_mask] | ||
| starts = flat_position_ids.squeeze(0).eq(0).nonzero(as_tuple=True)[0] | ||
| ends = torch.cat((starts[1:], starts.new_tensor([flat_position_ids.size(1)]))) | ||
| split_lengths = (ends - starts - 1).clamp_min(0).tolist() | ||
| return pad(list(tensor.split(split_lengths, dim=0)), padding_value=padding_value) | ||
|
Comment on lines
+142
to
+147
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 📝 Info: restore_padding_from_flattened cross-sequence boundary filtering is correct but subtle The Was this helpful? React with 👍 or 👎 to provide feedback. |
||
|
|
||
|
|
||
| @dataclass | ||
| class DataCollatorForPreference(DataCollatorMixin): | ||
| """ | ||
|
|
@@ -558,20 +615,29 @@ def __init__( | |
|
|
||
| # Data collator | ||
| self.padding_free = args.padding_free | ||
| if self.padding_free: | ||
| use_flash_attention = model.config._attn_implementation in FLASH_ATTENTION_VARIANTS | ||
| if self.padding_free and not use_flash_attention: | ||
| logger.warning( | ||
| "`padding_free=True` is temporarily unavailable after a refactor and is currently disabled. Falling " | ||
| "back to standard padding (`padding_free=False`). This feature is planned to return in a future " | ||
| "update; for now, please set `padding_free=False` explicitly." | ||
| "Padding-free training is enabled, but the attention implementation is not set to a supported flash " | ||
| "attention variant. Padding-free training flattens batches into a single sequence, and only the " | ||
| "following implementations are known to reliably support this: " | ||
| f"{', '.join(sorted(FLASH_ATTENTION_VARIANTS))}. Using other implementations may lead to unexpected " | ||
| "behavior. To ensure compatibility, set `attn_implementation` in the model configuration to one of " | ||
| "these supported options or verify that your attention mechanism can handle flattened sequences." | ||
| ) | ||
| self.padding_free = False | ||
|
|
||
| dataset_sample = next(iter(train_dataset)) | ||
| self._is_vision_dataset = "image" in dataset_sample or "images" in dataset_sample | ||
| if self._is_vision_dataset and not self._is_vlm: | ||
| raise ValueError( | ||
| "The dataset appears to be vision-related (contains 'image' or 'images' keys), but the provided " | ||
| "model does not seem to be a vision-language model. Please check your model and dataset." | ||
| ) | ||
| if self.padding_free and self._is_vision_dataset: | ||
| raise ValueError( | ||
| "Padding-free training is not supported for vision-language preference data. Please set " | ||
| "`padding_free=False`." | ||
| ) | ||
| if data_collator is None and not self._is_vision_dataset: | ||
| # Get the pad token: if not provided, use the one from the processing class or the eos token | ||
| # if the processing class does not have a pad token. | ||
|
|
@@ -988,11 +1054,17 @@ def compute_ref_log_probs(self, inputs): | |
|
|
||
| shift_labels = input_ids[..., 1:].contiguous() | ||
| shift_completion_mask = completion_mask[..., 1:].contiguous() | ||
|
|
||
| model_kwargs = {"input_ids": input_ids, "attention_mask": attention_mask, "use_cache": False} | ||
| if self.padding_free: | ||
| flat_input_ids, flat_position_ids = flatten_batch_for_padding_free(input_ids, attention_mask) | ||
| model_kwargs = {"input_ids": flat_input_ids, "position_ids": flat_position_ids, "use_cache": False} | ||
| else: | ||
| model_kwargs = {"input_ids": input_ids, "attention_mask": attention_mask, "use_cache": False} | ||
| for key in ("pixel_values", "pixel_attention_mask", "image_grid_thw", "image_sizes", "token_type_ids"): | ||
| if key in inputs: | ||
| model_kwargs[key] = inputs[key] | ||
| if self.padding_free and key == "token_type_ids": | ||
| model_kwargs[key] = inputs[key][attention_mask.bool()].unsqueeze(0) | ||
| else: | ||
| model_kwargs[key] = inputs[key] | ||
|
|
||
| with torch.no_grad(), disable_gradient_checkpointing(self.model, self.args.gradient_checkpointing_kwargs): | ||
| if is_peft_model(self.model) and self.ref_model is None: | ||
|
|
@@ -1003,6 +1075,8 @@ def compute_ref_log_probs(self, inputs): | |
| ref_outputs = self.ref_model(**model_kwargs) | ||
|
|
||
| ref_shift_logits = ref_outputs.logits[..., :-1, :].contiguous() | ||
| if self.padding_free: | ||
| ref_shift_logits = restore_padding_from_flattened(ref_shift_logits, flat_position_ids) | ||
| ref_per_token_logps = selective_log_softmax(ref_shift_logits, shift_labels) | ||
| ref_per_token_logps[shift_completion_mask == 0] = 0.0 | ||
|
|
||
|
|
@@ -1037,9 +1111,17 @@ def _compute_loss_liger(self, model, inputs, return_outputs): | |
| completion_mask = inputs["completion_mask"] | ||
| input_ids, attention_mask, completion_mask = self._truncate_inputs(input_ids, attention_mask, completion_mask) | ||
|
|
||
| if self.padding_free: | ||
| flat_input_ids, flat_position_ids = flatten_batch_for_padding_free(input_ids, attention_mask) | ||
| decoder_kwargs = {"input_ids": flat_input_ids, "position_ids": flat_position_ids, "use_cache": False} | ||
| else: | ||
| decoder_kwargs = {"input_ids": input_ids, "attention_mask": attention_mask, "use_cache": False} | ||
|
|
||
| decoder = model.get_decoder() | ||
| outputs = decoder(input_ids, attention_mask=attention_mask, use_cache=False) | ||
| outputs = decoder(**decoder_kwargs) | ||
| hidden_states = outputs.last_hidden_state[:, :-1].contiguous() | ||
| if self.padding_free: | ||
| hidden_states = restore_padding_from_flattened(hidden_states, flat_position_ids) | ||
| lm_head = model.get_output_embeddings() | ||
| weight = lm_head.weight | ||
| bias = lm_head.bias | ||
|
|
@@ -1049,9 +1131,11 @@ def _compute_loss_liger(self, model, inputs, return_outputs): | |
| else: | ||
| with torch.no_grad(), disable_gradient_checkpointing(self.model, self.args.gradient_checkpointing_kwargs): | ||
| ref_decoder = self.ref_model.get_decoder() | ||
| ref_outputs = ref_decoder(input_ids, attention_mask=attention_mask, use_cache=False) | ||
| ref_outputs = ref_decoder(**decoder_kwargs) | ||
| ref_lm_head = self.ref_model.get_output_embeddings() | ||
| ref_hidden_states = ref_outputs.last_hidden_state[:, :-1].contiguous() | ||
| if self.padding_free: | ||
| ref_hidden_states = restore_padding_from_flattened(ref_hidden_states, flat_position_ids) | ||
| ref_weight = ref_lm_head.weight | ||
| ref_bias = ref_lm_head.bias | ||
|
|
||
|
|
@@ -1110,13 +1194,22 @@ def _compute_loss(self, model, inputs, return_outputs): | |
| completion_mask = inputs["completion_mask"] | ||
| input_ids, attention_mask, completion_mask = self._truncate_inputs(input_ids, attention_mask, completion_mask) | ||
|
|
||
| model_kwargs = {"input_ids": input_ids, "attention_mask": attention_mask, "use_cache": False} | ||
| if self.padding_free: | ||
| flat_input_ids, flat_position_ids = flatten_batch_for_padding_free(input_ids, attention_mask) | ||
| model_kwargs = {"input_ids": flat_input_ids, "position_ids": flat_position_ids, "use_cache": False} | ||
| else: | ||
| model_kwargs = {"input_ids": input_ids, "attention_mask": attention_mask, "use_cache": False} | ||
|
Comment on lines
+1197
to
+1201
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 📝 Info: No attention_mask passed to model in padding_free mode relies on model auto-inferring it from position_ids When Was this helpful? React with 👍 or 👎 to provide feedback. |
||
| for key in ("pixel_values", "pixel_attention_mask", "image_grid_thw", "image_sizes", "token_type_ids"): | ||
| if key in inputs: | ||
| model_kwargs[key] = inputs[key] | ||
| if self.padding_free and key == "token_type_ids": | ||
| model_kwargs[key] = inputs[key][attention_mask.bool()].unsqueeze(0) | ||
| else: | ||
| model_kwargs[key] = inputs[key] | ||
|
|
||
| outputs = model(**model_kwargs) | ||
| shift_logits = outputs.logits[..., :-1, :].contiguous() | ||
| if self.padding_free: | ||
| shift_logits = restore_padding_from_flattened(shift_logits, flat_position_ids) | ||
| shift_labels = input_ids[..., 1:].contiguous() | ||
| shift_completion_mask = completion_mask[..., 1:].contiguous() | ||
| per_token_logps = selective_log_softmax(shift_logits, shift_labels) | ||
|
Comment on lines
+1211
to
1215
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🔴 Dimension mismatch between restored logits and shift_labels when When both Detailed explanation with concrete exampleConsider a batch with 2 sequences of lengths 3 and 2, and
This affects three methods: Impact: Any user combining Prompt for agentsWas this helpful? React with 👍 or 👎 to provide feedback.
Comment on lines
+1211
to
1215
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 📝 Info: Padding value of 0 in restored logits is safe due to downstream masking In (Refers to lines 1211-1216) Was this helpful? React with 👍 or 👎 to provide feedback. |
||
|
|
@@ -1154,6 +1247,8 @@ def _compute_loss(self, model, inputs, return_outputs): | |
| ref_outputs = self.ref_model(**model_kwargs) | ||
|
|
||
| ref_shift_logits = ref_outputs.logits[..., :-1, :].contiguous() | ||
| if self.padding_free: | ||
| ref_shift_logits = restore_padding_from_flattened(ref_shift_logits, flat_position_ids) | ||
| ref_per_token_logps = selective_log_softmax(ref_shift_logits, shift_labels) | ||
| ref_per_token_logps[shift_completion_mask == 0] = 0.0 # mask out non-completion tokens | ||
| if self.ld_alpha is None: | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
📝 Info: Test for restore_padding_from_flattened only covers multi-token sequences
The test at
tests/test_dpo_trainer.py:164-177covers sequences of lengths 4, 2, and 3, but does not test edge cases such as a sequence with exactly 1 token (which would produce 0 shifted tokens and an empty split). While tracing the logic showsclamp_min(0)attrl/trainer/dpo_trainer.py:146handles this gracefully (producing an empty tensor that gets padded), and such short sequences are unlikely in DPO training, an explicit test would increase confidence. Similarly, there's no test for the case where all sequences have the same length.Was this helpful? React with 👍 or 👎 to provide feedback.