Skip to content

Commit c4c32ba

Browse files
committed
minor
Signed-off-by: realAsma <[email protected]>
1 parent 97dc2ef commit c4c32ba

File tree

1 file changed

+3
-5
lines changed

1 file changed

+3
-5
lines changed

modelopt/torch/quantization/tensor_quant.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -80,11 +80,11 @@ def scaled_e4m3_impl(
8080
Returns:
8181
Input tensors faked quantized to FP8.
8282
"""
83-
if inputs.is_cpu:
83+
if inputs.is_cpu or amax is None or amax.squeeze().ndim > 1:
8484
return fp8_eager(inputs, amax)
8585

8686
cuda_ext_fp8 = get_cuda_ext_fp8(raise_if_failed=False)
87-
if cuda_ext_fp8 is None or amax is None:
87+
if cuda_ext_fp8 is None:
8888
return fp8_eager(inputs, amax)
8989

9090
with torch.cuda.device(
@@ -95,9 +95,7 @@ def scaled_e4m3_impl(
9595
elif amax.squeeze().ndim == 1:
9696
axis = amax.shape.index(amax.numel())
9797
outputs = cuda_ext_fp8.fake_e4m3fy_with_axis(inputs, amax.squeeze(), axis)
98-
else:
99-
outputs = fp8_eager(inputs, amax)
100-
return outputs
98+
return outputs
10199

102100

103101
def fake_quant_impl(

0 commit comments

Comments
 (0)