Skip to content

Commit 614845d

Browse files
authored
Adds support for the pixel_position_ids vision key (#5374)
1 parent 05eac2c commit 614845d

File tree

4 files changed

+26
-8
lines changed

4 files changed

+26
-8
lines changed

trl/trainer/dpo_trainer.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -232,7 +232,7 @@ class DataCollatorForVisionPreference(DataCollatorMixin):
232232
- `"completion_mask"`: Tensor indicating which tokens correspond to completions.
233233
- `"pixel_values"`: Tensor representing image pixel values.
234234
235-
Additional keys may be present depending on the processor, such as `"image_grid_thw"`.
235+
Additional keys may be present depending on the processor, such as `"image_grid_thw"` or `"pixel_position_ids"`.
236236
237237
Args:
238238
processor ([`~transformers.ProcessorMixin`]):
@@ -1041,6 +1041,7 @@ def compute_ref_log_probs(self, inputs):
10411041
"pixel_attention_mask",
10421042
"image_grid_thw",
10431043
"image_sizes",
1044+
"pixel_position_ids",
10441045
):
10451046
if key in inputs:
10461047
model_kwargs[key] = inputs[key]
@@ -1166,6 +1167,7 @@ def _compute_loss(self, model, inputs, return_outputs):
11661167
"pixel_attention_mask",
11671168
"image_grid_thw",
11681169
"image_sizes",
1170+
"pixel_position_ids",
11691171
):
11701172
if key in inputs:
11711173
model_kwargs[key] = inputs[key]

