Skip to content

Commit 2c70d7a

Browse files
authored
Update auto_fp8/quantize.py
1 parent 249902a commit 2c70d7a

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

auto_fp8/quantize.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ def fp8_gemm(A, A_scale, B, B_scale, bias, out_dtype):
8181
bias=bias,
8282
)
8383
if need_reshape:
84-
output = output.reshape((batch_size, *output.shape))
84+
output = output.reshape((batch_size, output.shape[0] // batch_size, output.shape[1]))
8585
else:
8686
output = torch.nn.functional.linear(
8787
A.to(out_dtype) * A_scale,

0 commit comments

Comments
 (0)