@@ -777,6 +777,33 @@ def test_export_core_attention(
777777 validate_result (fname , inp , model , atol = 1e-2 )
778778
779779
780+ def set_mha_scales (module ,
781+ scale_factor_qkv : list = [448 , 448 ],
782+ scale_factor_query : list = [112 , 112 ],
783+ scale_factor_kv : list = [224 , 224 ],
784+ scale_factor_proj : list = [448 , 448 ]
785+ ):
786+ if module .attention_type == "self" :
787+ if module .input_layernorm :
788+ # LayernormLinear layer scale init
789+ set_layer_scale (module .layernorm_qkv , scale_factor_qkv )
790+ else :
791+ # Linear layer scale init
792+ set_layer_scale (module .qkv , scale_factor_qkv )
793+ else :
794+ if module .input_layernorm :
795+ # LayernormLinear layer scale init
796+ set_layer_scale (module .layernorm_query , scale_factor_query )
797+ else :
798+ # Linear layer scale init
799+ set_layer_scale (module .query_layer , scale_factor_query )
800+
801+ # Linear layer scale init
802+ set_layer_scale (module .key_value , scale_factor_kv )
803+
804+ # Linear layer scale init
805+ set_layer_scale (module .proj , scale_factor_proj )
806+
780807test_configs_multihead_attention = [
781808 #"use_mask, attn_mask_type"
782809 (False , "causal" ), # calls ScaledUpperTriangMaskedSoftmax
@@ -802,6 +829,10 @@ def test_export_core_attention(
802829@pytest .mark .parametrize ("precision" , [torch .float32 , torch .float16 ])
803830@pytest .mark .parametrize ("return_layernorm_output" , [False ])
804831@pytest .mark .parametrize ("input_layernorm, attention_type, fuse_qkv_params" , test_configs_attention_type )
832+ @pytest .mark .parametrize ("scale_factor_qkv" , [[448 , 448 ]])
833+ @pytest .mark .parametrize ("scale_factor_query" , [[112 , 112 ]])
834+ @pytest .mark .parametrize ("scale_factor_kv" , [[224 , 224 ]])
835+ @pytest .mark .parametrize ("scale_factor_proj" , [[448 , 448 ]])
805836def test_export_multihead_attention (
806837 use_fp8 : bool ,
807838 use_mask : bool ,
@@ -810,7 +841,11 @@ def test_export_multihead_attention(
810841 return_layernorm_output : bool ,
811842 input_layernorm : bool ,
812843 attention_type : str ,
813- fuse_qkv_params : bool
844+ fuse_qkv_params : bool ,
845+ scale_factor_qkv : list ,
846+ scale_factor_query : list ,
847+ scale_factor_kv : list ,
848+ scale_factor_proj : list ,
814849):
815850 hidden_size = 256
816851 sequence_length = 128
@@ -851,21 +886,39 @@ def test_export_multihead_attention(
851886 input_ln_str = "_input-ln" if input_layernorm else ""
852887 fname = f"te.multihead_attention{ fp8_str } { attn_mask_str } { attn_type_str } { input_ln_str } { fuse_qkv_str } { dtype_str } .onnx"
853888
854- model = te .transformer .MultiHeadAttention (
855- * attention_args ,
856- attn_mask_type = attn_mask_type ,
857- params_dtype = precision ,
858- return_layernorm_output = return_layernorm_output ,
859- input_layernorm = input_layernorm ,
860- attention_type = attention_type ,
861- fuse_qkv_params = fuse_qkv_params ,
862- ).to (device = 'cuda' )
863- do_export (model , inp , fname , use_fp8 , input_names = input_names )
864- if not use_fp8 :
865- validate_result (fname , inp , model , atol = 1e-3 )
866- elif precision != torch .float16 :
867- validate_result (fname , inp , model , atol = 1e-2 , is_fp8 = use_fp8 )
889+ with te .fp8_autocast (enabled = use_fp8 , fp8_recipe = create_fp8_recipe ()):
890+ model = te .transformer .MultiHeadAttention (
891+ * attention_args ,
892+ attn_mask_type = attn_mask_type ,
893+ params_dtype = precision ,
894+ return_layernorm_output = return_layernorm_output ,
895+ input_layernorm = input_layernorm ,
896+ attention_type = attention_type ,
897+ fuse_qkv_params = fuse_qkv_params ,
898+ ).to (device = 'cuda' )
899+ if use_fp8 :
900+ set_mha_scales (model ,
901+ scale_factor_qkv ,
902+ scale_factor_query ,
903+ scale_factor_kv ,
904+ scale_factor_proj )
868905
906+ do_export (model , inp , fname , use_fp8 , input_names = input_names )
907+ if not use_fp8 :
908+ validate_result (fname , inp , model , atol = 1e-3 )
909+ elif precision != torch .float16 :
910+ validate_result (fname , inp , model , atol = 5e-3 , is_fp8 = use_fp8 )
911+
912+ def set_transformer_layer_scales (module ,
913+ scales_self_attn : list ,
914+ scales_inter_attn : list ,
915+ scales_layernorm_mlp : list = [224 , 224 , 448 , 448 ]):
916+ # set mha scales
917+ set_mha_scales (module .self_attention , * scales_self_attn )
918+ if module .layer_type == "decoder" :
919+ set_mha_scales (module .inter_attention , * scales_inter_attn )
920+ # set layernorm mlp scales
921+ set_layer_scale (module .layernorm_mlp , scales_layernorm_mlp , num_gemms = 2 )
869922
870923@pytest .mark .parametrize ("use_fp8" , [False , True ])
871924@pytest .mark .parametrize ("use_mask, attn_mask_type" , test_configs_multihead_attention )
@@ -876,14 +929,24 @@ def test_export_multihead_attention(
876929@pytest .mark .parametrize ("precision" , [torch .float32 , torch .float16 ])
877930@pytest .mark .parametrize ("fuse_qkv_params" , [False , True ])
878931@pytest .mark .parametrize ("apply_query_key_layer_scaling" , [True , False ])
932+ @pytest .mark .parametrize ("scale_factor_qkv" , [[448 , 448 ]])
933+ @pytest .mark .parametrize ("scale_factor_query" , [[112 , 112 ]])
934+ @pytest .mark .parametrize ("scale_factor_kv" , [[224 , 224 ]])
935+ @pytest .mark .parametrize ("scale_factor_proj" , [[448 , 448 ]])
936+ @pytest .mark .parametrize ("scale_factor_layernorm_mlp" , [[224 , 224 , 448 , 448 ]])
879937def test_export_transformer_layer (
880938 use_fp8 : bool ,
881939 use_mask : bool ,
882940 attn_mask_type : str ,
883941 output_layernorm : bool ,
884942 precision : torch .dtype ,
885943 fuse_qkv_params : bool ,
886- apply_query_key_layer_scaling : bool
944+ apply_query_key_layer_scaling : bool ,
945+ scale_factor_qkv : list ,
946+ scale_factor_query : list ,
947+ scale_factor_kv : list ,
948+ scale_factor_proj : list ,
949+ scale_factor_layernorm_mlp : list ,
887950):
888951 # Layer configuration
889952 hidden_size = 64
@@ -909,17 +972,30 @@ def test_export_transformer_layer(
909972 attn_mask_str = get_attn_mask_str (use_mask , attn_mask_type )
910973 fname = f"te.transformer_layer{ fp8_str } { attn_mask_str } { fuse_qkv_params_str } { qk_scaling_str } { high_prec_str } .onnx"
911974
912- model = te .TransformerLayer (
913- hidden_size ,
914- ffn_hidden_size ,
915- num_attention_heads ,
916- self_attn_mask_type = attn_mask_type ,
917- output_layernorm = output_layernorm ,
918- params_dtype = precision ,
919- fuse_qkv_params = fuse_qkv_params ,
920- apply_query_key_layer_scaling = apply_query_key_layer_scaling ).to (device = 'cuda' )
921- do_export (model , inp , fname , use_fp8 )
922- if not use_fp8 :
923- validate_result (fname , inp , model , atol = 1e-3 )
924- elif precision != torch .float16 :
925- validate_result (fname , inp , model , atol = 5e-1 , is_fp8 = use_fp8 )
975+ with te .fp8_autocast (enabled = use_fp8 , fp8_recipe = create_fp8_recipe ()):
976+ model = te .TransformerLayer (
977+ hidden_size ,
978+ ffn_hidden_size ,
979+ num_attention_heads ,
980+ self_attn_mask_type = attn_mask_type ,
981+ output_layernorm = output_layernorm ,
982+ params_dtype = precision ,
983+ fuse_qkv_params = fuse_qkv_params ,
984+ apply_query_key_layer_scaling = apply_query_key_layer_scaling ).to (device = 'cuda' )
985+ if use_fp8 :
986+ mha_scales = [
987+ scale_factor_qkv ,
988+ scale_factor_query ,
989+ scale_factor_kv ,
990+ scale_factor_proj
991+ ]
992+ set_transformer_layer_scales (model ,
993+ scales_self_attn = mha_scales ,
994+ scales_inter_attn = mha_scales ,
995+ scales_layernorm_mlp = scale_factor_layernorm_mlp )
996+
997+ do_export (model , inp , fname , use_fp8 )
998+ if not use_fp8 :
999+ validate_result (fname , inp , model , atol = 1e-3 )
1000+ elif precision != torch .float16 :
1001+ validate_result (fname , inp , model , atol = 1e-2 , is_fp8 = use_fp8 )
0 commit comments