Skip to content

Commit ae2cb71

Browse files
jananisriramfacebook-github-bot
authored andcommitted
Add amax as default per-row scaling factor for fp8_gemm benchmark (#341)
Summary: Add `amax` (absolute maximum) as the default scaling factor for per-row scaling for fp8 GEMMs, as is used in practice. Reviewed By: xuzhao9 Test Plan: Imported from GitHub, without a `Test Plan:` line. Rollback Plan: Differential Revision: D80590746 Pulled By: jananisriram
1 parent 49ce565 commit ae2cb71

File tree

1 file changed

+44
-13
lines changed

1 file changed

+44
-13
lines changed

tritonbench/operators/fp8_gemm/fp8_gemm.py

Lines changed: 44 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,8 @@ def parse_args(args):
4141
parser.add_argument("--m", type=int)
4242
parser.add_argument("--k", type=int)
4343
parser.add_argument("--n", type=int)
44+
parser.add_argument("--per-tensor-scale-a", type=float, default=None)
45+
parser.add_argument("--per-tensor-scale-b", type=float, default=None)
4446
return parser.parse_args(args)
4547

4648

@@ -54,18 +56,53 @@ def __init__(
5456
super().__init__(tb_args, extra_args)
5557
self.extra_args = parse_args(extra_args)
5658

59+
def _get_dtype(self):
60+
if self.extra_args.scaling_rowwise:
61+
return torch.bfloat16
62+
else:
63+
return torch.float16
64+
5765
def get_input_iter(self):
66+
def _get_scale_per_tensor(
67+
x: torch.Tensor, custom_scale: float = None
68+
) -> torch.Tensor:
69+
# For tensor-wise scaling, kernel requires a float32 scale tensor
70+
if custom_scale:
71+
return torch.tensor(custom_scale, dtype=torch.float32, device=x.device)
72+
scale = torch.finfo(torch.float8_e4m3fn).max / x.abs().max()
73+
return scale.to(torch.float32)
74+
75+
def _get_scale_per_row(
76+
x: torch.Tensor, transpose: bool = False
77+
) -> torch.Tensor:
78+
if transpose: # scale_b.shape should be [1, N]
79+
scale = (
80+
torch.finfo(torch.float8_e4m3fn).max
81+
/ x.abs().max(dim=0, keepdim=True).values
82+
)
83+
else: # scale_a.shape should be [M, 1]
84+
scale = (
85+
torch.finfo(torch.float8_e4m3fn).max
86+
/ x.abs().max(dim=1, keepdim=True).values
87+
)
88+
return scale.to(
89+
torch.float32
90+
) # For row-wise scaling, kernel requires a float32 scale tensor
91+
5892
def args(m, n, k):
5993
a = torch.randn(m, k, device=self.device).to(torch.float16)
6094
b = torch.randn(k, n, device=self.device).to(torch.float16).T.contiguous().T
6195

6296
if self.extra_args.scaling_rowwise:
63-
M, N = a.shape[0], b.shape[1]
64-
scale_a = torch.ones((M, 1), dtype=torch.float32, device=a.device)
65-
scale_b = torch.ones((1, N), dtype=torch.float32, device=b.device)
97+
scale_a = _get_scale_per_row(a)
98+
scale_b = _get_scale_per_row(b, transpose=True)
6699
else:
67-
scale_a = torch.tensor(1.0, device=a.device)
68-
scale_b = torch.tensor(1.0, device=a.device)
100+
scale_a = _get_scale_per_tensor(
101+
a, custom_scale=self.extra_args.per_tensor_scale_a
102+
)
103+
scale_b = _get_scale_per_tensor(
104+
b, custom_scale=self.extra_args.per_tensor_scale_b
105+
)
69106

70107
# Kernels expect dtype=float8_e4m3fn
71108
a = a.to(torch.float8_e4m3fn)
@@ -103,16 +140,10 @@ def get_x_val(self, example_inputs) -> float:
103140
_, n = b.size()
104141
return (m, n, k)
105142

106-
def _get_out_dtype(self):
107-
if self.extra_args.scaling_rowwise:
108-
return torch.bfloat16
109-
else:
110-
return torch.float16
111-
112143
@register_benchmark(baseline=True)
113144
def torch_fp8_gemm(self, a, b, scale_a, scale_b):
114145
return lambda: torch._scaled_mm(
115-
a, b, scale_a, scale_b, use_fast_accum=True, out_dtype=self._get_out_dtype()
146+
a, b, scale_a, scale_b, use_fast_accum=True, out_dtype=self._get_dtype()
116147
)
117148

118149
@register_benchmark()
@@ -129,7 +160,7 @@ def pt2_fp8_gemm(self, a, b, scale_a, scale_b) -> Callable:
129160
scale_a,
130161
scale_b,
131162
use_fast_accum=True,
132-
out_dtype=self._get_out_dtype(),
163+
out_dtype=self._get_dtype(),
133164
)
134165
compiled = torch.compile(f, dynamic=False)
135166
compiled(a, b)

0 commit comments

Comments
 (0)