Skip to content

Commit e83246e

Browse files
address comments
1. replace variable scale_factor with scale_factors 2. Update type hints for scale_factors to be List[float] 3. Remove use of num_gemms param and add amax_history assignment Signed-off-by: Asfiya Baig <asfiyab@nvidia.com>
1 parent 486db6b commit e83246e

File tree

1 file changed

+29
-27
lines changed

1 file changed

+29
-27
lines changed

tests/test_onnx_export.py

Lines changed: 29 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
import onnxruntime as ort
1515
import torch
1616
from torch import nn as nn
17-
from typing import Union, Tuple
17+
from typing import Union, Tuple, List
1818
import transformer_engine.pytorch as te
1919
from transformer_engine.common import recipe
2020
import transformer_engine_extensions as tex
@@ -80,17 +80,19 @@ def to_numpy(tensor):
8080
return tensor.cpu().numpy()
8181

8282

83-
def set_layer_scale(module: torch.nn.Module, scales: float, num_gemms: int=1):
84-
module.fp8_init(num_gemms=num_gemms)
85-
assert len(scales) == num_gemms * 2, "Each gemm should be accompanied by 2 scales"
83+
def set_layer_scale(module: torch.nn.Module, scales: List[float]):
84+
module.fp8_init()
8685
num_fp8_tensors = len(scales)
8786
scale = torch.ones(num_fp8_tensors, dtype=torch.float32, device="cuda")
8887
scale_inv = torch.ones(num_fp8_tensors, dtype=torch.float32, device="cuda")
88+
amax_history_len = module.fp8_meta["recipe"].amax_history_len
89+
amax_history = torch.zeros(amax_history_len, num_fp8_tensors, dtype=torch.float32, device="cuda")
8990
for i, s in enumerate(scales):
9091
scale[i] *= s
9192
scale_inv[i] /= s
9293
module.fp8_meta["scaling_fwd"].scale = scale
9394
module.fp8_meta["scaling_fwd"].scale_inv = scale_inv
95+
module.fp8_meta["scaling_fwd"].amax_history = amax_history
9496

9597

9698
def te_infer(model: torch.nn.Module, inps: Union[Tuple[torch.tensor], torch.tensor], is_fp8: bool):
@@ -546,7 +548,7 @@ def forward(self, inp, mask):
546548
validate_result(fname, inp, model, atol=1e-3)
547549

548550

549-
@pytest.mark.parametrize("scale_factor", [[448, 448]])
551+
@pytest.mark.parametrize("scale_factors", [[448, 448]])
550552
@pytest.mark.parametrize("use_fp8", [False, True])
551553
# Returning the bias is a TE fusion optimization we don't care about.
552554
@pytest.mark.parametrize("return_bias", [False])
@@ -562,7 +564,7 @@ def forward(self, inp, mask):
562564
# (torch.bfloat16, True),
563565
])
564566
def test_export_linear(
565-
scale_factor: list,
567+
scale_factors: List[float],
566568
use_fp8: bool,
567569
use_bias: bool,
568570
return_bias: bool,
@@ -609,7 +611,7 @@ def forward(self, inp):
609611
precision
610612
).to(device='cuda')
611613
if use_fp8:
612-
set_layer_scale(model.linear, scale_factor)
614+
set_layer_scale(model.linear, scale_factors)
613615
do_export(model, inp, fname, use_fp8)
614616

615617
if precision in (torch.bfloat16, ):
@@ -620,7 +622,7 @@ def forward(self, inp):
620622
validate_result(fname, inp, model, atol=5e-4, is_fp8=use_fp8)
621623

622624

