Skip to content

Commit a4d506a

Browse files
committed
[bugfix] fix ovis2_5 (#5803)
1 parent 52bd260 commit a4d506a

File tree

2 files changed

+8
-8
lines changed

2 files changed

+8
-8
lines changed

swift/llm/template/template/qwen.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -816,11 +816,11 @@ def _post_encode(self, model: nn.Module, inputs: Dict[str, Any]) -> Dict[str, An
816816
visual_embeds = model.vte(visual_tokens).to(dtype=inputs_embeds.dtype, device=inputs_embeds.device)
817817
inputs_embeds[input_ids == VISUAL_ATOM_ID] = visual_embeds
818818
elif is_deepspeed_enabled():
819-
media_inputs = model.visual_tokenizer.preprocess(
819+
pixel_values, grid_thws = model.visual_tokenizer.preprocess(
820820
Image.new('RGB', (32, 32), (0, 0, 0)), min_pixels=self.min_pixels, max_pixels=self.max_pixels)
821-
media_inputs = to_device(media_inputs, input_ids.device)
822-
pixel_values = media_inputs['pixel_values'].type(inputs_embeds.dtype)
823-
visual_tokens = model.visual_tokenizer(pixel_values, media_inputs['grid_thws'])
821+
pixel_values = pixel_values.to(device=inputs_embeds.device)
822+
grid_thws = grid_thws.to(device=inputs_embeds.device)
823+
visual_tokens = model.visual_tokenizer(pixel_values, grid_thws)
824824
visual_embeds = model.vte(visual_tokens).to(dtype=inputs_embeds.dtype, device=inputs_embeds.device)
825825
inputs_embeds = inputs_embeds + visual_embeds.mean() * 0.
826826

swift/megatron/model/mm_gpt/qwen.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -207,11 +207,11 @@ def get_inputs_embeds(self, inputs_embeds, **kwargs):
207207
for i, indicator_id in enumerate(INDICATOR_IDS):
208208
inputs_embeds[input_ids == indicator_id] = visual_indicator_embeds[i]
209209
if pixel_values is None:
210-
media_inputs = self.visual_tokenizer.preprocess(
210+
pixel_values, grid_thws = self.visual_tokenizer.preprocess(
211211
Image.new('RGB', (32, 32), (0, 0, 0)), min_pixels=self.min_pixels, max_pixels=self.max_pixels)
212-
media_inputs = to_device(media_inputs, input_ids.device)
213-
pixel_values = media_inputs['pixel_values'].type(inputs_embeds.dtype)
214-
visual_tokens = self.visual_tokenizer(pixel_values, media_inputs['grid_thws'])
212+
pixel_values = pixel_values.to(device=inputs_embeds.device)
213+
grid_thws = grid_thws.to(device=inputs_embeds.device)
214+
visual_tokens = self.visual_tokenizer(pixel_values, grid_thws)
215215
visual_embeds = self.vte(visual_tokens).to(dtype=inputs_embeds.dtype, device=inputs_embeds.device)
216216
inputs_embeds += visual_embeds.mean() * 0.
217217
else:

0 commit comments

Comments
 (0)