1
1
import argparse
2
+
2
3
import logging
3
4
4
5
from typing import Any , Callable , List , Optional
7
8
import torch ._inductor .config as inductor_config
8
9
import triton
9
10
11
+ from torch ._inductor .kernel .mm import ScalingMode
12
+
10
13
from tritonbench .operators .fp8_gemm .persistent import blackwell_persistent_tma
11
14
from tritonbench .utils .env_utils import get_nvidia_gpu_model , is_cuda
12
15
46
49
def parse_args (args ):
47
50
parser = argparse .ArgumentParser (description = "TritonBench fp8_gemm" )
48
51
parser .add_argument ("--llama" , action = "store_true" )
49
- parser .add_argument ("--scaling_rowwise " , action = "store_true " )
52
+ parser .add_argument ("--scaling-mode " , type = str , default = "tensor " )
50
53
parser .add_argument ("--m" , type = int )
51
54
parser .add_argument ("--k" , type = int )
52
55
parser .add_argument ("--n" , type = int )
@@ -55,6 +58,15 @@ def parse_args(args):
55
58
return parser .parse_args (args )
56
59
57
60
61
+ def get_scaling_mode_int (scaling_mode : str ) -> int :
62
+ if scaling_mode == "tensor" :
63
+ return ScalingMode .TENSOR
64
+ elif scaling_mode == "row" :
65
+ return ScalingMode .ROW
66
+ else :
67
+ raise ValueError (f"Invalid scaling mode: { scaling_mode } " )
68
+
69
+
58
70
class Operator (BenchmarkOperator ):
59
71
DEFAULT_METRICS = ["tflops" , "gbps" , "latency" ]
60
72
DEFAULT_PRECISION = "fp8"
@@ -65,11 +77,12 @@ def __init__(
65
77
super ().__init__ (tb_args , extra_args )
66
78
self .extra_args = parse_args (extra_args )
67
79
80
+ self .scaling_mode_int = get_scaling_mode_int (self .extra_args .scaling_mode ).value
81
+
68
82
def _get_dtype (self ):
69
- if self .extra_args .scaling_rowwise :
70
- return torch .bfloat16
71
- else :
83
+ if self .scaling_mode_int == ScalingMode .TENSOR :
72
84
return torch .float16
85
+ return torch .bfloat16
73
86
74
87
def get_input_iter (self ):
75
88
def _get_scale_per_tensor (
@@ -102,10 +115,10 @@ def args(m, n, k):
102
115
a = torch .randn (m , k , device = self .device ).to (self ._get_dtype ())
103
116
b = torch .randn (n , k , device = self .device ).to (self ._get_dtype ())
104
117
105
- if self .extra_args . scaling_rowwise :
118
+ if self .scaling_mode_int == ScalingMode . ROW :
106
119
scale_a = _get_scale_per_row (a )
107
120
scale_b = _get_scale_per_row (b )
108
- else :
121
+ else : # self.scaling_mode_int == ScalingMode.TENSOR
109
122
scale_a = _get_scale_per_tensor (
110
123
a , custom_scale = self .extra_args .per_tensor_scale_a
111
124
)
@@ -191,7 +204,7 @@ def blackwell_persistent_tma_fp8_gemm(self, a, b, scale_a, scale_b):
191
204
scale_a ,
192
205
scale_b ,
193
206
self ._get_dtype (),
194
- self .extra_args . scaling_rowwise ,
207
+ self .scaling_mode_int ,
195
208
)
196
209
197
210
@register_benchmark (enabled = True )
0 commit comments