diff --git a/src/transformers/models/idefics2/modeling_idefics2.py b/src/transformers/models/idefics2/modeling_idefics2.py index 9703a43d605c..1017afb4567e 100644 --- a/src/transformers/models/idefics2/modeling_idefics2.py +++ b/src/transformers/models/idefics2/modeling_idefics2.py @@ -151,11 +151,19 @@ def forward(self, pixel_values: torch.FloatTensor, patch_attention_mask: torch.B nb_patches_h = p_attn_mask[:, 0].sum() nb_patches_w = p_attn_mask[0].sum() - h_indices = torch.arange(nb_patches_h, device=position_ids.device, dtype=pixel_values.dtype) - w_indices = torch.arange(nb_patches_w, device=position_ids.device, dtype=pixel_values.dtype) + step_h = 1.0 / nb_patches_h + step_w = 1.0 / nb_patches_w - fractional_coords_h = h_indices / nb_patches_h * (1 - 1e-6) - fractional_coords_w = w_indices / nb_patches_w * (1 - 1e-6) + h_indices = torch.arange(nb_patches_h, device=position_ids.device, dtype=torch.float32) + w_indices = torch.arange(nb_patches_w, device=position_ids.device, dtype=torch.float32) + fractional_coords_h = h_indices * step_h + fractional_coords_w = w_indices * step_w + + fractional_coords_h = torch.clamp(fractional_coords_h, max=(1.0 - 1e-6)) + fractional_coords_w = torch.clamp(fractional_coords_w, max=(1.0 - 1e-6)) + + fractional_coords_h = fractional_coords_h.to(pixel_values.dtype) + fractional_coords_w = fractional_coords_w.to(pixel_values.dtype) bucket_coords_h = torch.bucketize(fractional_coords_h, boundaries, right=True) bucket_coords_w = torch.bucketize(fractional_coords_w, boundaries, right=True) diff --git a/src/transformers/models/idefics3/modeling_idefics3.py b/src/transformers/models/idefics3/modeling_idefics3.py index 89bbd931fadc..0829de53385f 100644 --- a/src/transformers/models/idefics3/modeling_idefics3.py +++ b/src/transformers/models/idefics3/modeling_idefics3.py @@ -149,11 +149,19 @@ def forward(self, pixel_values: torch.FloatTensor, patch_attention_mask: torch.B nb_patches_h = p_attn_mask[:, 0].sum() nb_patches_w = p_attn_mask[0].sum() - h_indices = torch.arange(nb_patches_h, device=position_ids.device, dtype=pixel_values.dtype) - w_indices = torch.arange(nb_patches_w, device=position_ids.device, dtype=pixel_values.dtype) + step_h = 1.0 / nb_patches_h + step_w = 1.0 / nb_patches_w - fractional_coords_h = h_indices / nb_patches_h * (1 - 1e-6) - fractional_coords_w = w_indices / nb_patches_w * (1 - 1e-6) + h_indices = torch.arange(nb_patches_h, device=position_ids.device, dtype=torch.float32) + w_indices = torch.arange(nb_patches_w, device=position_ids.device, dtype=torch.float32) + fractional_coords_h = h_indices * step_h + fractional_coords_w = w_indices * step_w + + fractional_coords_h = torch.clamp(fractional_coords_h, max=(1.0 - 1e-6)) + fractional_coords_w = torch.clamp(fractional_coords_w, max=(1.0 - 1e-6)) + + fractional_coords_h = fractional_coords_h.to(pixel_values.dtype) + fractional_coords_w = fractional_coords_w.to(pixel_values.dtype) bucket_coords_h = torch.bucketize(fractional_coords_h, boundaries, right=True) bucket_coords_w = torch.bucketize(fractional_coords_w, boundaries, right=True) diff --git a/src/transformers/models/smolvlm/modeling_smolvlm.py b/src/transformers/models/smolvlm/modeling_smolvlm.py index 5ff2b041dd2d..8ae07bb3d8b1 100644 --- a/src/transformers/models/smolvlm/modeling_smolvlm.py +++ b/src/transformers/models/smolvlm/modeling_smolvlm.py @@ -147,11 +147,19 @@ def forward(self, pixel_values: torch.FloatTensor, patch_attention_mask: torch.B nb_patches_h = p_attn_mask[:, 0].sum() nb_patches_w = p_attn_mask[0].sum() - h_indices = torch.arange(nb_patches_h, device=position_ids.device, dtype=pixel_values.dtype) - w_indices = torch.arange(nb_patches_w, device=position_ids.device, dtype=pixel_values.dtype) + step_h = 1.0 / nb_patches_h + step_w = 1.0 / nb_patches_w - fractional_coords_h = h_indices / nb_patches_h * (1 - 1e-6) - fractional_coords_w = w_indices / nb_patches_w * (1 - 1e-6) + h_indices = torch.arange(nb_patches_h, device=position_ids.device, dtype=torch.float32) + w_indices = torch.arange(nb_patches_w, device=position_ids.device, dtype=torch.float32) + fractional_coords_h = h_indices * step_h + fractional_coords_w = w_indices * step_w + + fractional_coords_h = torch.clamp(fractional_coords_h, max=(1.0 - 1e-6)) + fractional_coords_w = torch.clamp(fractional_coords_w, max=(1.0 - 1e-6)) + + fractional_coords_h = fractional_coords_h.to(pixel_values.dtype) + fractional_coords_w = fractional_coords_w.to(pixel_values.dtype) bucket_coords_h = torch.bucketize(fractional_coords_h, boundaries, right=True) bucket_coords_w = torch.bucketize(fractional_coords_w, boundaries, right=True)