Skip to content

Commit 849c377

Browse files
authored
[bugfix] Fix tensor device in Idefics2, Idefics3, and SmolVLM (#39975)
* [bugfix] ensure correct tensor device in Idefics2, Idefics3, and SmolVLM models * to cuda
1 parent 85d536a commit 849c377

File tree

3 files changed

+21
-12
lines changed

3 files changed

+21
-12
lines changed

src/transformers/models/idefics2/modeling_idefics2.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -141,8 +141,12 @@ def forward(self, pixel_values: torch.FloatTensor, patch_attention_mask: torch.B
141141
embeddings = patch_embeds.flatten(2).transpose(1, 2)
142142

143143
max_nb_patches_h, max_nb_patches_w = max_im_h // self.patch_size, max_im_w // self.patch_size
144-
boundaries = torch.arange(1 / self.num_patches_per_side, 1.0, 1 / self.num_patches_per_side)
145-
position_ids = torch.full(size=(batch_size, max_nb_patches_h * max_nb_patches_w), fill_value=0)
144+
boundaries = torch.arange(
145+
1 / self.num_patches_per_side, 1.0, 1 / self.num_patches_per_side, device=pixel_values.device
146+
)
147+
position_ids = torch.full(
148+
size=(batch_size, max_nb_patches_h * max_nb_patches_w), fill_value=0, device=pixel_values.device
149+
)
146150

147151
for batch_idx, p_attn_mask in enumerate(patch_attention_mask):
148152
nb_patches_h = p_attn_mask[:, 0].sum()
@@ -158,9 +162,8 @@ def forward(self, pixel_values: torch.FloatTensor, patch_attention_mask: torch.B
158162
bucket_coords_w = torch.bucketize(fractional_coords_w, boundaries, right=True)
159163

160164
pos_ids = (bucket_coords_h[:, None] * self.num_patches_per_side + bucket_coords_w).flatten()
161-
position_ids[batch_idx][p_attn_mask.view(-1).cpu()] = pos_ids
165+
position_ids[batch_idx][p_attn_mask.view(-1)] = pos_ids
162166

163-
position_ids = position_ids.to(self.position_embedding.weight.device)
164167
embeddings = embeddings + self.position_embedding(position_ids)
165168
return embeddings
166169

src/transformers/models/idefics3/modeling_idefics3.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -140,8 +140,12 @@ def forward(self, pixel_values: torch.FloatTensor, patch_attention_mask: torch.B
140140
embeddings = patch_embeds.flatten(2).transpose(1, 2)
141141

142142
max_nb_patches_h, max_nb_patches_w = max_im_h // self.patch_size, max_im_w // self.patch_size
143-
boundaries = torch.arange(1 / self.num_patches_per_side, 1.0, 1 / self.num_patches_per_side)
144-
position_ids = torch.full(size=(batch_size, max_nb_patches_h * max_nb_patches_w), fill_value=0)
143+
boundaries = torch.arange(
144+
1 / self.num_patches_per_side, 1.0, 1 / self.num_patches_per_side, device=pixel_values.device
145+
)
146+
position_ids = torch.full(
147+
size=(batch_size, max_nb_patches_h * max_nb_patches_w), fill_value=0, device=pixel_values.device
148+
)
145149

146150
for batch_idx, p_attn_mask in enumerate(patch_attention_mask):
147151
nb_patches_h = p_attn_mask[:, 0].sum()
@@ -157,9 +161,8 @@ def forward(self, pixel_values: torch.FloatTensor, patch_attention_mask: torch.B
157161
bucket_coords_w = torch.bucketize(fractional_coords_w, boundaries, right=True)
158162

159163
pos_ids = (bucket_coords_h[:, None] * self.num_patches_per_side + bucket_coords_w).flatten()
160-
position_ids[batch_idx][p_attn_mask.view(-1).cpu()] = pos_ids
164+
position_ids[batch_idx][p_attn_mask.view(-1)] = pos_ids
161165

162-
position_ids = position_ids.to(self.position_embedding.weight.device)
163166
embeddings = embeddings + self.position_embedding(position_ids)
164167
return embeddings
165168

src/transformers/models/smolvlm/modeling_smolvlm.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -135,8 +135,12 @@ def forward(self, pixel_values: torch.FloatTensor, patch_attention_mask: torch.B
135135
embeddings = patch_embeds.flatten(2).transpose(1, 2)
136136

137137
max_nb_patches_h, max_nb_patches_w = max_im_h // self.patch_size, max_im_w // self.patch_size
138-
boundaries = torch.arange(1 / self.num_patches_per_side, 1.0, 1 / self.num_patches_per_side)
139-
position_ids = torch.full(size=(batch_size, max_nb_patches_h * max_nb_patches_w), fill_value=0)
138+
boundaries = torch.arange(
139+
1 / self.num_patches_per_side, 1.0, 1 / self.num_patches_per_side, device=pixel_values.device
140+
)
141+
position_ids = torch.full(
142+
size=(batch_size, max_nb_patches_h * max_nb_patches_w), fill_value=0, device=pixel_values.device
143+
)
140144

141145
for batch_idx, p_attn_mask in enumerate(patch_attention_mask):
142146
nb_patches_h = p_attn_mask[:, 0].sum()
@@ -152,9 +156,8 @@ def forward(self, pixel_values: torch.FloatTensor, patch_attention_mask: torch.B
152156
bucket_coords_w = torch.bucketize(fractional_coords_w, boundaries, right=True)
153157

154158
pos_ids = (bucket_coords_h[:, None] * self.num_patches_per_side + bucket_coords_w).flatten()
155-
position_ids[batch_idx][p_attn_mask.view(-1).cpu()] = pos_ids
159+
position_ids[batch_idx][p_attn_mask.view(-1)] = pos_ids
156160

157-
position_ids = position_ids.to(self.position_embedding.weight.device)
158161
embeddings = embeddings + self.position_embedding(position_ids)
159162
return embeddings
160163

0 commit comments

Comments
 (0)