@@ -930,6 +930,7 @@ def _get_last_hidden_state(
930930 image_grid_thw = None ,
931931 pixel_attention_mask = None ,
932932 image_sizes = None ,
933+ pixel_position_ids = None ,
933934 ):
934935 if is_peft_model (unwrapped_model ):
935936 unwrapped_model = unwrapped_model .base_model .model
@@ -949,6 +950,8 @@ def _get_last_hidden_state(
949950 # For LLaVa-Next
950951 if image_sizes is not None :
951952 model_inputs ["image_sizes" ] = image_sizes
953+ if pixel_position_ids is not None :
954+ model_inputs ["pixel_position_ids" ] = pixel_position_ids
952955
953956 # Only add logits_to_keep if the model supports it
954957 if "logits_to_keep" in self .model_kwarg_keys :
@@ -1018,6 +1021,7 @@ def _get_per_token_logps_and_entropies(
10181021 image_sizes = None ,
10191022 token_type_ids = None ,
10201023 mm_token_type_ids = None ,
1024+ pixel_position_ids = None ,
10211025 ) -> dict [str , torch .Tensor | None ]:
10221026 """Compute log-probs and (optionally) entropies for each token."""
10231027 batch_size = batch_size or input_ids .size (0 ) # Chunk inputs into smaller batches to reduce memory peak
@@ -1049,6 +1053,8 @@ def _get_per_token_logps_and_entropies(
10491053 model_inputs ["token_type_ids" ] = token_type_ids [start : start + batch_size ]
10501054 if mm_token_type_ids is not None :
10511055 model_inputs ["mm_token_type_ids" ] = mm_token_type_ids [start : start + batch_size ]
1056+ if pixel_position_ids is not None :
1057+ model_inputs ["pixel_position_ids" ] = pixel_position_ids [start : start + batch_size ]
10521058
10531059 # Only add logits_to_keep if the model supports it
10541060 if "logits_to_keep" in self .model_kwarg_keys :
@@ -1875,7 +1881,7 @@ def _generate_and_score_completions(
18751881 logits_to_keep ,
18761882 batch_size ,
18771883 num_images = num_images ,
1878- ** forward_kwargs , # may contain pixel_values, image_grid_thw, pixel_attention_mask and image_sizes
1884+ ** forward_kwargs , # may contain pixel_values, image_grid_thw, pixel_attention_mask, image_sizes, pixel_position_ids
18791885 )
18801886 else :
18811887 old_per_token_logps = None
@@ -1921,7 +1927,7 @@ def _generate_and_score_completions(
19211927 logits_to_keep ,
19221928 batch_size = batch_size ,
19231929 num_images = num_images ,
1924- ** forward_kwargs , # may contain pixel_values, image_grid_thw, pixel_attention_mask and image_sizes
1930+ ** forward_kwargs , # may contain pixel_values, image_grid_thw, pixel_attention_mask, image_sizes, pixel_position_ids
19251931 )
19261932 else :
19271933 # When training a PEFT adapter, how we obtain the reference depends on the setup:
@@ -1936,7 +1942,7 @@ def _generate_and_score_completions(
19361942 logits_to_keep ,
19371943 batch_size = batch_size ,
19381944 num_images = num_images ,
1939- ** forward_kwargs , # may contain pixel_values, image_grid_thw, pixel_attention_mask and image_sizes
1945+ ** forward_kwargs , # may contain pixel_values, image_grid_thw, pixel_attention_mask, image_sizes, pixel_position_ids
19401946 )
19411947 else :
19421948 ref_per_token_logps = None
@@ -2115,6 +2121,8 @@ def _generate_and_score_completions(
21152121 output ["token_type_ids" ] = forward_kwargs ["token_type_ids" ]
21162122 if "mm_token_type_ids" in forward_kwargs :
21172123 output ["mm_token_type_ids" ] = forward_kwargs ["mm_token_type_ids" ]
2124+ if "pixel_position_ids" in forward_kwargs :
2125+ output ["pixel_position_ids" ] = forward_kwargs ["pixel_position_ids" ]
21182126 if images is not None :
21192127 output ["num_images" ] = num_images
21202128 if tool_mask is not None :
@@ -2139,6 +2147,7 @@ def compute_liger_loss(self, unwrapped_model, inputs):
21392147 inputs .get ("image_grid_thw" ),
21402148 inputs .get ("pixel_attention_mask" ),
21412149 inputs .get ("image_sizes" ),
2150+ inputs .get ("pixel_position_ids" ),
21422151 )
21432152
21442153 # Apply tool_mask (from env_mask) for loss computation in multi-turn training scenarios
@@ -2274,6 +2283,7 @@ def _compute_loss(self, model, inputs):
22742283 image_sizes = inputs .get ("image_sizes" ),
22752284 token_type_ids = inputs .get ("token_type_ids" ),
22762285 mm_token_type_ids = inputs .get ("mm_token_type_ids" ),
2286+ pixel_position_ids = inputs .get ("pixel_position_ids" ),
22772287 )
22782288
22792289 if self .top_entropy_quantile < 1.0 :
0 commit comments