trl/trainer/grpo_trainer.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -930,6 +930,7 @@ def _get_last_hidden_state(
930930
image_grid_thw=None,
931931
pixel_attention_mask=None,
932932
image_sizes=None,
933+
pixel_position_ids=None,
933934
):
934935
if is_peft_model(unwrapped_model):
935936
unwrapped_model = unwrapped_model.base_model.model
@@ -949,6 +950,8 @@ def _get_last_hidden_state(
949950
# For LLaVa-Next
950951
if image_sizes is not None:
951952
model_inputs["image_sizes"] = image_sizes
953+
if pixel_position_ids is not None:
954+
model_inputs["pixel_position_ids"] = pixel_position_ids
952955

953956
# Only add logits_to_keep if the model supports it
954957
if "logits_to_keep" in self.model_kwarg_keys:
@@ -1018,6 +1021,7 @@ def _get_per_token_logps_and_entropies(
10181021
image_sizes=None,
10191022
token_type_ids=None,
10201023
mm_token_type_ids=None,
1024+
pixel_position_ids=None,
10211025
) -> dict[str, torch.Tensor | None]:
10221026
"""Compute log-probs and (optionally) entropies for each token."""
10231027
batch_size = batch_size or input_ids.size(0) # Chunk inputs into smaller batches to reduce memory peak
@@ -1049,6 +1053,8 @@ def _get_per_token_logps_and_entropies(
10491053
model_inputs["token_type_ids"] = token_type_ids[start : start + batch_size]
10501054
if mm_token_type_ids is not None:
10511055
model_inputs["mm_token_type_ids"] = mm_token_type_ids[start : start + batch_size]
1056+
if pixel_position_ids is not None:
1057+
model_inputs["pixel_position_ids"] = pixel_position_ids[start : start + batch_size]
10521058

10531059
# Only add logits_to_keep if the model supports it
10541060
if "logits_to_keep" in self.model_kwarg_keys:
@@ -1875,7 +1881,7 @@ def _generate_and_score_completions(
18751881
logits_to_keep,
18761882
batch_size,
18771883
num_images=num_images,
1878-
**forward_kwargs, # may contain pixel_values, image_grid_thw, pixel_attention_mask and image_sizes
1884+
**forward_kwargs, # may contain pixel_values, image_grid_thw, pixel_attention_mask, image_sizes, pixel_position_ids
18791885
)
18801886
else:
18811887
old_per_token_logps = None
@@ -1921,7 +1927,7 @@ def _generate_and_score_completions(
19211927
logits_to_keep,
19221928
batch_size=batch_size,
19231929
num_images=num_images,
1924-
**forward_kwargs, # may contain pixel_values, image_grid_thw, pixel_attention_mask and image_sizes
1930+
**forward_kwargs, # may contain pixel_values, image_grid_thw, pixel_attention_mask, image_sizes, pixel_position_ids
19251931
)
19261932
else:
19271933
# When training a PEFT adapter, how we obtain the reference depends on the setup:
@@ -1936,7 +1942,7 @@ def _generate_and_score_completions(
19361942
logits_to_keep,
19371943
batch_size=batch_size,
19381944
num_images=num_images,
1939-
**forward_kwargs, # may contain pixel_values, image_grid_thw, pixel_attention_mask and image_sizes
1945+
**forward_kwargs, # may contain pixel_values, image_grid_thw, pixel_attention_mask, image_sizes, pixel_position_ids
19401946
)
19411947
else:
19421948
ref_per_token_logps = None
@@ -2115,6 +2121,8 @@ def _generate_and_score_completions(
21152121
output["token_type_ids"] = forward_kwargs["token_type_ids"]
21162122
if "mm_token_type_ids" in forward_kwargs:
21172123
output["mm_token_type_ids"] = forward_kwargs["mm_token_type_ids"]
2124+
if "pixel_position_ids" in forward_kwargs:
2125+
output["pixel_position_ids"] = forward_kwargs["pixel_position_ids"]
21182126
if images is not None:
21192127
output["num_images"] = num_images
21202128
if tool_mask is not None:
@@ -2139,6 +2147,7 @@ def compute_liger_loss(self, unwrapped_model, inputs):
21392147
inputs.get("image_grid_thw"),
21402148
inputs.get("pixel_attention_mask"),
21412149
inputs.get("image_sizes"),
2150+
inputs.get("pixel_position_ids"),
21422151
)
21432152

21442153
# Apply tool_mask (from env_mask) for loss computation in multi-turn training scenarios
@@ -2274,6 +2283,7 @@ def _compute_loss(self, model, inputs):
22742283
image_sizes=inputs.get("image_sizes"),
22752284
token_type_ids=inputs.get("token_type_ids"),
22762285
mm_token_type_ids=inputs.get("mm_token_type_ids"),
2286+
pixel_position_ids=inputs.get("pixel_position_ids"),
22772287
)
22782288

22792289
if self.top_entropy_quantile < 1.0:

trl/trainer/rloo_trainer.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -680,6 +680,7 @@ def _get_per_token_logps_and_entropies(
680680
image_sizes=None,
681681
token_type_ids=None,
682682
mm_token_type_ids=None,
683+
pixel_position_ids=None,
683684
) -> dict[str, torch.Tensor | None]:
684685
"""Compute log-probs and (optionally) entropies for each token."""
685686
batch_size = batch_size or input_ids.size(0) # Chunk inputs into smaller batches to reduce memory peak
@@ -711,6 +712,8 @@ def _get_per_token_logps_and_entropies(
711712
model_inputs["token_type_ids"] = token_type_ids[start : start + batch_size]
712713
if mm_token_type_ids is not None:
713714
model_inputs["mm_token_type_ids"] = mm_token_type_ids[start : start + batch_size]
715+
if pixel_position_ids is not None:
716+
model_inputs["pixel_position_ids"] = pixel_position_ids[start : start + batch_size]
714717

715718
# Only add logits_to_keep if the model supports it
716719
if "logits_to_keep" in self.model_kwarg_keys:
@@ -1211,7 +1214,7 @@ def _generate_and_score_completions(
12111214
logits_to_keep,
12121215
batch_size,
12131216
num_images=num_images,
1214-
**forward_kwargs, # may contain pixel_values, image_grid_thw, pixel_attention_mask and image_sizes
1217+
**forward_kwargs, # may contain pixel_values, image_grid_thw, pixel_attention_mask, image_sizes, pixel_position_ids
12151218
)
12161219
old_logps = (old_per_token_logps * completion_mask).sum(1) # mask out padding and tokens after EOS
12171220

@@ -1225,7 +1228,7 @@ def _generate_and_score_completions(
12251228
logits_to_keep,
12261229
batch_size=batch_size,
12271230
num_images=num_images,
1228-
**forward_kwargs, # may contain pixel_values, image_grid_thw, pixel_attention_mask and image_sizes
1231+
**forward_kwargs, # may contain pixel_values, image_grid_thw, pixel_attention_mask, image_sizes, pixel_position_ids
12291232
)
12301233
else:
12311234
# When training a PEFT adapter, how we obtain the reference depends on the setup:
@@ -1240,7 +1243,7 @@ def _generate_and_score_completions(
12401243
logits_to_keep,
12411244
batch_size=batch_size,
12421245
num_images=num_images,
1243-
**forward_kwargs, # may contain pixel_values, image_grid_thw, pixel_attention_mask and image_sizes
1246+
**forward_kwargs, # may contain pixel_values, image_grid_thw, pixel_attention_mask, image_sizes, pixel_position_ids
12441247
)
12451248
else:
12461249
ref_per_token_logps = None
@@ -1363,6 +1366,8 @@ def _generate_and_score_completions(
13631366
output["token_type_ids"] = forward_kwargs["token_type_ids"]
13641367
if "mm_token_type_ids" in forward_kwargs:
13651368
output["mm_token_type_ids"] = forward_kwargs["mm_token_type_ids"]
1369+
if "pixel_position_ids" in forward_kwargs:
1370+
output["pixel_position_ids"] = forward_kwargs["pixel_position_ids"]
13661371
if images is not None:
13671372
output["num_images"] = num_images
13681373
return output
@@ -1395,6 +1400,7 @@ def _compute_loss(self, model, inputs):
13951400
image_sizes=inputs.get("image_sizes"),
13961401
token_type_ids=inputs.get("token_type_ids"),
13971402
mm_token_type_ids=inputs.get("mm_token_type_ids"),
1403+
pixel_position_ids=inputs.get("pixel_position_ids"),
13981404
)
13991405

14001406
logps = (per_token_logps * completion_mask).sum(1) # mask out padding and tokens after EOS

trl/trainer/sft_trainer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -308,7 +308,7 @@ class DataCollatorForVisionLanguageModeling(DataCollatorMixin):
308308
- `"pixel_values"`: Tensor representing image pixel values.
309309
- `"labels"`: Tensor for training labels.
310310
311-
Additional keys may be present depending on the processor, such as `"image_grid_thw"`.
311+
Additional keys may be present depending on the processor, such as `"image_grid_thw"` or `"pixel_position_ids"`.
312312
313313
Args:
314314
processor ([`~transformers.ProcessorMixin`]):

0 commit comments

Comments
 (0)