@@ -1955,6 +1955,8 @@ def _flash_attention_forward(*args, **kwargs):
19551955 else :
19561956 flash_attention_forward = _origin_flash_attention_forward
19571957 kwargs ['position_ids' ] = position_ids
1958+ if args and isinstance (args [0 ], torch .Tensor ):
1959+ kwargs ['position_ids' ] = kwargs ['position_ids' ].to (args [0 ].device )
19581960 return flash_attention_forward (* args , ** kwargs )
19591961
19601962 modeling_module ._flash_attention_forward = _flash_attention_forward
@@ -1977,7 +1979,7 @@ def _get_inputs_embeds_hf(inputs_embeds, inputs, visual, processor, config):
19771979 media_inputs = to_device (media_inputs , input_ids .device )
19781980 pixel_values = media_inputs ['pixel_values' ].type (dtype )
19791981 image_embeds = visual (pixel_values , grid_thw = media_inputs ['image_grid_thw' ])
1980- inputs_embeds = inputs_embeds + image_embeds .mean () * 0.
1982+ inputs_embeds = inputs_embeds + image_embeds .mean (). to ( device = inputs_embeds . device ) * 0.
19811983 else :
19821984 if pixel_values is None :
19831985 pixel_values_mixed = pixel_values_videos
@@ -2005,11 +2007,13 @@ def _get_inputs_embeds_hf(inputs_embeds, inputs, visual, processor, config):
20052007 if image_embeds is not None :
20062008 image_mask = (input_ids == config .image_token_id ).unsqueeze (- 1 ).expand_as (inputs_embeds )
20072009 image_embeds = image_embeds .to (inputs_embeds .device , inputs_embeds .dtype )
2010+ image_mask = image_mask .to (inputs_embeds .device )
20082011 inputs_embeds = inputs_embeds .masked_scatter (image_mask , image_embeds )
20092012
20102013 if video_embeds is not None :
20112014 video_mask = (input_ids == config .video_token_id ).unsqueeze (- 1 ).expand_as (inputs_embeds )
20122015 video_embeds = video_embeds .to (inputs_embeds .device , inputs_embeds .dtype )
2016+ video_mask = video_mask .to (inputs_embeds .device )
20132017 inputs_embeds = inputs_embeds .masked_scatter (video_mask , video_embeds )
20142018 return inputs_embeds
20152019
0 commit comments