@@ -2583,6 +2583,7 @@ def get_model_tokenizer_deepseek2(model_dir: str,
25832583 model , tokenizer = get_model_tokenizer_from_repo (
25842584 model_dir , torch_dtype , model_kwargs , load_model , model_config = model_config , ** kwargs )
25852585 if model is not None :
2586+ model .generation_config .pad_token_id = model .generation_config .eos_token_id
25862587 # fix dtype bug
25872588 model .generation_config .pad_token_id = model .generation_config .eos_token_id
25882589 mlp_cls = model .model .layers [1 ].mlp .__class__
@@ -2640,11 +2641,19 @@ def get_model_tokenizer_internvl(model_dir: str,
26402641 model_kwargs : Dict [str , Any ],
26412642 load_model : bool = True ,
26422643 ** kwargs ):
2643-
26442644 model_config = AutoConfig .from_pretrained (model_dir , trust_remote_code = True )
26452645 use_flash_attn = kwargs .pop ('use_flash_attn' , False )
26462646 model_config .vision_config .use_flash_attn = use_flash_attn
26472647 model_config .llm_config .attn_implementation = 'flash_attention_2' if use_flash_attn else 'eager'
2648+ model_quant_config = getattr (model_config , 'quantization_config' , None )
2649+
2650+ use_bnb = False
2651+ if model_quant_config is not None :
2652+ use_bnb = model_quant_config .get ('quant_method' , None ) == 'bitsandbytes'
2653+ quantization_config = model_kwargs .get ('quantization_config' , None )
2654+ if isinstance (quantization_config , BitsAndBytesConfig ):
2655+ use_bnb = True
2656+
26482657 model , tokenizer = get_model_tokenizer_from_repo (
26492658 model_dir ,
26502659 torch_dtype ,
@@ -2654,6 +2663,11 @@ def get_model_tokenizer_internvl(model_dir: str,
26542663 automodel_class = AutoModel ,
26552664 ** kwargs )
26562665
2666+ if use_bnb and kwargs .get ('is_training' ):
2667+ # patch: bnb backward shape mismatch bug
2668+ if model is not None and model .language_model is not None :
2669+ model .language_model .output .state .force_no_igemmlt = True
2670+
26572671 if model is not None :
26582672 _use_submodel_func (model , 'language_model' , ['get_input_embeddings' ])
26592673 fix_internvl_inplace_bug (model )
@@ -2685,7 +2699,7 @@ def _new_generate(*args, **kwargs):
26852699
26862700 @wraps (extract_feature )
26872701 def _new_extract_feature (pixel_values ):
2688- return extract_feature (pixel_values ).to (pixel_values .device )
2702+ return extract_feature (pixel_values ).to (pixel_values .device ). to ( pixel_values . dtype )
26892703
26902704 model .extract_feature = _new_extract_feature
26912705
0 commit comments