1
1
import argparse
2
+
2
3
import logging
3
4
4
5
from typing import Any , Callable , List , Optional
7
8
import torch ._inductor .config as inductor_config
8
9
import triton
9
10
11
+ from torch ._inductor .kernel .mm import scaling_pairs , ScalingType
12
+
10
13
from tritonbench .operators .fp8_gemm .persistent import blackwell_persistent_tma
11
14
from tritonbench .utils .env_utils import get_nvidia_gpu_model , is_cuda
12
15
46
49
def parse_args (args ):
47
50
parser = argparse .ArgumentParser (description = "TritonBench fp8_gemm" )
48
51
parser .add_argument ("--llama" , action = "store_true" )
49
- parser .add_argument ("--scaling_rowwise " , action = "store_true " )
52
+ parser .add_argument ("--scaling-pair " , type = str , default = "TensorWise,TensorWise " )
50
53
parser .add_argument ("--m" , type = int )
51
54
parser .add_argument ("--k" , type = int )
52
55
parser .add_argument ("--n" , type = int )
@@ -55,6 +58,58 @@ def parse_args(args):
55
58
return parser .parse_args (args )
56
59
57
60
61
+ def get_scaling_recipe_int (scaling_recipe : str ) -> int :
62
+ if scaling_recipe == "TensorWise" :
63
+ return ScalingType .TensorWise
64
+ elif scaling_recipe == "RowWise" :
65
+ return ScalingType .RowWise
66
+ else :
67
+ raise ValueError (f"Invalid scaling recipe: { scaling_recipe } " )
68
+
69
+
70
+ def get_scale (
71
+ x : torch .Tensor ,
72
+ scaling_recipe_int : int ,
73
+ transpose : bool = False ,
74
+ custom_scale : float = None ,
75
+ ) -> (torch .Tensor , torch .Tensor ):
76
+ def _get_scale_per_tensor (
77
+ x : torch .Tensor , custom_scale : float = None
78
+ ) -> (torch .Tensor , torch .Tensor ):
79
+ # For tensor-wise scaling, kernel requires a float32 scale tensor
80
+ if custom_scale :
81
+ return torch .tensor (custom_scale , dtype = torch .float32 , device = x .device )
82
+ scale = torch .finfo (torch .float8_e4m3fn ).max / x .abs ().max ()
83
+ x *= scale
84
+ return x , scale .to (torch .float32 )
85
+
86
+ def _get_scale_per_row (
87
+ x : torch .Tensor , transpose : bool = False
88
+ ) -> (torch .Tensor , torch .Tensor ):
89
+ if transpose : # scale_b.shape should be [1, N]
90
+ scale = (
91
+ torch .finfo (torch .float8_e4m3fn ).max
92
+ / x .abs ().max (dim = 0 , keepdim = True ).values
93
+ )
94
+ else : # scale_a.shape should be [M, 1]
95
+ scale = (
96
+ torch .finfo (torch .float8_e4m3fn ).max
97
+ / x .abs ().max (dim = 1 , keepdim = True ).values
98
+ )
99
+ x = x .mul (scale )
100
+ return x , scale .to (
101
+ torch .float32
102
+ ) # For row-wise scaling, kernel requires a float32 scale tensor
103
+
104
+ match scaling_recipe_int :
105
+ case ScalingType .TensorWise :
106
+ return _get_scale_per_tensor (x , custom_scale = custom_scale )
107
+ case ScalingType .RowWise :
108
+ return _get_scale_per_row (x , transpose = transpose )
109
+ case _:
110
+ raise AssertionError (f"Unsupported scaling type { scaling_recipe_int } " )
111
+
112
+
58
113
class Operator (BenchmarkOperator ):
59
114
DEFAULT_METRICS = ["tflops" , "gbps" , "latency" ]
60
115
DEFAULT_PRECISION = "fp8"
@@ -66,53 +121,39 @@ def __init__(
66
121
super ().__init__ (tb_args , extra_args )
67
122
self .extra_args = parse_args (extra_args )
68
123
124
+ scaling_recipe_a , scaling_recipe_b = self .extra_args .scaling_pair .split ("," )
125
+ if (scaling_recipe_a , scaling_recipe_b ) not in [
126
+ (a .name , b .name ) for a , b in scaling_pairs
127
+ ]:
128
+ raise ValueError (
129
+ f"Invalid scaling pair: { scaling_recipe_a } , { scaling_recipe_b } . See torch/_inductor/kernel/mm.py::scaling_pairs for valid pairs."
130
+ )
131
+ self .scaling_recipe_a_int = get_scaling_recipe_int (scaling_recipe_a ).value
132
+ self .scaling_recipe_b_int = get_scaling_recipe_int (scaling_recipe_b ).value
133
+
69
134
def _get_dtype (self ):
70
- if self .extra_args .scaling_rowwise :
71
- return torch .bfloat16
72
- else :
135
+ if (
136
+ self .scaling_recipe_a_int == ScalingType .TensorWise
137
+ and self .scaling_recipe_b_int == ScalingType .TensorWise
138
+ ):
73
139
return torch .float16
140
+ return torch .bfloat16
74
141
75
142
def get_input_iter (self ):
76
- def _get_scale_per_tensor (
77
- x : torch .Tensor , custom_scale : float = None
78
- ) -> torch .Tensor :
79
- # For tensor-wise scaling, kernel requires a float32 scale tensor
80
- if custom_scale :
81
- return torch .tensor (custom_scale , dtype = torch .float32 , device = x .device )
82
- scale = torch .finfo (torch .float8_e4m3fn ).max / x .abs ().max ()
83
- return scale .to (torch .float32 )
84
-
85
- def _get_scale_per_row (
86
- x : torch .Tensor , transpose : bool = False
87
- ) -> torch .Tensor :
88
- if transpose : # scale_b.shape should be [1, N]
89
- scale = (
90
- torch .finfo (torch .float8_e4m3fn ).max
91
- / x .abs ().max (dim = 0 , keepdim = True ).values
92
- )
93
- else : # scale_a.shape should be [M, 1]
94
- scale = (
95
- torch .finfo (torch .float8_e4m3fn ).max
96
- / x .abs ().max (dim = 1 , keepdim = True ).values
97
- )
98
- return scale .to (
99
- torch .float32
100
- ) # For row-wise scaling, kernel requires a float32 scale tensor
101
-
102
143
def args (m , n , k ):
103
144
a = torch .randn (m , k , device = self .device ).to (self ._get_dtype ())
104
145
b = torch .randn (n , k , device = self .device ).to (self ._get_dtype ())
105
146
106
- if self . extra_args . scaling_rowwise :
107
- scale_a = _get_scale_per_row ( a )
108
- scale_b = _get_scale_per_row ( b )
109
- else :
110
- scale_a = _get_scale_per_tensor (
111
- a , custom_scale = self . extra_args . per_tensor_scale_a
112
- )
113
- scale_b = _get_scale_per_tensor (
114
- b , custom_scale = self .extra_args .per_tensor_scale_b
115
- )
147
+ a , scale_a = get_scale (
148
+ a ,
149
+ self . scaling_recipe_a_int ,
150
+ custom_scale = self . extra_args . per_tensor_scale_a ,
151
+ )
152
+ b , scale_b = get_scale (
153
+ b ,
154
+ self . scaling_recipe_b_int ,
155
+ custom_scale = self .extra_args .per_tensor_scale_b ,
156
+ )
116
157
117
158
# Kernels expect dtype=float8_e4m3fn
118
159
a = a .to (torch .float8_e4m3fn )
@@ -192,7 +233,7 @@ def blackwell_persistent_tma_fp8_gemm(self, a, b, scale_a, scale_b):
192
233
scale_a ,
193
234
scale_b ,
194
235
self ._get_dtype (),
195
- self .extra_args . scaling_rowwise ,
236
+ 0 if self .scaling_recipe_a_int == self . scaling_recipe_b_int == 0 else 1 ,
196
237
)
197
238
198
239
@register_benchmark (enabled = True )
0 commit comments