Skip to content

Commit 486db6b

Browse files
add scales to mha and transformer layer submodules
Signed-off-by: Asfiya Baig <asfiyab@nvidia.com>
1 parent 58920bb commit 486db6b

File tree

1 file changed

+106
-30
lines changed

1 file changed

+106
-30
lines changed

tests/test_onnx_export.py

Lines changed: 106 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -777,6 +777,33 @@ def test_export_core_attention(
777777
validate_result(fname, inp, model, atol=1e-2)
778778

779779

780+
def set_mha_scales(module,
781+
scale_factor_qkv: list=[448, 448],
782+
scale_factor_query: list=[112, 112],
783+
scale_factor_kv: list=[224, 224],
784+
scale_factor_proj: list=[448, 448]
785+
):
786+
if module.attention_type == "self":
787+
if module.input_layernorm:
788+
# LayernormLinear layer scale init
789+
set_layer_scale(module.layernorm_qkv, scale_factor_qkv)
790+
else:
791+
# Linear layer scale init
792+
set_layer_scale(module.qkv, scale_factor_qkv)
793+
else:
794+
if module.input_layernorm:
795+
# LayernormLinear layer scale init
796+
set_layer_scale(module.layernorm_query, scale_factor_query)
797+
else:
798+
# Linear layer scale init
799+
set_layer_scale(module.query_layer, scale_factor_query)
800+
801+
# Linear layer scale init
802+
set_layer_scale(module.key_value, scale_factor_kv)
803+
804+
# Linear layer scale init
805+
set_layer_scale(module.proj, scale_factor_proj)
806+
780807
test_configs_multihead_attention = [
781808
#"use_mask, attn_mask_type"
782809
(False, "causal"), # calls ScaledUpperTriangMaskedSoftmax
@@ -802,6 +829,10 @@ def test_export_core_attention(
802829
@pytest.mark.parametrize("precision", [torch.float32, torch.float16])
803830
@pytest.mark.parametrize("return_layernorm_output", [False])
804831
@pytest.mark.parametrize("input_layernorm, attention_type, fuse_qkv_params", test_configs_attention_type)
832+
@pytest.mark.parametrize("scale_factor_qkv", [[448, 448]])
833+
@pytest.mark.parametrize("scale_factor_query", [[112, 112]])
834+
@pytest.mark.parametrize("scale_factor_kv", [[224, 224]])
835+
@pytest.mark.parametrize("scale_factor_proj", [[448, 448]])
805836
def test_export_multihead_attention(
806837
use_fp8: bool,
807838
use_mask: bool,
@@ -810,7 +841,11 @@ def test_export_multihead_attention(
810841
return_layernorm_output: bool,
811842
input_layernorm: bool,
812843
attention_type: str,
813-
fuse_qkv_params: bool
844+
fuse_qkv_params: bool,
845+
scale_factor_qkv: list,
846+
scale_factor_query: list,
847+
scale_factor_kv: list,
848+
scale_factor_proj: list,
814849
):
815850
hidden_size = 256
816851
sequence_length = 128
@@ -851,21 +886,39 @@ def test_export_multihead_attention(
851886
input_ln_str = "_input-ln" if input_layernorm else ""
852887
fname = f"te.multihead_attention{fp8_str}{attn_mask_str}{attn_type_str}{input_ln_str}{fuse_qkv_str}{dtype_str}.onnx"
853888

854-
model = te.transformer.MultiHeadAttention(
855-
*attention_args,
856-
attn_mask_type=attn_mask_type,
857-
params_dtype=precision,
858-
return_layernorm_output=return_layernorm_output,
859-
input_layernorm=input_layernorm,
860-
attention_type=attention_type,
861-
fuse_qkv_params=fuse_qkv_params,
862-
).to(device='cuda')
863-
do_export(model, inp, fname, use_fp8, input_names=input_names)
864-
if not use_fp8:
865-
validate_result(fname, inp, model, atol=1e-3)
866-
elif precision != torch.float16:
867-
validate_result(fname, inp, model, atol=1e-2, is_fp8=use_fp8)
889+
with te.fp8_autocast(enabled=use_fp8, fp8_recipe=create_fp8_recipe()):
890+
model = te.transformer.MultiHeadAttention(
891+
*attention_args,
892+
attn_mask_type=attn_mask_type,
893+
params_dtype=precision,
894+
return_layernorm_output=return_layernorm_output,
895+
input_layernorm=input_layernorm,
896+
attention_type=attention_type,
897+
fuse_qkv_params=fuse_qkv_params,
898+
).to(device='cuda')
899+
if use_fp8:
900+
set_mha_scales(model,
901+
scale_factor_qkv,
902+
scale_factor_query,
903+
scale_factor_kv,
904+
scale_factor_proj)
868905

906+
do_export(model, inp, fname, use_fp8, input_names=input_names)
907+
if not use_fp8:
908+
validate_result(fname, inp, model, atol=1e-3)
909+
elif precision != torch.float16:
910+
validate_result(fname, inp, model, atol=5e-3, is_fp8=use_fp8)
911+
912+
def set_transformer_layer_scales(module,
913+
scales_self_attn: list,
914+
scales_inter_attn: list,
915+
scales_layernorm_mlp: list=[224, 224, 448, 448]):
916+
# set mha scales
917+
set_mha_scales(module.self_attention, *scales_self_attn)
918+
if module.layer_type == "decoder":
919+
set_mha_scales(module.inter_attention, *scales_inter_attn)
920+
# set layernorm mlp scales
921+
set_layer_scale(module.layernorm_mlp, scales_layernorm_mlp, num_gemms=2)
869922

870923
@pytest.mark.parametrize("use_fp8", [False, True])
871924
@pytest.mark.parametrize("use_mask, attn_mask_type", test_configs_multihead_attention)
@@ -876,14 +929,24 @@ def test_export_multihead_attention(
876929
@pytest.mark.parametrize("precision", [torch.float32, torch.float16])
877930
@pytest.mark.parametrize("fuse_qkv_params", [False, True])
878931
@pytest.mark.parametrize("apply_query_key_layer_scaling", [True, False])
932+
@pytest.mark.parametrize("scale_factor_qkv", [[448, 448]])
933+
@pytest.mark.parametrize("scale_factor_query", [[112, 112]])
934+
@pytest.mark.parametrize("scale_factor_kv", [[224, 224]])
935+
@pytest.mark.parametrize("scale_factor_proj", [[448, 448]])
936+
@pytest.mark.parametrize("scale_factor_layernorm_mlp", [[224, 224, 448, 448]])
879937
def test_export_transformer_layer(
880938
use_fp8: bool,
881939
use_mask: bool,
882940
attn_mask_type: str,
883941
output_layernorm: bool,
884942
precision: torch.dtype,
885943
fuse_qkv_params: bool,
886-
apply_query_key_layer_scaling: bool
944+
apply_query_key_layer_scaling: bool,
945+
scale_factor_qkv: list,
946+
scale_factor_query: list,
947+
scale_factor_kv: list,
948+
scale_factor_proj: list,
949+
scale_factor_layernorm_mlp: list,
887950
):
888951
# Layer configuration
889952
hidden_size = 64
@@ -909,17 +972,30 @@ def test_export_transformer_layer(
909972
attn_mask_str = get_attn_mask_str(use_mask, attn_mask_type)
910973
fname = f"te.transformer_layer{fp8_str}{attn_mask_str}{fuse_qkv_params_str}{qk_scaling_str}{high_prec_str}.onnx"
911974

912-
model = te.TransformerLayer(
913-
hidden_size,
914-
ffn_hidden_size,
915-
num_attention_heads,
916-
self_attn_mask_type=attn_mask_type,
917-
output_layernorm=output_layernorm,
918-
params_dtype=precision,
919-
fuse_qkv_params=fuse_qkv_params,
920-
apply_query_key_layer_scaling=apply_query_key_layer_scaling).to(device='cuda')
921-
do_export(model, inp, fname, use_fp8)
922-
if not use_fp8:
923-
validate_result(fname, inp, model, atol=1e-3)
924-
elif precision != torch.float16:
925-
validate_result(fname, inp, model, atol=5e-1, is_fp8=use_fp8)
975+
with te.fp8_autocast(enabled=use_fp8, fp8_recipe=create_fp8_recipe()):
976+
model = te.TransformerLayer(
977+
hidden_size,
978+
ffn_hidden_size,
979+
num_attention_heads,
980+
self_attn_mask_type=attn_mask_type,
981+
output_layernorm=output_layernorm,
982+
params_dtype=precision,
983+
fuse_qkv_params=fuse_qkv_params,
984+
apply_query_key_layer_scaling=apply_query_key_layer_scaling).to(device='cuda')
985+
if use_fp8:
986+
mha_scales = [
987+
scale_factor_qkv,
988+
scale_factor_query,
989+
scale_factor_kv,
990+
scale_factor_proj
991+
]
992+
set_transformer_layer_scales(model,
993+
scales_self_attn=mha_scales,
994+
scales_inter_attn=mha_scales,
995+
scales_layernorm_mlp=scale_factor_layernorm_mlp)
996+
997+
do_export(model, inp, fname, use_fp8)
998+
if not use_fp8:
999+
validate_result(fname, inp, model, atol=1e-3)
1000+
elif precision != torch.float16:
1001+
validate_result(fname, inp, model, atol=1e-2, is_fp8=use_fp8)

0 commit comments

Comments
 (0)