Skip to content

Commit b4cb2d5

Browse files
committed
fix fp8 gemm bug
Signed-off-by: Ye Yu <[email protected]>
1 parent 760c583 commit b4cb2d5

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

modelopt/torch/quantization/backends/fp8_per_tensor_gemm.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,9 @@ def backward(ctx, grad_outputs):
175175
weight = weight.dequantize(scale=scale, block_sizes=ctx.block_sizes)
176176
grad_input = grad_outputs @ weight
177177
if input_tensor is not None:
178-
grad_weight = grad_outputs.transpose(-2, 1) @ input_tensor
178+
grad_weight = grad_outputs.reshape(-1, grad_outputs.shape[-1]).T @ input_tensor.reshape(
179+
-1, input_tensor.shape[-1]
180+
)
179181
if compute_bias_grad is not None:
180182
# Sum all dimensions except the last one
181183
grad_bias = grad_outputs.sum(dim=list(range(grad_outputs.dim() - 1)))

0 commit comments

Comments
 (0)