Skip to content

Commit 555cbf5

Browse files
authored
[Idefics] fix device mismatch (#39981)
fix
1 parent 597ed1a commit 555cbf5

File tree

3 files changed

+6
-6
lines changed

3 files changed

+6
-6
lines changed

src/transformers/models/idefics2/modeling_idefics2.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -147,8 +147,8 @@ def forward(self, pixel_values: torch.FloatTensor, patch_attention_mask: torch.B
147147
nb_patches_h = p_attn_mask[:, 0].sum()
148148
nb_patches_w = p_attn_mask[0].sum()
149149

150-
h_indices = torch.arange(nb_patches_h, device=pixel_values.device, dtype=pixel_values.dtype)
151-
w_indices = torch.arange(nb_patches_w, device=pixel_values.device, dtype=pixel_values.dtype)
150+
h_indices = torch.arange(nb_patches_h, device=position_ids.device, dtype=position_ids.dtype)
151+
w_indices = torch.arange(nb_patches_w, device=position_ids.device, dtype=position_ids.dtype)
152152

153153
fractional_coords_h = h_indices / nb_patches_h * (1 - 1e-6)
154154
fractional_coords_w = w_indices / nb_patches_w * (1 - 1e-6)

src/transformers/models/idefics3/modeling_idefics3.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -147,8 +147,8 @@ def forward(self, pixel_values: torch.FloatTensor, patch_attention_mask: torch.B
147147
nb_patches_h = p_attn_mask[:, 0].sum()
148148
nb_patches_w = p_attn_mask[0].sum()
149149

150-
h_indices = torch.arange(nb_patches_h, device=pixel_values.device, dtype=pixel_values.dtype)
151-
w_indices = torch.arange(nb_patches_w, device=pixel_values.device, dtype=pixel_values.dtype)
150+
h_indices = torch.arange(nb_patches_h, device=position_ids.device, dtype=position_ids.dtype)
151+
w_indices = torch.arange(nb_patches_w, device=position_ids.device, dtype=position_ids.dtype)
152152

153153
fractional_coords_h = h_indices / nb_patches_h * (1 - 1e-6)
154154
fractional_coords_w = w_indices / nb_patches_w * (1 - 1e-6)

src/transformers/models/smolvlm/modeling_smolvlm.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -142,8 +142,8 @@ def forward(self, pixel_values: torch.FloatTensor, patch_attention_mask: torch.B
142142
nb_patches_h = p_attn_mask[:, 0].sum()
143143
nb_patches_w = p_attn_mask[0].sum()
144144

145-
h_indices = torch.arange(nb_patches_h, device=pixel_values.device, dtype=pixel_values.dtype)
146-
w_indices = torch.arange(nb_patches_w, device=pixel_values.device, dtype=pixel_values.dtype)
145+
h_indices = torch.arange(nb_patches_h, device=position_ids.device, dtype=position_ids.dtype)
146+
w_indices = torch.arange(nb_patches_w, device=position_ids.device, dtype=position_ids.dtype)
147147

148148
fractional_coords_h = h_indices / nb_patches_h * (1 - 1e-6)
149149
fractional_coords_w = w_indices / nb_patches_w * (1 - 1e-6)

0 commit comments

Comments
 (0)