1
1
import argparse
2
+
2
3
import logging
3
4
4
5
from typing import Any , Callable , List , Optional
7
8
import torch ._inductor .config as inductor_config
8
9
import triton
9
10
11
+ from torch ._inductor .kernel .mm import scaling_pairs , ScalingType
12
+
10
13
from tritonbench .operators .fp8_gemm .persistent import blackwell_persistent_tma
11
14
from tritonbench .utils .env_utils import get_nvidia_gpu_model , is_cuda
12
15
46
49
def parse_args (args ):
47
50
parser = argparse .ArgumentParser (description = "TritonBench fp8_gemm" )
48
51
parser .add_argument ("--llama" , action = "store_true" )
49
- parser .add_argument ("--scaling_rowwise " , action = "store_true " )
52
+ parser .add_argument ("--scaling-pair " , type = str , default = "TensorWise,TensorWise " )
50
53
parser .add_argument ("--m" , type = int )
51
54
parser .add_argument ("--k" , type = int )
52
55
parser .add_argument ("--n" , type = int )
@@ -55,6 +58,54 @@ def parse_args(args):
55
58
return parser .parse_args (args )
56
59
57
60
61
+ def get_scaling_recipe_int (scaling_recipe : str ) -> int :
62
+ if scaling_recipe == "TensorWise" :
63
+ return ScalingType .TensorWise
64
+ elif scaling_recipe == "RowWise" :
65
+ return ScalingType .RowWise
66
+ else :
67
+ raise ValueError (f"Invalid scaling recipe: { scaling_recipe } " )
68
+
69
+
70
+ def get_scale (
71
+ x : torch .Tensor ,
72
+ scaling_recipe_int : int ,
73
+ transpose : bool = False ,
74
+ custom_scale : float = None ,
75
+ ):
76
+ def _get_scale_per_tensor (
77
+ x : torch .Tensor , custom_scale : float = None
78
+ ) -> torch .Tensor :
79
+ # For tensor-wise scaling, kernel requires a float32 scale tensor
80
+ if custom_scale :
81
+ return torch .tensor (custom_scale , dtype = torch .float32 , device = x .device )
82
+ scale = torch .finfo (torch .float8_e4m3fn ).max / x .abs ().max ()
83
+ return scale .to (torch .float32 )
84
+
85
+ def _get_scale_per_row (x : torch .Tensor , transpose : bool = False ) -> torch .Tensor :
86
+ if transpose : # scale_b.shape should be [1, N]
87
+ scale = (
88
+ torch .finfo (torch .float8_e4m3fn ).max
89
+ / x .abs ().max (dim = 0 , keepdim = True ).values
90
+ )
91
+ else : # scale_a.shape should be [M, 1]
92
+ scale = (
93
+ torch .finfo (torch .float8_e4m3fn ).max
94
+ / x .abs ().max (dim = 1 , keepdim = True ).values
95
+ )
96
+ return scale .to (
97
+ torch .float32
98
+ ) # For row-wise scaling, kernel requires a float32 scale tensor
99
+
100
+ match scaling_recipe_int :
101
+ case ScalingType .TensorWise :
102
+ return _get_scale_per_tensor (x , custom_scale = custom_scale )
103
+ case ScalingType .RowWise :
104
+ return _get_scale_per_row (x , transpose = transpose )
105
+ case _:
106
+ raise AssertionError (f"Unsupported scaling type { scaling_recipe_int } " )
107
+
108
+
58
109
class Operator (BenchmarkOperator ):
59
110
DEFAULT_METRICS = ["tflops" , "gbps" , "latency" ]
60
111
DEFAULT_PRECISION = "fp8"
@@ -66,53 +117,39 @@ def __init__(
66
117
super ().__init__ (tb_args , extra_args )
67
118
self .extra_args = parse_args (extra_args )
68
119
120
+ scaling_recipe_a , scaling_recipe_b = self .extra_args .scaling_pair .split ("," )
121
+ if (scaling_recipe_a , scaling_recipe_b ) not in [
122
+ (a .name , b .name ) for a , b in scaling_pairs
123
+ ]:
124
+ raise ValueError (
125
+ f"Invalid scaling pair: { scaling_recipe_a } , { scaling_recipe_b } . See torch/_inductor/kernel/mm.py::scaling_pairs for valid pairs."
126
+ )
127
+ self .scaling_recipe_a_int = get_scaling_recipe_int (scaling_recipe_a ).value
128
+ self .scaling_recipe_b_int = get_scaling_recipe_int (scaling_recipe_b ).value
129
+
69
130
def _get_dtype (self ):
70
- if self .extra_args .scaling_rowwise :
71
- return torch .bfloat16
72
- else :
131
+ if (
132
+ self .scaling_recipe_a_int == ScalingType .TensorWise
133
+ and self .scaling_recipe_b_int == ScalingType .TensorWise
134
+ ):
73
135
return torch .float16
136
+ return torch .bfloat16
74
137
75
138
def get_input_iter (self ):
76
- def _get_scale_per_tensor (
77
- x : torch .Tensor , custom_scale : float = None
78
- ) -> torch .Tensor :
79
- # For tensor-wise scaling, kernel requires a float32 scale tensor
80
- if custom_scale :
81
- return torch .tensor (custom_scale , dtype = torch .float32 , device = x .device )
82
- scale = torch .finfo (torch .float8_e4m3fn ).max / x .abs ().max ()
83
- return scale .to (torch .float32 )
84
-
85
- def _get_scale_per_row (
86
- x : torch .Tensor , transpose : bool = False
87
- ) -> torch .Tensor :
88
- if transpose : # scale_b.shape should be [1, N]
89
- scale = (
90
- torch .finfo (torch .float8_e4m3fn ).max
91
- / x .abs ().max (dim = 0 , keepdim = True ).values
92
- )
93
- else : # scale_a.shape should be [M, 1]
94
- scale = (
95
- torch .finfo (torch .float8_e4m3fn ).max
96
- / x .abs ().max (dim = 1 , keepdim = True ).values
97
- )
98
- return scale .to (
99
- torch .float32
100
- ) # For row-wise scaling, kernel requires a float32 scale tensor
101
-
102
139
def args (m , n , k ):
103
140
a = torch .randn (m , k , device = self .device ).to (self ._get_dtype ())
104
141
b = torch .randn (n , k , device = self .device ).to (self ._get_dtype ())
105
142
106
- if self . extra_args . scaling_rowwise :
107
- scale_a = _get_scale_per_row ( a )
108
- scale_b = _get_scale_per_row ( b )
109
- else :
110
- scale_a = _get_scale_per_tensor (
111
- a , custom_scale = self . extra_args . per_tensor_scale_a
112
- )
113
- scale_b = _get_scale_per_tensor (
114
- b , custom_scale = self .extra_args .per_tensor_scale_b
115
- )
143
+ scale_a = get_scale (
144
+ a ,
145
+ self . scaling_recipe_a_int ,
146
+ custom_scale = self . extra_args . per_tensor_scale_a ,
147
+ )
148
+ scale_b = get_scale (
149
+ b ,
150
+ self . scaling_recipe_b_int ,
151
+ custom_scale = self .extra_args .per_tensor_scale_b ,
152
+ )
116
153
117
154
# Kernels expect dtype=float8_e4m3fn
118
155
a = a .to (torch .float8_e4m3fn )
@@ -192,7 +229,7 @@ def blackwell_persistent_tma_fp8_gemm(self, a, b, scale_a, scale_b):
192
229
scale_a ,
193
230
scale_b ,
194
231
self ._get_dtype (),
195
- self .extra_args . scaling_rowwise ,
232
+ 0 if self .scaling_recipe_a_int == self . scaling_recipe_b_int == 0 else 1 ,
196
233
)
197
234
198
235
@register_benchmark (enabled = True )
0 commit comments