1414import onnxruntime as ort
1515import torch
1616from torch import nn as nn
17- from typing import Union , Tuple
17+ from typing import Union , Tuple , List
1818import transformer_engine .pytorch as te
1919from transformer_engine .common import recipe
2020import transformer_engine_extensions as tex
@@ -80,17 +80,19 @@ def to_numpy(tensor):
8080 return tensor .cpu ().numpy ()
8181
8282
83- def set_layer_scale (module : torch .nn .Module , scales : float , num_gemms : int = 1 ):
84- module .fp8_init (num_gemms = num_gemms )
85- assert len (scales ) == num_gemms * 2 , "Each gemm should be accompanied by 2 scales"
83+ def set_layer_scale (module : torch .nn .Module , scales : List [float ]):
84+ module .fp8_init ()
8685 num_fp8_tensors = len (scales )
8786 scale = torch .ones (num_fp8_tensors , dtype = torch .float32 , device = "cuda" )
8887 scale_inv = torch .ones (num_fp8_tensors , dtype = torch .float32 , device = "cuda" )
88+ amax_history_len = module .fp8_meta ["recipe" ].amax_history_len
89+ amax_history = torch .zeros (amax_history_len , num_fp8_tensors , dtype = torch .float32 , device = "cuda" )
8990 for i , s in enumerate (scales ):
9091 scale [i ] *= s
9192 scale_inv [i ] /= s
9293 module .fp8_meta ["scaling_fwd" ].scale = scale
9394 module .fp8_meta ["scaling_fwd" ].scale_inv = scale_inv
95+ module .fp8_meta ["scaling_fwd" ].amax_history = amax_history
9496
9597
9698def te_infer (model : torch .nn .Module , inps : Union [Tuple [torch .tensor ], torch .tensor ], is_fp8 : bool ):
@@ -546,7 +548,7 @@ def forward(self, inp, mask):
546548 validate_result (fname , inp , model , atol = 1e-3 )
547549
548550
549- @pytest .mark .parametrize ("scale_factor " , [[448 , 448 ]])
551+ @pytest .mark .parametrize ("scale_factors " , [[448 , 448 ]])
550552@pytest .mark .parametrize ("use_fp8" , [False , True ])
551553# Returning the bias is a TE fusion optimization we don't care about.
552554@pytest .mark .parametrize ("return_bias" , [False ])
@@ -562,7 +564,7 @@ def forward(self, inp, mask):
562564 # (torch.bfloat16, True),
563565])
564566def test_export_linear (
565- scale_factor : list ,
567+ scale_factors : List [ float ] ,
566568 use_fp8 : bool ,
567569 use_bias : bool ,
568570 return_bias : bool ,
@@ -609,7 +611,7 @@ def forward(self, inp):
609611 precision
610612 ).to (device = 'cuda' )
611613 if use_fp8 :
612- set_layer_scale (model .linear , scale_factor )
614+ set_layer_scale (model .linear , scale_factors )
613615 do_export (model , inp , fname , use_fp8 )
614616
615617 if precision in (torch .bfloat16 , ):
@@ -620,7 +622,7 @@ def forward(self, inp):
620622 validate_result (fname , inp , model , atol = 5e-4 , is_fp8 = use_fp8 )
621623
622624
623- @pytest .mark .parametrize ("scale_factor " , [[448 , 448 ]])
625+ @pytest .mark .parametrize ("scale_factors " , [[448 , 448 ]])
624626@pytest .mark .parametrize ("use_fp8" , [False , True ])
625627# Returning the bias is a TE fusion optimization we don't care about.
626628@pytest .mark .parametrize ("return_bias" , [False ])
@@ -633,7 +635,7 @@ def forward(self, inp):
633635 (torch .float16 , False ),
634636])
635637def test_export_layernorm_linear (
636- scale_factor : list ,
638+ scale_factors : List [ float ] ,
637639 use_fp8 : bool ,
638640 use_bias : bool ,
639641 return_bias : bool ,
@@ -660,15 +662,15 @@ def test_export_layernorm_linear(
660662 params_dtype = precision ,
661663 ).to (device = 'cuda' )
662664 if use_fp8 :
663- set_layer_scale (model , scale_factor )
665+ set_layer_scale (model , scale_factors )
664666 do_export (model , inp , fname , use_fp8 )
665667 if not use_fp8 :
666668 validate_result (fname , inp , model , atol = 1e-3 )
667669 elif precision not in (torch .bfloat16 ,):
668670 validate_result (fname , inp , model , atol = 1e-3 , is_fp8 = use_fp8 )
669671
670672
671- @pytest .mark .parametrize ("scale_factor " , [[224 , 224 , 448 , 448 ]])
673+ @pytest .mark .parametrize ("scale_factors " , [[224 , 224 , 448 , 448 ]])
672674@pytest .mark .parametrize ("use_fp8" , [False , True ])
673675# Returning the bias is a TE fusion optimization we don't care about.
674676@pytest .mark .parametrize ("return_bias" , [False ])
@@ -681,7 +683,7 @@ def test_export_layernorm_linear(
681683 (torch .float16 , False ),
682684])
683685def test_export_layernorm_mlp (
684- scale_factor : list ,
686+ scale_factors : List [ float ] ,
685687 use_fp8 : bool ,
686688 use_bias : bool ,
687689 return_bias : bool ,
@@ -709,7 +711,7 @@ def test_export_layernorm_mlp(
709711 params_dtype = precision ,
710712 ).to (device = 'cuda' )
711713 if use_fp8 :
712- set_layer_scale (model , scale_factor , num_gemms = 2 )
714+ set_layer_scale (model , scale_factors )
713715 do_export (model , inp , fname , use_fp8 )
714716 if not use_fp8 :
715717 validate_result (fname , inp , model , atol = 5e-4 )
@@ -778,10 +780,10 @@ def test_export_core_attention(
778780
779781
780782def set_mha_scales (module ,
781- scale_factor_qkv : list = [448 , 448 ],
782- scale_factor_query : list = [112 , 112 ],
783- scale_factor_kv : list = [224 , 224 ],
784- scale_factor_proj : list = [448 , 448 ]
783+ scale_factor_qkv : List [ float ] = [448 , 448 ],
784+ scale_factor_query : List [ float ] = [112 , 112 ],
785+ scale_factor_kv : List [ float ] = [224 , 224 ],
786+ scale_factor_proj : List [ float ] = [448 , 448 ]
785787):
786788 if module .attention_type == "self" :
787789 if module .input_layernorm :
@@ -842,10 +844,10 @@ def test_export_multihead_attention(
842844 input_layernorm : bool ,
843845 attention_type : str ,
844846 fuse_qkv_params : bool ,
845- scale_factor_qkv : list ,
846- scale_factor_query : list ,
847- scale_factor_kv : list ,
848- scale_factor_proj : list ,
847+ scale_factor_qkv : List [ float ] ,
848+ scale_factor_query : List [ float ] ,
849+ scale_factor_kv : List [ float ] ,
850+ scale_factor_proj : List [ float ] ,
849851):
850852 hidden_size = 256
851853 sequence_length = 128
@@ -918,7 +920,7 @@ def set_transformer_layer_scales(module,
918920 if module .layer_type == "decoder" :
919921 set_mha_scales (module .inter_attention , * scales_inter_attn )
920922 # set layernorm mlp scales
921- set_layer_scale (module .layernorm_mlp , scales_layernorm_mlp , num_gemms = 2 )
923+ set_layer_scale (module .layernorm_mlp , scales_layernorm_mlp )
922924
923925@pytest .mark .parametrize ("use_fp8" , [False , True ])
924926@pytest .mark .parametrize ("use_mask, attn_mask_type" , test_configs_multihead_attention )
@@ -942,11 +944,11 @@ def test_export_transformer_layer(
942944 precision : torch .dtype ,
943945 fuse_qkv_params : bool ,
944946 apply_query_key_layer_scaling : bool ,
945- scale_factor_qkv : list ,
946- scale_factor_query : list ,
947- scale_factor_kv : list ,
948- scale_factor_proj : list ,
949- scale_factor_layernorm_mlp : list ,
947+ scale_factor_qkv : List [ float ] ,
948+ scale_factor_query : List [ float ] ,
949+ scale_factor_kv : List [ float ] ,
950+ scale_factor_proj : List [ float ] ,
951+ scale_factor_layernorm_mlp : List [ float ] ,
950952):
951953 # Layer configuration
952954 hidden_size = 64
0 commit comments