Skip to content

Commit 5ed5913

Browse files
committed
[bugfix] fix megatron flash_attn (flash_attention_3) (#5837)
1 parent 541223c commit 5ed5913

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

swift/megatron/model/mm_gpt/utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,8 @@ def __init__(self, config, ignore_init_model_cls=None):
6363
super().__init__(config)
6464
args = get_args()
6565
model_dir = args.model_info.model_dir
66-
kwargs = {'attn_impl': 'flash_attn'} if args.attention_backend.name == 'flash' else {}
66+
attn_impl = getattr(args, 'attn_impl', None) or 'flash_attn'
67+
kwargs = {'attn_impl': attn_impl} if args.attention_backend.name == 'flash' else {}
6768
ignore_init_model_cls = ignore_init_model_cls or []
6869
if not isinstance(ignore_init_model_cls, list):
6970
ignore_init_model_cls = [ignore_init_model_cls]

0 commit comments

Comments
 (0)