Skip to content

Commit 893aaa5

Browse files
committed
assert Flash Attention doesn't get arbitrary mask
1 parent c988cf2 commit 893aaa5

File tree

1 file changed

+4
-0
lines changed

1 file changed

+4
-0
lines changed

megatron/arguments.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -378,6 +378,10 @@ def validate_args(args, defaults={}):
378378
if args.sequence_parallel:
379379
args.async_tensor_model_parallel_allreduce = False
380380

381+
if args.use_flash_attn:
382+
assert not args.reset_attention_mask, \
383+
"Flash Attention doesn't support arbitrary attention masks. Please turn off reset-attention-mask"
384+
381385
_print_args(args)
382386
return args
383387

0 commit comments

Comments
 (0)