Skip to content

Commit 249902a

Browse files
authored
Fix fp8_gemm on H100
1 parent d69a57f commit 249902a

File tree

1 file changed

+10
-1
lines changed

1 file changed

+10
-1
lines changed

auto_fp8/quantize.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,14 +65,23 @@ def fp8_gemm(A, A_scale, B, B_scale, bias, out_dtype):
6565
torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0)
6666
)
6767
if native_fp8_support:
68+
need_reshape = A.dim() == 3
69+
if need_reshape:
70+
batch_size = A.shape[0]
71+
A_input = A.reshape(-1, A.shape[-1])
72+
else:
73+
batch_size = None
74+
A_input = A
6875
output, _ = torch._scaled_mm(
69-
A,
76+
A_input,
7077
B.t(),
7178
out_dtype=out_dtype,
7279
scale_a=A_scale,
7380
scale_b=B_scale,
7481
bias=bias,
7582
)
83+
if need_reshape:
84+
output = output.reshape((batch_size, *output.shape))
7685
else:
7786
output = torch.nn.functional.linear(
7887
A.to(out_dtype) * A_scale,

0 commit comments

Comments
 (0)