Skip to content

Commit 5495e4c

Browse files
authored
Merge pull request #9 from comaniac/patch-1
Fix numel()=0
2 parents c4a9594 + c762239 commit 5495e4c

File tree

1 file changed

+8
-2
lines changed

1 file changed

+8
-2
lines changed

auto_fp8/quantize.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,8 +61,12 @@ def per_tensor_quantize(tensor: torch.Tensor) -> Tuple[torch.Tensor, float]:
6161

6262

6363
def fp8_gemm(A, A_scale, B, B_scale, bias, out_dtype):
64+
if A.numel() == 0:
65+
# Deal with empty tensors (triggeted by empty MoE experts)
66+
return torch.empty(size=(0, B.shape[0]), dtype=out_dtype, device=A.device)
67+
6468
native_fp8_support = (
65-
torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0)
69+
torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 9)
6670
)
6771
if native_fp8_support:
6872
need_reshape = A.dim() == 3
@@ -81,7 +85,9 @@ def fp8_gemm(A, A_scale, B, B_scale, bias, out_dtype):
8185
bias=bias,
8286
)
8387
if need_reshape:
84-
output = output.reshape((batch_size, output.shape[0] // batch_size, output.shape[1]))
88+
output = output.reshape(
89+
batch_size, output.shape[0] // batch_size, output.shape[1]
90+
)
8591
else:
8692
output = torch.nn.functional.linear(
8793
A.to(out_dtype) * A_scale,

0 commit comments

Comments
 (0)