44
55import deepspeed
66import torch
7+ from parameterized import parameterized
78from torch import nn
89import torch .nn .functional as F
910
11+ from megatron .enums import AttnMaskType
1012from megatron .model .fused_layer_norm import MixedFusedLayerNorm
1113from packaging import version
1214
1315from megatron import initialize_megatron , get_args , get_tokenizer , global_vars
14- from megatron .model .fused_softmax import ScaledMaskedSoftmax
16+ from megatron .model .fused_softmax import ScaledMaskedSoftmax , FusedScaleMaskSoftmax
17+ from megatron .model .utils import attention_mask_func
1518from megatron .testing_utils import TestCasePlus , mockenv_context , flatten_arguments , torch_assert_equal , \
1619 torch_assert_close , require_torch_bf16
1720from megatron .training import setup_model_and_optimizer
@@ -366,7 +369,8 @@ def test_fused_layer_norm(self):
366369
367370 torch_assert_equal (mfln_output , torch_layer_norm_output )
368371
369- def test_fused_masked_softmax (self ):
372+ @parameterized .expand ([(attn_mask_type ,) for attn_mask_type in AttnMaskType ])
373+ def test_fused_masked_softmax (self , attn_mask_type : AttnMaskType ):
370374 command_args = get_default_args (self .test_file_dir_str )
371375
372376 with patch ('sys.argv' , flatten_arguments (command_args )):
@@ -382,30 +386,54 @@ def test_fused_masked_softmax(self):
382386 device = "cuda" ,
383387 dtype = args .params_dtype
384388 )
385- dummy_attention_mask = torch .randn (
386- args .micro_batch_size ,
387- 1 , # `args.num_attention_heads` not implemented in our cuda kernel
388- args .seq_length ,
389- args .seq_length ,
390- device = "cuda" ,
391- dtype = args .params_dtype
392- ) < 0
389+ if attn_mask_type == AttnMaskType .causal :
390+ dummy_attention_mask = None
391+ else :
392+ dummy_attention_mask = torch .randn (
393+ args .micro_batch_size ,
394+ 1 , # `args.num_attention_heads` not implemented in our cuda kernel
395+ args .seq_length ,
396+ args .seq_length ,
397+ device = "cuda" ,
398+ dtype = args .params_dtype
399+ ) < 0
393400 scale = torch .rand (())
394401
395- fused_scaled_softmax = ScaledMaskedSoftmax
396-
397- fused_output = fused_scaled_softmax .apply (dummy_input , dummy_attention_mask , scale )
402+ fused_scaled_softmax = FusedScaleMaskSoftmax (
403+ input_in_fp16 = args .params_dtype == torch .float16 ,
404+ input_in_bf16 = args .params_dtype == torch .bfloat16 ,
405+ attn_mask_type = attn_mask_type ,
406+ scaled_masked_softmax_fusion = True ,
407+ mask_func = attention_mask_func ,
408+ softmax_in_fp32 = True ,
409+ scale = scale ,
410+ )
411+ unfused_scaled_softmax = FusedScaleMaskSoftmax (
412+ input_in_fp16 = args .params_dtype == torch .float16 ,
413+ input_in_bf16 = args .params_dtype == torch .bfloat16 ,
414+ attn_mask_type = attn_mask_type ,
415+ scaled_masked_softmax_fusion = False ,
416+ mask_func = attention_mask_func ,
417+ softmax_in_fp32 = True ,
418+ scale = scale ,
419+ )
398420
399- # mimick the same via torch
400- output = scale * dummy_input
401- output = output . masked_fill ( dummy_attention_mask , torch . finfo ( args . params_dtype ). min )
402- output = F . softmax ( output , dim = - 1 )
421+ self . assertTrue ( fused_scaled_softmax . is_kernel_available ( dummy_attention_mask , * dummy_input . size ()))
422+ fused_output = fused_scaled_softmax ( dummy_input , dummy_attention_mask )
423+ self . assertFalse ( unfused_scaled_softmax . is_kernel_available ( dummy_attention_mask , * dummy_input . size ()) )
424+ unfused_output = unfused_scaled_softmax ( dummy_input , dummy_attention_mask )
403425
404426 # Test that the nonzeros are the same with the mask
405427 for i in range (args .num_attention_heads ):
406- torch_assert_equal (torch .nonzero (fused_output [:, i ]), torch .nonzero (~ dummy_attention_mask [:, 0 ]))
428+ if dummy_attention_mask is None :
429+ # Make sure it's causal, values in the lower triangle should be not zero.
430+ non_zero_values = torch .tril (torch .ones_like (fused_output [:, i ]))
431+ torch_assert_equal (torch .nonzero (fused_output [:, i ]), torch .nonzero (non_zero_values ))
432+ else :
433+ torch_assert_equal (torch .nonzero (fused_output [:, i ]), torch .nonzero (~ dummy_attention_mask [:, 0 ]))
434+
407435 # Cuda kernel produces slightly different results
408- torch_assert_close (fused_output , output )
436+ torch_assert_close (fused_output , unfused_output )
409437
410438
411439 def test_non_causal_decoder_model_with_packed_input_passed_with_attention_mask_is_not_causal_across_segments (self ):
0 commit comments