@@ -489,8 +489,10 @@ def main():
489489 logger .info ('Loading InternVLChatModel...' )
490490 config = InternVLChatConfig .from_pretrained (model_args .model_name_or_path )
491491 config .vision_config .drop_path_rate = model_args .drop_path_rate
492- config .llm_config .attn_implementation = 'flash_attention_2' # for InternLM
493- config .llm_config ._attn_implementation = 'flash_attention_2' # for LLaMA
492+ if 'internlm' in model_args .model_name_or_path .lower ():
493+ config .llm_config .attn_implementation = 'flash_attention_2' # for InternLM
494+ else :
495+ config .llm_config ._attn_implementation = 'flash_attention_2' # for LLaMA
494496 config .template = data_args .conv_style
495497 config .select_layer = model_args .vision_select_layer
496498 config .dynamic_image_size = data_args .dynamic_image_size
@@ -508,8 +510,10 @@ def main():
508510 model_args .vision_path , torch_dtype = torch .bfloat16 , config = vision_config )
509511 logger .info ('Loading LLaMA...' )
510512 llm_config = AutoConfig .from_pretrained (model_args .llm_path , trust_remote_code = True )
511- llm_config .attn_implementation = 'flash_attention_2' # for InternLM
512- llm_config ._attn_implementation = 'flash_attention_2' # for LLaMA
513+ if 'internlm' in model_args .llm_path .lower ():
514+ llm_config .attn_implementation = 'flash_attention_2' # for InternLM
515+ else :
516+ llm_config ._attn_implementation = 'flash_attention_2' # for LLaMA
513517 llm = AutoModelForCausalLM .from_pretrained (
514518 model_args .llm_path , torch_dtype = torch .bfloat16 ,
515519 config = llm_config , trust_remote_code = True )
0 commit comments