Skip to content

Commit 0e0d6f6

Browse files
jananisriramfacebook-github-bot
authored andcommitted
Add amax as default per-tensor scaling factor for fp8_gemm benchmark (#339)
Summary: Pull Request resolved: #339 Add `amax` (absolute maximum) as the default scaling factor for per-tensor scaling with fp8 workloads, as is used in practice. Also add a command-line argument that allows the user to define a scaling factor for per-tensor scaling. Reviewed By: NikhilAPatel Differential Revision: D80577628
1 parent 3c10556 commit 0e0d6f6

File tree

1 file changed

+21
-12
lines changed

1 file changed

+21
-12
lines changed

tritonbench/operators/fp8_gemm/fp8_gemm.py

Lines changed: 21 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,8 @@ def parse_args(args):
4141
parser.add_argument("--m", type=int)
4242
parser.add_argument("--k", type=int)
4343
parser.add_argument("--n", type=int)
44+
parser.add_argument("--per-tensor-scale-a", type=float, default=None)
45+
parser.add_argument("--per-tensor-scale-b", type=float, default=None)
4446
return parser.parse_args(args)
4547

4648

@@ -54,12 +56,25 @@ def __init__(
5456
super().__init__(tb_args, extra_args)
5557
self.extra_args = parse_args(extra_args)
5658

59+
def _get_dtype(self):
60+
if self.extra_args.scaling_rowwise:
61+
return torch.bfloat16
62+
else:
63+
return torch.float16
64+
5765
def get_input_iter(self):
66+
def _get_scale_per_tensor(x: torch.Tensor, custom_scale: float = None) -> torch.Tensor:
67+
# For tensor-wise scaling, kernel requires a float32 scale tensor
68+
if custom_scale:
69+
return torch.tensor(custom_scale, dtype=torch.float32, device=x.device)
70+
scale = torch.finfo(torch.float8_e4m3fn).max / x.abs().max()
71+
return scale.to(torch.float32)
72+
5873
def args(m, n, k):
59-
a = torch.randn(m, k, device=self.device).to(torch.float16)
74+
a = torch.randn(m, k, device=self.device).to(self._get_dtype())
6075
b = (
6176
torch.randn(k, n, device=self.device)
62-
.to(torch.float16)
77+
.to(self._get_dtype())
6378
.T.contiguous()
6479
.T
6580
)
@@ -69,8 +84,8 @@ def args(m, n, k):
6984
scale_a = torch.ones((M, 1), dtype=torch.float32, device=a.device)
7085
scale_b = torch.ones((1, N), dtype=torch.float32, device=b.device)
7186
else:
72-
scale_a = torch.tensor(1.0, device=a.device)
73-
scale_b = torch.tensor(1.0, device=a.device)
87+
scale_a = _get_scale_per_tensor(a, custom_scale=self.extra_args.per_tensor_scale_a)
88+
scale_b = _get_scale_per_tensor(b, custom_scale=self.extra_args.per_tensor_scale_b)
7489

7590
# Kernels expect dtype=float8_e4m3fn
7691
a = a.to(torch.float8_e4m3fn)
@@ -108,16 +123,10 @@ def get_x_val(self, example_inputs) -> float:
108123
_, n = b.size()
109124
return (m, n, k)
110125

111-
def _get_out_dtype(self):
112-
if self.extra_args.scaling_rowwise:
113-
return torch.bfloat16
114-
else:
115-
return torch.float16
116-
117126
@register_benchmark(baseline=True)
118127
def torch_fp8_gemm(self, a, b, scale_a, scale_b):
119128
return lambda: torch._scaled_mm(
120-
a, b, scale_a, scale_b, use_fast_accum=True, out_dtype=self._get_out_dtype()
129+
a, b, scale_a, scale_b, use_fast_accum=True, out_dtype=self._get_dtype()
121130
)
122131

123132
@register_benchmark()
@@ -129,7 +138,7 @@ def pt2_fp8_gemm(self, a, b, scale_a, scale_b) -> Callable:
129138
autotune_fallback_to_aten=False,
130139
):
131140
f = lambda a, b: torch._scaled_mm(
132-
a, b, scale_a, scale_b, use_fast_accum=True, out_dtype=self._get_out_dtype()
141+
a, b, scale_a, scale_b, use_fast_accum=True, out_dtype=self._get_dtype()
133142
)
134143
compiled = torch.compile(f, dynamic=False)
135144
compiled(a, b)

0 commit comments

Comments
 (0)