Skip to content

Commit eb7e551

Browse files
committed
Fix flash attention bug
1 parent e401bc7 commit eb7e551

File tree

2 files changed

+16
-8
lines changed

2 files changed

+16
-8
lines changed

internvl_chat/internvl/train/internvl_chat_finetune.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -489,8 +489,10 @@ def main():
489489
logger.info('Loading InternVLChatModel...')
490490
config = InternVLChatConfig.from_pretrained(model_args.model_name_or_path)
491491
config.vision_config.drop_path_rate = model_args.drop_path_rate
492-
config.llm_config.attn_implementation = 'flash_attention_2' # for InternLM
493-
config.llm_config._attn_implementation = 'flash_attention_2' # for LLaMA
492+
if 'internlm' in model_args.model_name_or_path.lower():
493+
config.llm_config.attn_implementation = 'flash_attention_2' # for InternLM
494+
else:
495+
config.llm_config._attn_implementation = 'flash_attention_2' # for LLaMA
494496
config.template = data_args.conv_style
495497
config.select_layer = model_args.vision_select_layer
496498
config.dynamic_image_size = data_args.dynamic_image_size
@@ -508,8 +510,10 @@ def main():
508510
model_args.vision_path, torch_dtype=torch.bfloat16, config=vision_config)
509511
logger.info('Loading LLaMA...')
510512
llm_config = AutoConfig.from_pretrained(model_args.llm_path, trust_remote_code=True)
511-
llm_config.attn_implementation = 'flash_attention_2' # for InternLM
512-
llm_config._attn_implementation = 'flash_attention_2' # for LLaMA
513+
if 'internlm' in model_args.llm_path.lower():
514+
llm_config.attn_implementation = 'flash_attention_2' # for InternLM
515+
else:
516+
llm_config._attn_implementation = 'flash_attention_2' # for LLaMA
513517
llm = AutoModelForCausalLM.from_pretrained(
514518
model_args.llm_path, torch_dtype=torch.bfloat16,
515519
config=llm_config, trust_remote_code=True)

internvl_chat/internvl/train/internvl_chat_pretrain.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -509,8 +509,10 @@ def main():
509509
logger.info('Loading InternVLChatModel...')
510510
config = InternVLChatConfig.from_pretrained(model_args.model_name_or_path)
511511
config.vision_config.drop_path_rate = model_args.drop_path_rate
512-
config.llm_config.attn_implementation = 'flash_attention_2' # for InternLM
513-
config.llm_config._attn_implementation = 'flash_attention_2' # for LLaMA
512+
if 'internlm' in model_args.model_name_or_path.lower():
513+
config.llm_config.attn_implementation = 'flash_attention_2' # for InternLM
514+
else:
515+
config.llm_config._attn_implementation = 'flash_attention_2' # for LLaMA
514516
config.template = data_args.conv_style
515517
config.select_layer = model_args.vision_select_layer
516518
config.dynamic_image_size = data_args.dynamic_image_size
@@ -528,8 +530,10 @@ def main():
528530
model_args.vision_path, torch_dtype=torch.bfloat16, config=vision_config)
529531
logger.info('Loading LLaMA...')
530532
llm_config = AutoConfig.from_pretrained(model_args.llm_path, trust_remote_code=True)
531-
llm_config.attn_implementation = 'flash_attention_2' # for InternLM
532-
llm_config._attn_implementation = 'flash_attention_2' # for LLaMA
533+
if 'internlm' in model_args.llm_path.lower():
534+
llm_config.attn_implementation = 'flash_attention_2' # for InternLM
535+
else:
536+
llm_config._attn_implementation = 'flash_attention_2' # for LLaMA
533537
llm = AutoModelForCausalLM.from_pretrained(
534538
model_args.llm_path, torch_dtype=torch.bfloat16,
535539
config=llm_config, trust_remote_code=True)

0 commit comments

Comments
 (0)