1
1
import argparse
2
+
2
3
import logging
3
4
4
5
from typing import Any , Callable , List , Optional
7
8
import torch ._inductor .config as inductor_config
8
9
import triton
9
10
11
+ from torch ._inductor .kernel .mm import ScalingMode
12
+
10
13
from tritonbench .operators .fp8_gemm .persistent import blackwell_persistent_tma
11
14
from tritonbench .utils .env_utils import get_nvidia_gpu_model , is_cuda
12
15
46
49
def parse_args (args ):
47
50
parser = argparse .ArgumentParser (description = "TritonBench fp8_gemm" )
48
51
parser .add_argument ("--llama" , action = "store_true" )
49
- parser .add_argument ("--scaling_rowwise " , action = "store_true " )
52
+ parser .add_argument ("--scaling-mode " , type = str , default = "tensor " )
50
53
parser .add_argument ("--m" , type = int )
51
54
parser .add_argument ("--k" , type = int )
52
55
parser .add_argument ("--n" , type = int )
@@ -55,6 +58,15 @@ def parse_args(args):
55
58
return parser .parse_args (args )
56
59
57
60
61
+ def get_scaling_mode_int (scaling_mode : str ) -> int :
62
+ if scaling_mode == "tensor" :
63
+ return ScalingMode .TENSOR
64
+ elif scaling_mode == "row" :
65
+ return ScalingMode .ROW
66
+ else :
67
+ raise ValueError (f"Invalid scaling mode: { scaling_mode } " )
68
+
69
+
58
70
class Operator (BenchmarkOperator ):
59
71
DEFAULT_METRICS = ["tflops" , "gbps" , "latency" ]
60
72
DEFAULT_PRECISION = "fp8"
@@ -66,11 +78,12 @@ def __init__(
66
78
super ().__init__ (tb_args , extra_args )
67
79
self .extra_args = parse_args (extra_args )
68
80
81
+ self .scaling_mode_int = get_scaling_mode_int (self .extra_args .scaling_mode ).value
82
+
69
83
def _get_dtype (self ):
70
- if self .extra_args .scaling_rowwise :
71
- return torch .bfloat16
72
- else :
84
+ if self .scaling_mode_int == ScalingMode .TENSOR :
73
85
return torch .float16
86
+ return torch .bfloat16
74
87
75
88
def get_input_iter (self ):
76
89
def _get_scale_per_tensor (
@@ -103,10 +116,10 @@ def args(m, n, k):
103
116
a = torch .randn (m , k , device = self .device ).to (self ._get_dtype ())
104
117
b = torch .randn (n , k , device = self .device ).to (self ._get_dtype ())
105
118
106
- if self .extra_args . scaling_rowwise :
119
+ if self .scaling_mode_int == ScalingMode . ROW :
107
120
scale_a = _get_scale_per_row (a )
108
121
scale_b = _get_scale_per_row (b )
109
- else :
122
+ else : # self.scaling_mode_int == ScalingMode.TENSOR
110
123
scale_a = _get_scale_per_tensor (
111
124
a , custom_scale = self .extra_args .per_tensor_scale_a
112
125
)
@@ -192,7 +205,7 @@ def blackwell_persistent_tma_fp8_gemm(self, a, b, scale_a, scale_b):
192
205
scale_a ,
193
206
scale_b ,
194
207
self ._get_dtype (),
195
- self .extra_args . scaling_rowwise ,
208
+ self .scaling_mode_int ,
196
209
)
197
210
198
211
@register_benchmark (enabled = True )
0 commit comments