@@ -3870,6 +3870,52 @@ def get_model_tokenizer_yi_vl(model_dir: str,
38703870 return model , tokenizer
38713871
38723872
3873+ def _patch_minicpm_v_device_map (model ) -> None :
3874+ if not hasattr (model , 'hf_device_map' ) or len (model .hf_device_map .values ()) == 1 :
3875+ return
3876+ if hasattr (model .llm , '__old_forward' ):
3877+ # avoid double patching
3878+ return
3879+ device = list (model .hf_device_map .values ())[0 ]
3880+ if hasattr (model , 'get_vision_embedding' ): # minicpm-v-v2-chat
3881+ _old_get_vision_embedding = model .get_vision_embedding
3882+
3883+ def _get_vision_embedding (pixel_values ):
3884+ if len (pixel_values ) == 0 :
3885+ return _old_get_vision_embedding (pixel_values )
3886+ output = _old_get_vision_embedding (pixel_values )
3887+ return output .to (device = device )
3888+
3889+ model ._old_get_vision_embedding = _old_get_vision_embedding
3890+ model .get_vision_embedding = _get_vision_embedding
3891+
3892+ if hasattr (model , 'resampler' ): # minicpm-v-v2_5-chat
3893+ __old_resampler_forward = model .resampler .forward
3894+
3895+ def _new_resampler_forward (* args , ** kwargs ) -> Tensor :
3896+ output = __old_resampler_forward (* args , ** kwargs )
3897+ return output .to (device = device )
3898+
3899+ model .resampler .forward = _new_resampler_forward
3900+
3901+ __old_forward = model .llm .forward
3902+
3903+ def _new_forward (* args , ** kwargs ) -> Tensor :
3904+ inputs = kwargs .get ('inputs_embeds' )
3905+ if inputs is None :
3906+ inputs = kwargs .get ('input_ids' )
3907+ device = inputs .device
3908+ output = __old_forward (* args , ** kwargs )
3909+ if output .logits is not None :
3910+ output .logits = output .logits .to (device )
3911+ if output .loss is not None :
3912+ output .loss = output .loss .to (device )
3913+ return output
3914+
3915+ model .llm .forward = _new_forward
3916+ model .llm .__old_forward = __old_forward
3917+
3918+
38733919@register_model (
38743920 ModelType .minicpm_v_3b_chat ,
38753921 'OpenBMB/MiniCPM-V' ,
@@ -3904,6 +3950,7 @@ def get_model_tokenizer_minicpm_v(model_dir: str,
39043950 model , tokenizer = get_model_tokenizer_with_flash_attn (model_dir , torch_dtype , model_kwargs , load_model , ** kwargs )
39053951 if load_model :
39063952 model .resampler .to (torch_dtype ) # fix float32
3953+ _patch_minicpm_v_device_map (model )
39073954 func_list = ['generate' , 'get_input_embeddings' , 'forward' ]
39083955 _use_submodel_func (model , 'llm' , func_list )
39093956 if patching_embedding :
0 commit comments