Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 12 additions & 4 deletions src/transformers/models/idefics2/modeling_idefics2.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,11 +151,19 @@ def forward(self, pixel_values: torch.FloatTensor, patch_attention_mask: torch.B
nb_patches_h = p_attn_mask[:, 0].sum()
nb_patches_w = p_attn_mask[0].sum()

h_indices = torch.arange(nb_patches_h, device=position_ids.device, dtype=pixel_values.dtype)
w_indices = torch.arange(nb_patches_w, device=position_ids.device, dtype=pixel_values.dtype)
step_h = 1.0 / nb_patches_h
step_w = 1.0 / nb_patches_w

fractional_coords_h = h_indices / nb_patches_h * (1 - 1e-6)
fractional_coords_w = w_indices / nb_patches_w * (1 - 1e-6)
h_indices = torch.arange(nb_patches_h, device=position_ids.device, dtype=torch.float32)
w_indices = torch.arange(nb_patches_w, device=position_ids.device, dtype=torch.float32)
fractional_coords_h = h_indices * step_h
fractional_coords_w = w_indices * step_w

fractional_coords_h = torch.clamp(fractional_coords_h, max=(1.0 - 1e-6))
fractional_coords_w = torch.clamp(fractional_coords_w, max=(1.0 - 1e-6))

fractional_coords_h = fractional_coords_h.to(pixel_values.dtype)
fractional_coords_w = fractional_coords_w.to(pixel_values.dtype)

bucket_coords_h = torch.bucketize(fractional_coords_h, boundaries, right=True)
bucket_coords_w = torch.bucketize(fractional_coords_w, boundaries, right=True)
Expand Down
16 changes: 12 additions & 4 deletions src/transformers/models/idefics3/modeling_idefics3.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,11 +149,19 @@ def forward(self, pixel_values: torch.FloatTensor, patch_attention_mask: torch.B
nb_patches_h = p_attn_mask[:, 0].sum()
nb_patches_w = p_attn_mask[0].sum()

h_indices = torch.arange(nb_patches_h, device=position_ids.device, dtype=pixel_values.dtype)
w_indices = torch.arange(nb_patches_w, device=position_ids.device, dtype=pixel_values.dtype)
step_h = 1.0 / nb_patches_h
step_w = 1.0 / nb_patches_w

fractional_coords_h = h_indices / nb_patches_h * (1 - 1e-6)
fractional_coords_w = w_indices / nb_patches_w * (1 - 1e-6)
h_indices = torch.arange(nb_patches_h, device=position_ids.device, dtype=torch.float32)
w_indices = torch.arange(nb_patches_w, device=position_ids.device, dtype=torch.float32)
fractional_coords_h = h_indices * step_h
fractional_coords_w = w_indices * step_w

fractional_coords_h = torch.clamp(fractional_coords_h, max=(1.0 - 1e-6))
fractional_coords_w = torch.clamp(fractional_coords_w, max=(1.0 - 1e-6))

fractional_coords_h = fractional_coords_h.to(pixel_values.dtype)
fractional_coords_w = fractional_coords_w.to(pixel_values.dtype)

bucket_coords_h = torch.bucketize(fractional_coords_h, boundaries, right=True)
bucket_coords_w = torch.bucketize(fractional_coords_w, boundaries, right=True)
Expand Down
16 changes: 12 additions & 4 deletions src/transformers/models/smolvlm/modeling_smolvlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,11 +147,19 @@ def forward(self, pixel_values: torch.FloatTensor, patch_attention_mask: torch.B
nb_patches_h = p_attn_mask[:, 0].sum()
nb_patches_w = p_attn_mask[0].sum()

h_indices = torch.arange(nb_patches_h, device=position_ids.device, dtype=pixel_values.dtype)
w_indices = torch.arange(nb_patches_w, device=position_ids.device, dtype=pixel_values.dtype)
step_h = 1.0 / nb_patches_h
step_w = 1.0 / nb_patches_w

fractional_coords_h = h_indices / nb_patches_h * (1 - 1e-6)
fractional_coords_w = w_indices / nb_patches_w * (1 - 1e-6)
h_indices = torch.arange(nb_patches_h, device=position_ids.device, dtype=torch.float32)
w_indices = torch.arange(nb_patches_w, device=position_ids.device, dtype=torch.float32)
fractional_coords_h = h_indices * step_h
fractional_coords_w = w_indices * step_w

fractional_coords_h = torch.clamp(fractional_coords_h, max=(1.0 - 1e-6))
fractional_coords_w = torch.clamp(fractional_coords_w, max=(1.0 - 1e-6))

fractional_coords_h = fractional_coords_h.to(pixel_values.dtype)
fractional_coords_w = fractional_coords_w.to(pixel_values.dtype)

bucket_coords_h = torch.bucketize(fractional_coords_h, boundaries, right=True)
bucket_coords_w = torch.bucketize(fractional_coords_w, boundaries, right=True)
Expand Down