Skip to content

Commit 52bd260

Browse files
committed
[bugfix] fix qwen2_5_vl device_map8 (#5800)
1 parent 07ef6bf commit 52bd260

File tree

1 file changed

+5
-1
lines changed

1 file changed

+5
-1
lines changed

swift/llm/template/base.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1955,6 +1955,8 @@ def _flash_attention_forward(*args, **kwargs):
19551955
else:
19561956
flash_attention_forward = _origin_flash_attention_forward
19571957
kwargs['position_ids'] = position_ids
1958+
if args and isinstance(args[0], torch.Tensor):
1959+
kwargs['position_ids'] = kwargs['position_ids'].to(args[0].device)
19581960
return flash_attention_forward(*args, **kwargs)
19591961

19601962
modeling_module._flash_attention_forward = _flash_attention_forward
@@ -1977,7 +1979,7 @@ def _get_inputs_embeds_hf(inputs_embeds, inputs, visual, processor, config):
19771979
media_inputs = to_device(media_inputs, input_ids.device)
19781980
pixel_values = media_inputs['pixel_values'].type(dtype)
19791981
image_embeds = visual(pixel_values, grid_thw=media_inputs['image_grid_thw'])
1980-
inputs_embeds = inputs_embeds + image_embeds.mean() * 0.
1982+
inputs_embeds = inputs_embeds + image_embeds.mean().to(device=inputs_embeds.device) * 0.
19811983
else:
19821984
if pixel_values is None:
19831985
pixel_values_mixed = pixel_values_videos
@@ -2005,11 +2007,13 @@ def _get_inputs_embeds_hf(inputs_embeds, inputs, visual, processor, config):
20052007
if image_embeds is not None:
20062008
image_mask = (input_ids == config.image_token_id).unsqueeze(-1).expand_as(inputs_embeds)
20072009
image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
2010+
image_mask = image_mask.to(inputs_embeds.device)
20082011
inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)
20092012

20102013
if video_embeds is not None:
20112014
video_mask = (input_ids == config.video_token_id).unsqueeze(-1).expand_as(inputs_embeds)
20122015
video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
2016+
video_mask = video_mask.to(inputs_embeds.device)
20132017
inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds)
20142018
return inputs_embeds
20152019

0 commit comments

Comments
 (0)