Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 46 additions & 1 deletion tests/test_dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,11 @@
from transformers.utils import is_peft_available

from trl import DPOConfig, DPOTrainer
from trl.trainer.dpo_trainer import DataCollatorForPreference
from trl.trainer.dpo_trainer import (
DataCollatorForPreference,
flatten_batch_for_padding_free,
restore_padding_from_flattened,
)

from .testing_utils import (
TrlTestCase,
Expand Down Expand Up @@ -132,6 +136,47 @@ def test_with_pad_to_multiple_of(self):
torch.testing.assert_close(result["input_ids"], expected_input_ids)


class TestPaddingFreeHelpers(TrlTestCase):
def test_flatten_batch_for_padding_free(self):
input_ids = torch.tensor(
[
[10, 11, 12, 13, 0],
[20, 21, 0, 0, 0],
[30, 31, 32, 0, 0],
]
)
attention_mask = torch.tensor(
[
[1, 1, 1, 1, 0],
[1, 1, 0, 0, 0],
[1, 1, 1, 0, 0],
]
)

flat_input_ids, flat_position_ids = flatten_batch_for_padding_free(input_ids, attention_mask)

expected_flat_input_ids = torch.tensor([[10, 11, 12, 13, 20, 21, 30, 31, 32]])
expected_flat_position_ids = torch.tensor([[0, 1, 2, 3, 0, 1, 0, 1, 2]])

torch.testing.assert_close(flat_input_ids, expected_flat_input_ids)
torch.testing.assert_close(flat_position_ids, expected_flat_position_ids)

def test_restore_padding_from_flattened(self):
flat_position_ids = torch.tensor([[0, 1, 2, 3, 0, 1, 0, 1, 2]])
flattened = torch.arange(16, dtype=torch.float32).view(1, 8, 2)

restored = restore_padding_from_flattened(flattened, flat_position_ids, padding_value=-100)

expected = torch.tensor(
[
[[0.0, 1.0], [2.0, 3.0], [4.0, 5.0]],
[[8.0, 9.0], [-100.0, -100.0], [-100.0, -100.0]],
[[12.0, 13.0], [14.0, 15.0], [-100.0, -100.0]],
]
)
torch.testing.assert_close(restored, expected)
Comment on lines +164 to +177
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

📝 Info: Test for restore_padding_from_flattened only covers multi-token sequences

The test at tests/test_dpo_trainer.py:164-177 covers sequences of lengths 4, 2, and 3, but does not test edge cases such as a sequence with exactly 1 token (which would produce 0 shifted tokens and an empty split). While tracing the logic shows clamp_min(0) at trl/trainer/dpo_trainer.py:146 handles this gracefully (producing an empty tensor that gets padded), and such short sequences are unlikely in DPO training, an explicit test would increase confidence. Similarly, there's no test for the case where all sequences have the same length.

Open in Devin Review

Was this helpful? React with 👍 or 👎 to provide feedback.



class TestDPOTrainer(TrlTestCase):
@pytest.mark.parametrize(
"model_id",
Expand Down
119 changes: 107 additions & 12 deletions trl/trainer/dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,63 @@ def get_dataset_column_names(dataset: Dataset | IterableDataset) -> list[str]:
return list(next(iter(dataset)).keys()) if dataset.column_names is None else dataset.column_names


def flatten_batch_for_padding_free(
input_ids: torch.Tensor, attention_mask: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Flattens a right-padded batch for padding-free attention.

Args:
input_ids (`torch.Tensor`):
Tensor of token IDs with shape `(batch_size, sequence_length)`.
attention_mask (`torch.Tensor`):
Tensor with shape `(batch_size, sequence_length)` where non-padding tokens are marked with `1`.

Returns:
`tuple[torch.Tensor, torch.Tensor]`:
A tuple `(flat_input_ids, flat_position_ids)` where:
- `flat_input_ids` has shape `(1, total_non_padding_tokens)` and contains all non-padding tokens.
- `flat_position_ids` has shape `(1, total_non_padding_tokens)` and resets positions at each sequence
boundary.
"""
non_padding_mask = attention_mask.bool()
position_ids = attention_mask.cumsum(dim=1) - 1
position_ids = position_ids.masked_fill(~non_padding_mask, 0)
flat_input_ids = input_ids[non_padding_mask].unsqueeze(0)
flat_position_ids = position_ids[non_padding_mask].unsqueeze(0)
return flat_input_ids, flat_position_ids


def restore_padding_from_flattened(
tensor: torch.Tensor, flat_position_ids: torch.Tensor, padding_value: int = 0
) -> torch.Tensor:
"""
Restores per-example padding from a flattened tensor produced in padding-free mode.

This helper is designed for shifted next-token tensors (for example, logits computed with `[..., :-1, :]`), so the
restored sequence length is derived from `flat_position_ids` and corresponds to `sequence_lengths - 1`.

Args:
tensor (`torch.Tensor`):
Flattened tensor with shape `(1, total_non_padding_tokens - 1, ...)`.
flat_position_ids (`torch.Tensor`):
Flattened position IDs returned by [`flatten_batch_for_padding_free`] with shape `(1,
total_non_padding_tokens)`.
padding_value (`int`, *optional*, defaults to `0`):
Value used to pad restored sequences to a common length.

Returns:
`torch.Tensor`:
Restored tensor with shape `(batch_size, max(sequence_lengths - 1), ...)`.
"""
keep_mask = flat_position_ids[:, 1:].ne(0).squeeze(0)
tensor = tensor.squeeze(0)[keep_mask]
starts = flat_position_ids.squeeze(0).eq(0).nonzero(as_tuple=True)[0]
ends = torch.cat((starts[1:], starts.new_tensor([flat_position_ids.size(1)])))
split_lengths = (ends - starts - 1).clamp_min(0).tolist()
return pad(list(tensor.split(split_lengths, dim=0)), padding_value=padding_value)
Comment on lines +142 to +147
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

📝 Info: restore_padding_from_flattened cross-sequence boundary filtering is correct but subtle

The keep_mask at trl/trainer/dpo_trainer.py:142 filters elements from the shifted ([..., :-1, :]) tensor using flat_position_ids[:, 1:].ne(0). This removes exactly the cross-sequence boundary predictions (where a token from sequence N predicts the first token of sequence N+1). For example, with flat_position_ids = [[0,1,2,3,0,1,0,1,2]], the shifted position IDs are [1,2,3,0,1,0,1,2], and .ne(0) correctly identifies indices 3 and 5 as boundaries to discard. The split_lengths at line 146 use (ends - starts - 1).clamp_min(0) which correctly yields per-sequence shifted lengths (original_length - 1). This logic is correct but non-obvious — a comment explaining why keep_mask filters cross-boundary predictions would improve maintainability.

Open in Devin Review

Was this helpful? React with 👍 or 👎 to provide feedback.



@dataclass
class DataCollatorForPreference(DataCollatorMixin):
"""
Expand Down Expand Up @@ -558,20 +615,29 @@ def __init__(

# Data collator
self.padding_free = args.padding_free
if self.padding_free:
use_flash_attention = model.config._attn_implementation in FLASH_ATTENTION_VARIANTS
if self.padding_free and not use_flash_attention:
logger.warning(
"`padding_free=True` is temporarily unavailable after a refactor and is currently disabled. Falling "
"back to standard padding (`padding_free=False`). This feature is planned to return in a future "
"update; for now, please set `padding_free=False` explicitly."
"Padding-free training is enabled, but the attention implementation is not set to a supported flash "
"attention variant. Padding-free training flattens batches into a single sequence, and only the "
"following implementations are known to reliably support this: "
f"{', '.join(sorted(FLASH_ATTENTION_VARIANTS))}. Using other implementations may lead to unexpected "
"behavior. To ensure compatibility, set `attn_implementation` in the model configuration to one of "
"these supported options or verify that your attention mechanism can handle flattened sequences."
)
self.padding_free = False

dataset_sample = next(iter(train_dataset))
self._is_vision_dataset = "image" in dataset_sample or "images" in dataset_sample
if self._is_vision_dataset and not self._is_vlm:
raise ValueError(
"The dataset appears to be vision-related (contains 'image' or 'images' keys), but the provided "
"model does not seem to be a vision-language model. Please check your model and dataset."
)
if self.padding_free and self._is_vision_dataset:
raise ValueError(
"Padding-free training is not supported for vision-language preference data. Please set "
"`padding_free=False`."
)
if data_collator is None and not self._is_vision_dataset:
# Get the pad token: if not provided, use the one from the processing class or the eos token
# if the processing class does not have a pad token.
Expand Down Expand Up @@ -988,11 +1054,17 @@ def compute_ref_log_probs(self, inputs):

shift_labels = input_ids[..., 1:].contiguous()
shift_completion_mask = completion_mask[..., 1:].contiguous()

model_kwargs = {"input_ids": input_ids, "attention_mask": attention_mask, "use_cache": False}
if self.padding_free:
flat_input_ids, flat_position_ids = flatten_batch_for_padding_free(input_ids, attention_mask)
model_kwargs = {"input_ids": flat_input_ids, "position_ids": flat_position_ids, "use_cache": False}
else:
model_kwargs = {"input_ids": input_ids, "attention_mask": attention_mask, "use_cache": False}
for key in ("pixel_values", "pixel_attention_mask", "image_grid_thw", "image_sizes", "token_type_ids"):
if key in inputs:
model_kwargs[key] = inputs[key]
if self.padding_free and key == "token_type_ids":
model_kwargs[key] = inputs[key][attention_mask.bool()].unsqueeze(0)
else:
model_kwargs[key] = inputs[key]

with torch.no_grad(), disable_gradient_checkpointing(self.model, self.args.gradient_checkpointing_kwargs):
if is_peft_model(self.model) and self.ref_model is None:
Expand All @@ -1003,6 +1075,8 @@ def compute_ref_log_probs(self, inputs):
ref_outputs = self.ref_model(**model_kwargs)

ref_shift_logits = ref_outputs.logits[..., :-1, :].contiguous()
if self.padding_free:
ref_shift_logits = restore_padding_from_flattened(ref_shift_logits, flat_position_ids)
ref_per_token_logps = selective_log_softmax(ref_shift_logits, shift_labels)
ref_per_token_logps[shift_completion_mask == 0] = 0.0

Expand Down Expand Up @@ -1037,9 +1111,17 @@ def _compute_loss_liger(self, model, inputs, return_outputs):
completion_mask = inputs["completion_mask"]
input_ids, attention_mask, completion_mask = self._truncate_inputs(input_ids, attention_mask, completion_mask)

if self.padding_free:
flat_input_ids, flat_position_ids = flatten_batch_for_padding_free(input_ids, attention_mask)
decoder_kwargs = {"input_ids": flat_input_ids, "position_ids": flat_position_ids, "use_cache": False}
else:
decoder_kwargs = {"input_ids": input_ids, "attention_mask": attention_mask, "use_cache": False}

decoder = model.get_decoder()
outputs = decoder(input_ids, attention_mask=attention_mask, use_cache=False)
outputs = decoder(**decoder_kwargs)
hidden_states = outputs.last_hidden_state[:, :-1].contiguous()
if self.padding_free:
hidden_states = restore_padding_from_flattened(hidden_states, flat_position_ids)
lm_head = model.get_output_embeddings()
weight = lm_head.weight
bias = lm_head.bias
Expand All @@ -1049,9 +1131,11 @@ def _compute_loss_liger(self, model, inputs, return_outputs):
else:
with torch.no_grad(), disable_gradient_checkpointing(self.model, self.args.gradient_checkpointing_kwargs):
ref_decoder = self.ref_model.get_decoder()
ref_outputs = ref_decoder(input_ids, attention_mask=attention_mask, use_cache=False)
ref_outputs = ref_decoder(**decoder_kwargs)
ref_lm_head = self.ref_model.get_output_embeddings()
ref_hidden_states = ref_outputs.last_hidden_state[:, :-1].contiguous()
if self.padding_free:
ref_hidden_states = restore_padding_from_flattened(ref_hidden_states, flat_position_ids)
ref_weight = ref_lm_head.weight
ref_bias = ref_lm_head.bias

Expand Down Expand Up @@ -1110,13 +1194,22 @@ def _compute_loss(self, model, inputs, return_outputs):
completion_mask = inputs["completion_mask"]
input_ids, attention_mask, completion_mask = self._truncate_inputs(input_ids, attention_mask, completion_mask)

model_kwargs = {"input_ids": input_ids, "attention_mask": attention_mask, "use_cache": False}
if self.padding_free:
flat_input_ids, flat_position_ids = flatten_batch_for_padding_free(input_ids, attention_mask)
model_kwargs = {"input_ids": flat_input_ids, "position_ids": flat_position_ids, "use_cache": False}
else:
model_kwargs = {"input_ids": input_ids, "attention_mask": attention_mask, "use_cache": False}
Comment on lines +1197 to +1201
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

📝 Info: No attention_mask passed to model in padding_free mode relies on model auto-inferring it from position_ids

When padding_free=True, the model_kwargs at trl/trainer/dpo_trainer.py:1199 include position_ids but omit attention_mask. This relies on the model's flash attention implementation inferring sequence boundaries from position_ids (detecting resets to 0 as new sequence starts). This is the standard convention for HuggingFace flash attention implementations and is correct for the supported flash attention variants listed in FLASH_ATTENTION_VARIANTS. However, if a user bypasses the warning and uses a non-flash attention implementation, the lack of attention_mask would cause incorrect self-attention (all tokens attending to all other tokens across sequence boundaries).

Open in Devin Review

Was this helpful? React with 👍 or 👎 to provide feedback.

for key in ("pixel_values", "pixel_attention_mask", "image_grid_thw", "image_sizes", "token_type_ids"):
if key in inputs:
model_kwargs[key] = inputs[key]
if self.padding_free and key == "token_type_ids":
model_kwargs[key] = inputs[key][attention_mask.bool()].unsqueeze(0)
else:
model_kwargs[key] = inputs[key]

outputs = model(**model_kwargs)
shift_logits = outputs.logits[..., :-1, :].contiguous()
if self.padding_free:
shift_logits = restore_padding_from_flattened(shift_logits, flat_position_ids)
shift_labels = input_ids[..., 1:].contiguous()
shift_completion_mask = completion_mask[..., 1:].contiguous()
per_token_logps = selective_log_softmax(shift_logits, shift_labels)
Comment on lines +1211 to 1215
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🔴 Dimension mismatch between restored logits and shift_labels when padding_free=True with pad_to_multiple_of

When both padding_free=True and pad_to_multiple_of is set, restore_padding_from_flattened produces a tensor whose sequence dimension is max_unpadded_seq_len - 1, while shift_labels and shift_completion_mask (derived from the original padded input_ids) have sequence dimension padded_seq_len - 1. Because pad_to_multiple_of can make padded_seq_len > max_unpadded_seq_len, these dimensions will not match, causing a runtime crash in selective_log_softmax (or the liger loss).

Detailed explanation with concrete example

Consider a batch with 2 sequences of lengths 3 and 2, and pad_to_multiple_of=4:

  • DataCollator pads to length 4 → input_ids shape: (2, 4)
  • shift_labels = input_ids[..., 1:] → shape (2, 3)
  • Flattening removes padding → 5 non-padding tokens
  • Model forward → (1, 5, vocab), after [..., :-1, :](1, 4, vocab)
  • restore_padding_from_flattened removes cross-boundary elements and pads back → shape (2, 2, vocab) (max unpadded length 3 minus 1 = 2)
  • selective_log_softmax(shift_logits, shift_labels) receives shapes (2, 2, vocab) and (2, 3)crash

This affects three methods: _compute_loss at trl/trainer/dpo_trainer.py:1215, compute_ref_log_probs at trl/trainer/dpo_trainer.py:1080, and _compute_loss_liger at trl/trainer/dpo_trainer.py:1146-1147. There is no validation preventing the user from setting both padding_free=True and pad_to_multiple_of.

Impact: Any user combining padding_free=True with a non-None pad_to_multiple_of will hit a dimension mismatch crash during training whenever pad_to_multiple_of rounds up beyond the longest sequence in the batch.

Prompt for agents
In trl/trainer/dpo_trainer.py, around where the padding_free and vision dataset checks are done (near line 635-638), add a validation that raises an error when both padding_free=True and pad_to_multiple_of is not None, since restore_padding_from_flattened reconstructs tensors based on actual (unpadded) sequence lengths, which will be shorter than the padded_seq_len when pad_to_multiple_of rounds up. For example, add after line 638:

if self.padding_free and args.pad_to_multiple_of is not None:
    raise ValueError(
        "Padding-free training is not compatible with `pad_to_multiple_of`. "
        "Please set `pad_to_multiple_of=None` when using `padding_free=True`."
    )

Alternatively, the fix could be applied inside the loss computation methods (_compute_loss, compute_ref_log_probs, _compute_loss_liger) by truncating shift_labels and shift_completion_mask to match the restored logits' sequence length after restore_padding_from_flattened is called.
Open in Devin Review

Was this helpful? React with 👍 or 👎 to provide feedback.

Comment on lines +1211 to 1215
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

📝 Info: Padding value of 0 in restored logits is safe due to downstream masking

In _compute_loss at line 1212, restore_padding_from_flattened(shift_logits, flat_position_ids) uses the default padding_value=0, meaning restored padding positions get all-zero logits. While this produces incorrect log-probabilities at those positions (approximately -log(vocab_size) after log_softmax), the subsequent masking at line 1216 (per_token_logps[shift_completion_mask == 0] = 0.0) zeros them out before summation. Similarly in compute_ref_log_probs at line 1081. So the choice of padding_value=0 is safe, though it could be confusing to future readers.

(Refers to lines 1211-1216)

Open in Devin Review

Was this helpful? React with 👍 or 👎 to provide feedback.

Expand Down Expand Up @@ -1154,6 +1247,8 @@ def _compute_loss(self, model, inputs, return_outputs):
ref_outputs = self.ref_model(**model_kwargs)

ref_shift_logits = ref_outputs.logits[..., :-1, :].contiguous()
if self.padding_free:
ref_shift_logits = restore_padding_from_flattened(ref_shift_logits, flat_position_ids)
ref_per_token_logps = selective_log_softmax(ref_shift_logits, shift_labels)
ref_per_token_logps[shift_completion_mask == 0] = 0.0 # mask out non-completion tokens
if self.ld_alpha is None:
Expand Down
Loading