@@ -41,6 +41,8 @@ def parse_args(args):
4141 parser .add_argument ("--m" , type = int )
4242 parser .add_argument ("--k" , type = int )
4343 parser .add_argument ("--n" , type = int )
44+ parser .add_argument ("--per-tensor-scale-a" , type = float , default = None )
45+ parser .add_argument ("--per-tensor-scale-b" , type = float , default = None )
4446 return parser .parse_args (args )
4547
4648
@@ -54,18 +56,58 @@ def __init__(
5456 super ().__init__ (tb_args , extra_args )
5557 self .extra_args = parse_args (extra_args )
5658
59+ def _get_dtype (self ):
60+ if self .extra_args .scaling_rowwise :
61+ return torch .bfloat16
62+ else :
63+ return torch .float16
64+
5765 def get_input_iter (self ):
66+ def _get_scale_per_tensor (
67+ x : torch .Tensor , custom_scale : float = None
68+ ) -> torch .Tensor :
69+ # For tensor-wise scaling, kernel requires a float32 scale tensor
70+ if custom_scale :
71+ return torch .tensor (custom_scale , dtype = torch .float32 , device = x .device )
72+ scale = torch .finfo (torch .float8_e4m3fn ).max / x .abs ().max ()
73+ return scale .to (torch .float32 )
74+
75+ def _get_scale_per_row (
76+ x : torch .Tensor , transpose : bool = False
77+ ) -> torch .Tensor :
78+ if transpose : # scale_b.shape should be [1, N]
79+ scale = (
80+ torch .finfo (torch .float8_e4m3fn ).max
81+ / x .abs ().max (dim = 0 , keepdim = True ).values
82+ )
83+ else : # scale_a.shape should be [M, 1]
84+ scale = (
85+ torch .finfo (torch .float8_e4m3fn ).max
86+ / x .abs ().max (dim = 1 , keepdim = True ).values
87+ )
88+ return scale .to (
89+ torch .float32
90+ ) # For row-wise scaling, kernel requires a float32 scale tensor
91+
5892 def args (m , n , k ):
59- a = torch .randn (m , k , device = self .device ).to (torch .float16 )
60- b = torch .randn (k , n , device = self .device ).to (torch .float16 ).T .contiguous ().T
93+ a = torch .randn (m , k , device = self .device ).to (self ._get_dtype ())
94+ b = (
95+ torch .randn (k , n , device = self .device )
96+ .to (self ._get_dtype ())
97+ .T .contiguous ()
98+ .T
99+ )
61100
62101 if self .extra_args .scaling_rowwise :
63- M , N = a .shape [0 ], b .shape [1 ]
64- scale_a = torch .ones ((M , 1 ), dtype = torch .float32 , device = a .device )
65- scale_b = torch .ones ((1 , N ), dtype = torch .float32 , device = b .device )
102+ scale_a = _get_scale_per_row (a )
103+ scale_b = _get_scale_per_row (b , transpose = True )
66104 else :
67- scale_a = torch .tensor (1.0 , device = a .device )
68- scale_b = torch .tensor (1.0 , device = a .device )
105+ scale_a = _get_scale_per_tensor (
106+ a , custom_scale = self .extra_args .per_tensor_scale_a
107+ )
108+ scale_b = _get_scale_per_tensor (
109+ b , custom_scale = self .extra_args .per_tensor_scale_b
110+ )
69111
70112 # Kernels expect dtype=float8_e4m3fn
71113 a = a .to (torch .float8_e4m3fn )
@@ -103,16 +145,10 @@ def get_x_val(self, example_inputs) -> float:
103145 _ , n = b .size ()
104146 return (m , n , k )
105147
106- def _get_out_dtype (self ):
107- if self .extra_args .scaling_rowwise :
108- return torch .bfloat16
109- else :
110- return torch .float16
111-
112148 @register_benchmark (baseline = True )
113149 def torch_fp8_gemm (self , a , b , scale_a , scale_b ):
114150 return lambda : torch ._scaled_mm (
115- a , b , scale_a , scale_b , use_fast_accum = True , out_dtype = self ._get_out_dtype ()
151+ a , b , scale_a , scale_b , use_fast_accum = True , out_dtype = self ._get_dtype ()
116152 )
117153
118154 @register_benchmark ()
@@ -124,12 +160,7 @@ def pt2_fp8_gemm(self, a, b, scale_a, scale_b) -> Callable:
124160 autotune_fallback_to_aten = False ,
125161 ):
126162 f = lambda a , b : torch ._scaled_mm (
127- a ,
128- b ,
129- scale_a ,
130- scale_b ,
131- use_fast_accum = True ,
132- out_dtype = self ._get_out_dtype (),
163+ a , b , scale_a , scale_b , use_fast_accum = True , out_dtype = self ._get_dtype ()
133164 )
134165 compiled = torch .compile (f , dynamic = False )
135166 compiled (a , b )
0 commit comments