Skip to content

Commit 4763b8c

Browse files
authored
Correct numerical regression in vision embeddings (#41374)
created modeling file
1 parent caa14e7 commit 4763b8c

File tree

3 files changed

+36
-12
lines changed

3 files changed

+36
-12
lines changed

src/transformers/models/idefics2/modeling_idefics2.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -151,11 +151,19 @@ def forward(self, pixel_values: torch.FloatTensor, patch_attention_mask: torch.B
151151
nb_patches_h = p_attn_mask[:, 0].sum()
152152
nb_patches_w = p_attn_mask[0].sum()
153153

154-
h_indices = torch.arange(nb_patches_h, device=position_ids.device, dtype=pixel_values.dtype)
155-
w_indices = torch.arange(nb_patches_w, device=position_ids.device, dtype=pixel_values.dtype)
154+
step_h = 1.0 / nb_patches_h
155+
step_w = 1.0 / nb_patches_w
156156

157-
fractional_coords_h = h_indices / nb_patches_h * (1 - 1e-6)
158-
fractional_coords_w = w_indices / nb_patches_w * (1 - 1e-6)
157+
h_indices = torch.arange(nb_patches_h, device=position_ids.device, dtype=torch.float32)
158+
w_indices = torch.arange(nb_patches_w, device=position_ids.device, dtype=torch.float32)
159+
fractional_coords_h = h_indices * step_h
160+
fractional_coords_w = w_indices * step_w
161+
162+
fractional_coords_h = torch.clamp(fractional_coords_h, max=(1.0 - 1e-6))
163+
fractional_coords_w = torch.clamp(fractional_coords_w, max=(1.0 - 1e-6))
164+
165+
fractional_coords_h = fractional_coords_h.to(pixel_values.dtype)
166+
fractional_coords_w = fractional_coords_w.to(pixel_values.dtype)
159167

160168
bucket_coords_h = torch.bucketize(fractional_coords_h, boundaries, right=True)
161169
bucket_coords_w = torch.bucketize(fractional_coords_w, boundaries, right=True)

src/transformers/models/idefics3/modeling_idefics3.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -149,11 +149,19 @@ def forward(self, pixel_values: torch.FloatTensor, patch_attention_mask: torch.B
149149
nb_patches_h = p_attn_mask[:, 0].sum()
150150
nb_patches_w = p_attn_mask[0].sum()
151151

152-
h_indices = torch.arange(nb_patches_h, device=position_ids.device, dtype=pixel_values.dtype)
153-
w_indices = torch.arange(nb_patches_w, device=position_ids.device, dtype=pixel_values.dtype)
152+
step_h = 1.0 / nb_patches_h
153+
step_w = 1.0 / nb_patches_w
154154

155-
fractional_coords_h = h_indices / nb_patches_h * (1 - 1e-6)
156-
fractional_coords_w = w_indices / nb_patches_w * (1 - 1e-6)
155+
h_indices = torch.arange(nb_patches_h, device=position_ids.device, dtype=torch.float32)
156+
w_indices = torch.arange(nb_patches_w, device=position_ids.device, dtype=torch.float32)
157+
fractional_coords_h = h_indices * step_h
158+
fractional_coords_w = w_indices * step_w
159+
160+
fractional_coords_h = torch.clamp(fractional_coords_h, max=(1.0 - 1e-6))
161+
fractional_coords_w = torch.clamp(fractional_coords_w, max=(1.0 - 1e-6))
162+
163+
fractional_coords_h = fractional_coords_h.to(pixel_values.dtype)
164+
fractional_coords_w = fractional_coords_w.to(pixel_values.dtype)
157165

158166
bucket_coords_h = torch.bucketize(fractional_coords_h, boundaries, right=True)
159167
bucket_coords_w = torch.bucketize(fractional_coords_w, boundaries, right=True)

src/transformers/models/smolvlm/modeling_smolvlm.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -147,11 +147,19 @@ def forward(self, pixel_values: torch.FloatTensor, patch_attention_mask: torch.B
147147
nb_patches_h = p_attn_mask[:, 0].sum()
148148
nb_patches_w = p_attn_mask[0].sum()
149149

150-
h_indices = torch.arange(nb_patches_h, device=position_ids.device, dtype=pixel_values.dtype)
151-
w_indices = torch.arange(nb_patches_w, device=position_ids.device, dtype=pixel_values.dtype)
150+
step_h = 1.0 / nb_patches_h
151+
step_w = 1.0 / nb_patches_w
152152

153-
fractional_coords_h = h_indices / nb_patches_h * (1 - 1e-6)
154-
fractional_coords_w = w_indices / nb_patches_w * (1 - 1e-6)
153+
h_indices = torch.arange(nb_patches_h, device=position_ids.device, dtype=torch.float32)
154+
w_indices = torch.arange(nb_patches_w, device=position_ids.device, dtype=torch.float32)
155+
fractional_coords_h = h_indices * step_h
156+
fractional_coords_w = w_indices * step_w
157+
158+
fractional_coords_h = torch.clamp(fractional_coords_h, max=(1.0 - 1e-6))
159+
fractional_coords_w = torch.clamp(fractional_coords_w, max=(1.0 - 1e-6))
160+
161+
fractional_coords_h = fractional_coords_h.to(pixel_values.dtype)
162+
fractional_coords_w = fractional_coords_w.to(pixel_values.dtype)
155163

156164
bucket_coords_h = torch.bucketize(fractional_coords_h, boundaries, right=True)
157165
bucket_coords_w = torch.bucketize(fractional_coords_w, boundaries, right=True)

0 commit comments

Comments
 (0)