623-
@pytest.mark.parametrize("scale_factor", [[448, 448]])
625+
@pytest.mark.parametrize("scale_factors", [[448, 448]])
624626
@pytest.mark.parametrize("use_fp8", [False, True])
625627
# Returning the bias is a TE fusion optimization we don't care about.
626628
@pytest.mark.parametrize("return_bias", [False])
@@ -633,7 +635,7 @@ def forward(self, inp):
633635
(torch.float16, False),
634636
])
635637
def test_export_layernorm_linear(
636-
scale_factor: list,
638+
scale_factors: List[float],
637639
use_fp8: bool,
638640
use_bias: bool,
639641
return_bias: bool,
@@ -660,15 +662,15 @@ def test_export_layernorm_linear(
660662
params_dtype=precision,
661663
).to(device='cuda')
662664
if use_fp8:
663-
set_layer_scale(model, scale_factor)
665+
set_layer_scale(model, scale_factors)
664666
do_export(model, inp, fname, use_fp8)
665667
if not use_fp8:
666668
validate_result(fname, inp, model, atol=1e-3)
667669
elif precision not in (torch.bfloat16,):
668670
validate_result(fname, inp, model, atol=1e-3, is_fp8=use_fp8)
669671

670672

671-
@pytest.mark.parametrize("scale_factor", [[224, 224, 448, 448]])
673+
@pytest.mark.parametrize("scale_factors", [[224, 224, 448, 448]])
672674
@pytest.mark.parametrize("use_fp8", [False, True])
673675
# Returning the bias is a TE fusion optimization we don't care about.
674676
@pytest.mark.parametrize("return_bias", [False])
@@ -681,7 +683,7 @@ def test_export_layernorm_linear(
681683
(torch.float16, False),
682684
])
683685
def test_export_layernorm_mlp(
684-
scale_factor: list,
686+
scale_factors: List[float],
685687
use_fp8: bool,
686688
use_bias: bool,
687689
return_bias: bool,
@@ -709,7 +711,7 @@ def test_export_layernorm_mlp(
709711
params_dtype=precision,
710712
).to(device='cuda')
711713
if use_fp8:
712-
set_layer_scale(model, scale_factor, num_gemms=2)
714+
set_layer_scale(model, scale_factors)
713715
do_export(model, inp, fname, use_fp8)
714716
if not use_fp8:
715717
validate_result(fname, inp, model, atol=5e-4)
@@ -778,10 +780,10 @@ def test_export_core_attention(
778780

779781

780782
def set_mha_scales(module,
781-
scale_factor_qkv: list=[448, 448],
782-
scale_factor_query: list=[112, 112],
783-
scale_factor_kv: list=[224, 224],
784-
scale_factor_proj: list=[448, 448]
783+
scale_factor_qkv: List[float]=[448, 448],
784+
scale_factor_query: List[float]=[112, 112],
785+
scale_factor_kv: List[float]=[224, 224],
786+
scale_factor_proj: List[float]=[448, 448]
785787
):
786788
if module.attention_type == "self":
787789
if module.input_layernorm:
@@ -842,10 +844,10 @@ def test_export_multihead_attention(
842844
input_layernorm: bool,
843845
attention_type: str,
844846
fuse_qkv_params: bool,
845-
scale_factor_qkv: list,
846-
scale_factor_query: list,
847-
scale_factor_kv: list,
848-
scale_factor_proj: list,
847+
scale_factor_qkv: List[float],
848+
scale_factor_query: List[float],
849+
scale_factor_kv: List[float],
850+
scale_factor_proj: List[float],
849851
):
850852
hidden_size = 256
851853
sequence_length = 128
@@ -918,7 +920,7 @@ def set_transformer_layer_scales(module,
918920
if module.layer_type == "decoder":
919921
set_mha_scales(module.inter_attention, *scales_inter_attn)
920922
# set layernorm mlp scales
921-
set_layer_scale(module.layernorm_mlp, scales_layernorm_mlp, num_gemms=2)
923+
set_layer_scale(module.layernorm_mlp, scales_layernorm_mlp)
922924

923925
@pytest.mark.parametrize("use_fp8", [False, True])
924926
@pytest.mark.parametrize("use_mask, attn_mask_type", test_configs_multihead_attention)
@@ -942,11 +944,11 @@ def test_export_transformer_layer(
942944
precision: torch.dtype,
943945
fuse_qkv_params: bool,
944946
apply_query_key_layer_scaling: bool,
945-
scale_factor_qkv: list,
946-
scale_factor_query: list,
947-
scale_factor_kv: list,
948-
scale_factor_proj: list,
949-
scale_factor_layernorm_mlp: list,
947+
scale_factor_qkv: List[float],
948+
scale_factor_query: List[float],
949+
scale_factor_kv: List[float],
950+
scale_factor_proj: List[float],
951+
scale_factor_layernorm_mlp: List[float],
950952
):
951953
# Layer configuration
952954
hidden_size = 64

0 commit comments

Comments
 (0)