diff --git a/tests/test_dpo_trainer.py b/tests/test_dpo_trainer.py index b260dea7f21..2fcdd8555ec 100644 --- a/tests/test_dpo_trainer.py +++ b/tests/test_dpo_trainer.py @@ -22,7 +22,11 @@ from transformers.utils import is_peft_available from trl import DPOConfig, DPOTrainer -from trl.trainer.dpo_trainer import DataCollatorForPreference +from trl.trainer.dpo_trainer import ( + DataCollatorForPreference, + flatten_batch_for_padding_free, + restore_padding_from_flattened, +) from .testing_utils import ( TrlTestCase, @@ -132,6 +136,47 @@ def test_with_pad_to_multiple_of(self): torch.testing.assert_close(result["input_ids"], expected_input_ids) +class TestPaddingFreeHelpers(TrlTestCase): + def test_flatten_batch_for_padding_free(self): + input_ids = torch.tensor( + [ + [10, 11, 12, 13, 0], + [20, 21, 0, 0, 0], + [30, 31, 32, 0, 0], + ] + ) + attention_mask = torch.tensor( + [ + [1, 1, 1, 1, 0], + [1, 1, 0, 0, 0], + [1, 1, 1, 0, 0], + ] + ) + + flat_input_ids, flat_position_ids = flatten_batch_for_padding_free(input_ids, attention_mask) + + expected_flat_input_ids = torch.tensor([[10, 11, 12, 13, 20, 21, 30, 31, 32]]) + expected_flat_position_ids = torch.tensor([[0, 1, 2, 3, 0, 1, 0, 1, 2]]) + + torch.testing.assert_close(flat_input_ids, expected_flat_input_ids) + torch.testing.assert_close(flat_position_ids, expected_flat_position_ids) + + def test_restore_padding_from_flattened(self): + flat_position_ids = torch.tensor([[0, 1, 2, 3, 0, 1, 0, 1, 2]]) + flattened = torch.arange(16, dtype=torch.float32).view(1, 8, 2) + + restored = restore_padding_from_flattened(flattened, flat_position_ids, padding_value=-100) + + expected = torch.tensor( + [ + [[0.0, 1.0], [2.0, 3.0], [4.0, 5.0]], + [[8.0, 9.0], [-100.0, -100.0], [-100.0, -100.0]], + [[12.0, 13.0], [14.0, 15.0], [-100.0, -100.0]], + ] + ) + torch.testing.assert_close(restored, expected) + + class TestDPOTrainer(TrlTestCase): @pytest.mark.parametrize( "model_id", diff --git a/trl/trainer/dpo_trainer.py b/trl/trainer/dpo_trainer.py index 84aef42946d..134265a3f56 100644 --- a/trl/trainer/dpo_trainer.py +++ b/trl/trainer/dpo_trainer.py @@ -90,6 +90,63 @@ def get_dataset_column_names(dataset: Dataset | IterableDataset) -> list[str]: return list(next(iter(dataset)).keys()) if dataset.column_names is None else dataset.column_names +def flatten_batch_for_padding_free( + input_ids: torch.Tensor, attention_mask: torch.Tensor +) -> tuple[torch.Tensor, torch.Tensor]: + """ + Flattens a right-padded batch for padding-free attention. + + Args: + input_ids (`torch.Tensor`): + Tensor of token IDs with shape `(batch_size, sequence_length)`. + attention_mask (`torch.Tensor`): + Tensor with shape `(batch_size, sequence_length)` where non-padding tokens are marked with `1`. + + Returns: + `tuple[torch.Tensor, torch.Tensor]`: + A tuple `(flat_input_ids, flat_position_ids)` where: + - `flat_input_ids` has shape `(1, total_non_padding_tokens)` and contains all non-padding tokens. + - `flat_position_ids` has shape `(1, total_non_padding_tokens)` and resets positions at each sequence + boundary. + """ + non_padding_mask = attention_mask.bool() + position_ids = attention_mask.cumsum(dim=1) - 1 + position_ids = position_ids.masked_fill(~non_padding_mask, 0) + flat_input_ids = input_ids[non_padding_mask].unsqueeze(0) + flat_position_ids = position_ids[non_padding_mask].unsqueeze(0) + return flat_input_ids, flat_position_ids + + +def restore_padding_from_flattened( + tensor: torch.Tensor, flat_position_ids: torch.Tensor, padding_value: int = 0 +) -> torch.Tensor: + """ + Restores per-example padding from a flattened tensor produced in padding-free mode. + + This helper is designed for shifted next-token tensors (for example, logits computed with `[..., :-1, :]`), so the + restored sequence length is derived from `flat_position_ids` and corresponds to `sequence_lengths - 1`. + + Args: + tensor (`torch.Tensor`): + Flattened tensor with shape `(1, total_non_padding_tokens - 1, ...)`. + flat_position_ids (`torch.Tensor`): + Flattened position IDs returned by [`flatten_batch_for_padding_free`] with shape `(1, + total_non_padding_tokens)`. + padding_value (`int`, *optional*, defaults to `0`): + Value used to pad restored sequences to a common length. + + Returns: + `torch.Tensor`: + Restored tensor with shape `(batch_size, max(sequence_lengths - 1), ...)`. + """ + keep_mask = flat_position_ids[:, 1:].ne(0).squeeze(0) + tensor = tensor.squeeze(0)[keep_mask] + starts = flat_position_ids.squeeze(0).eq(0).nonzero(as_tuple=True)[0] + ends = torch.cat((starts[1:], starts.new_tensor([flat_position_ids.size(1)]))) + split_lengths = (ends - starts - 1).clamp_min(0).tolist() + return pad(list(tensor.split(split_lengths, dim=0)), padding_value=padding_value) + + @dataclass class DataCollatorForPreference(DataCollatorMixin): """ @@ -558,13 +615,17 @@ def __init__( # Data collator self.padding_free = args.padding_free - if self.padding_free: + use_flash_attention = model.config._attn_implementation in FLASH_ATTENTION_VARIANTS + if self.padding_free and not use_flash_attention: logger.warning( - "`padding_free=True` is temporarily unavailable after a refactor and is currently disabled. Falling " - "back to standard padding (`padding_free=False`). This feature is planned to return in a future " - "update; for now, please set `padding_free=False` explicitly." + "Padding-free training is enabled, but the attention implementation is not set to a supported flash " + "attention variant. Padding-free training flattens batches into a single sequence, and only the " + "following implementations are known to reliably support this: " + f"{', '.join(sorted(FLASH_ATTENTION_VARIANTS))}. Using other implementations may lead to unexpected " + "behavior. To ensure compatibility, set `attn_implementation` in the model configuration to one of " + "these supported options or verify that your attention mechanism can handle flattened sequences." ) - self.padding_free = False + dataset_sample = next(iter(train_dataset)) self._is_vision_dataset = "image" in dataset_sample or "images" in dataset_sample if self._is_vision_dataset and not self._is_vlm: @@ -572,6 +633,11 @@ def __init__( "The dataset appears to be vision-related (contains 'image' or 'images' keys), but the provided " "model does not seem to be a vision-language model. Please check your model and dataset." ) + if self.padding_free and self._is_vision_dataset: + raise ValueError( + "Padding-free training is not supported for vision-language preference data. Please set " + "`padding_free=False`." + ) if data_collator is None and not self._is_vision_dataset: # Get the pad token: if not provided, use the one from the processing class or the eos token # if the processing class does not have a pad token. @@ -988,11 +1054,17 @@ def compute_ref_log_probs(self, inputs): shift_labels = input_ids[..., 1:].contiguous() shift_completion_mask = completion_mask[..., 1:].contiguous() - - model_kwargs = {"input_ids": input_ids, "attention_mask": attention_mask, "use_cache": False} + if self.padding_free: + flat_input_ids, flat_position_ids = flatten_batch_for_padding_free(input_ids, attention_mask) + model_kwargs = {"input_ids": flat_input_ids, "position_ids": flat_position_ids, "use_cache": False} + else: + model_kwargs = {"input_ids": input_ids, "attention_mask": attention_mask, "use_cache": False} for key in ("pixel_values", "pixel_attention_mask", "image_grid_thw", "image_sizes", "token_type_ids"): if key in inputs: - model_kwargs[key] = inputs[key] + if self.padding_free and key == "token_type_ids": + model_kwargs[key] = inputs[key][attention_mask.bool()].unsqueeze(0) + else: + model_kwargs[key] = inputs[key] with torch.no_grad(), disable_gradient_checkpointing(self.model, self.args.gradient_checkpointing_kwargs): if is_peft_model(self.model) and self.ref_model is None: @@ -1003,6 +1075,8 @@ def compute_ref_log_probs(self, inputs): ref_outputs = self.ref_model(**model_kwargs) ref_shift_logits = ref_outputs.logits[..., :-1, :].contiguous() + if self.padding_free: + ref_shift_logits = restore_padding_from_flattened(ref_shift_logits, flat_position_ids) ref_per_token_logps = selective_log_softmax(ref_shift_logits, shift_labels) ref_per_token_logps[shift_completion_mask == 0] = 0.0 @@ -1037,9 +1111,17 @@ def _compute_loss_liger(self, model, inputs, return_outputs): completion_mask = inputs["completion_mask"] input_ids, attention_mask, completion_mask = self._truncate_inputs(input_ids, attention_mask, completion_mask) + if self.padding_free: + flat_input_ids, flat_position_ids = flatten_batch_for_padding_free(input_ids, attention_mask) + decoder_kwargs = {"input_ids": flat_input_ids, "position_ids": flat_position_ids, "use_cache": False} + else: + decoder_kwargs = {"input_ids": input_ids, "attention_mask": attention_mask, "use_cache": False} + decoder = model.get_decoder() - outputs = decoder(input_ids, attention_mask=attention_mask, use_cache=False) + outputs = decoder(**decoder_kwargs) hidden_states = outputs.last_hidden_state[:, :-1].contiguous() + if self.padding_free: + hidden_states = restore_padding_from_flattened(hidden_states, flat_position_ids) lm_head = model.get_output_embeddings() weight = lm_head.weight bias = lm_head.bias @@ -1049,9 +1131,11 @@ def _compute_loss_liger(self, model, inputs, return_outputs): else: with torch.no_grad(), disable_gradient_checkpointing(self.model, self.args.gradient_checkpointing_kwargs): ref_decoder = self.ref_model.get_decoder() - ref_outputs = ref_decoder(input_ids, attention_mask=attention_mask, use_cache=False) + ref_outputs = ref_decoder(**decoder_kwargs) ref_lm_head = self.ref_model.get_output_embeddings() ref_hidden_states = ref_outputs.last_hidden_state[:, :-1].contiguous() + if self.padding_free: + ref_hidden_states = restore_padding_from_flattened(ref_hidden_states, flat_position_ids) ref_weight = ref_lm_head.weight ref_bias = ref_lm_head.bias @@ -1110,13 +1194,22 @@ def _compute_loss(self, model, inputs, return_outputs): completion_mask = inputs["completion_mask"] input_ids, attention_mask, completion_mask = self._truncate_inputs(input_ids, attention_mask, completion_mask) - model_kwargs = {"input_ids": input_ids, "attention_mask": attention_mask, "use_cache": False} + if self.padding_free: + flat_input_ids, flat_position_ids = flatten_batch_for_padding_free(input_ids, attention_mask) + model_kwargs = {"input_ids": flat_input_ids, "position_ids": flat_position_ids, "use_cache": False} + else: + model_kwargs = {"input_ids": input_ids, "attention_mask": attention_mask, "use_cache": False} for key in ("pixel_values", "pixel_attention_mask", "image_grid_thw", "image_sizes", "token_type_ids"): if key in inputs: - model_kwargs[key] = inputs[key] + if self.padding_free and key == "token_type_ids": + model_kwargs[key] = inputs[key][attention_mask.bool()].unsqueeze(0) + else: + model_kwargs[key] = inputs[key] outputs = model(**model_kwargs) shift_logits = outputs.logits[..., :-1, :].contiguous() + if self.padding_free: + shift_logits = restore_padding_from_flattened(shift_logits, flat_position_ids) shift_labels = input_ids[..., 1:].contiguous() shift_completion_mask = completion_mask[..., 1:].contiguous() per_token_logps = selective_log_softmax(shift_logits, shift_labels) @@ -1154,6 +1247,8 @@ def _compute_loss(self, model, inputs, return_outputs): ref_outputs = self.ref_model(**model_kwargs) ref_shift_logits = ref_outputs.logits[..., :-1, :].contiguous() + if self.padding_free: + ref_shift_logits = restore_padding_from_flattened(ref_shift_logits, flat_position_ids) ref_per_token_logps = selective_log_softmax(ref_shift_logits, shift_labels) ref_per_token_logps[shift_completion_mask == 0] = 0.0 # mask out non-completion tokens if self.ld_alpha is None: