Skip to content

Commit cef448b

Browse files
authored
fix Internvl-int8 sft bug (#932)
1 parent bdc8f54 commit cef448b

File tree

1 file changed

+16
-2
lines changed

1 file changed

+16
-2
lines changed

swift/llm/utils/model.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2583,6 +2583,7 @@ def get_model_tokenizer_deepseek2(model_dir: str,
25832583
model, tokenizer = get_model_tokenizer_from_repo(
25842584
model_dir, torch_dtype, model_kwargs, load_model, model_config=model_config, **kwargs)
25852585
if model is not None:
2586+
model.generation_config.pad_token_id = model.generation_config.eos_token_id
25862587
# fix dtype bug
25872588
model.generation_config.pad_token_id = model.generation_config.eos_token_id
25882589
mlp_cls = model.model.layers[1].mlp.__class__
@@ -2640,11 +2641,19 @@ def get_model_tokenizer_internvl(model_dir: str,
26402641
model_kwargs: Dict[str, Any],
26412642
load_model: bool = True,
26422643
**kwargs):
2643-
26442644
model_config = AutoConfig.from_pretrained(model_dir, trust_remote_code=True)
26452645
use_flash_attn = kwargs.pop('use_flash_attn', False)
26462646
model_config.vision_config.use_flash_attn = use_flash_attn
26472647
model_config.llm_config.attn_implementation = 'flash_attention_2' if use_flash_attn else 'eager'
2648+
model_quant_config = getattr(model_config, 'quantization_config', None)
2649+
2650+
use_bnb = False
2651+
if model_quant_config is not None:
2652+
use_bnb = model_quant_config.get('quant_method', None) == 'bitsandbytes'
2653+
quantization_config = model_kwargs.get('quantization_config', None)
2654+
if isinstance(quantization_config, BitsAndBytesConfig):
2655+
use_bnb = True
2656+
26482657
model, tokenizer = get_model_tokenizer_from_repo(
26492658
model_dir,
26502659
torch_dtype,
@@ -2654,6 +2663,11 @@ def get_model_tokenizer_internvl(model_dir: str,
26542663
automodel_class=AutoModel,
26552664
**kwargs)
26562665

2666+
if use_bnb and kwargs.get('is_training'):
2667+
# patch: bnb backward shape mismatch bug
2668+
if model is not None and model.language_model is not None:
2669+
model.language_model.output.state.force_no_igemmlt = True
2670+
26572671
if model is not None:
26582672
_use_submodel_func(model, 'language_model', ['get_input_embeddings'])
26592673
fix_internvl_inplace_bug(model)
@@ -2685,7 +2699,7 @@ def _new_generate(*args, **kwargs):
26852699

26862700
@wraps(extract_feature)
26872701
def _new_extract_feature(pixel_values):
2688-
return extract_feature(pixel_values).to(pixel_values.device)
2702+
return extract_feature(pixel_values).to(pixel_values.device).to(pixel_values.dtype)
26892703

26902704
model.extract_feature = _new_extract_feature
26912705

0 commit comments

Comments
 (0)