File tree Expand file tree Collapse file tree 1 file changed +3
-5
lines changed
modelopt/torch/quantization Expand file tree Collapse file tree 1 file changed +3
-5
lines changed Original file line number Diff line number Diff line change @@ -80,11 +80,11 @@ def scaled_e4m3_impl(
80
80
Returns:
81
81
Input tensors faked quantized to FP8.
82
82
"""
83
- if inputs .is_cpu :
83
+ if inputs .is_cpu or amax is None or amax . squeeze (). ndim > 1 :
84
84
return fp8_eager (inputs , amax )
85
85
86
86
cuda_ext_fp8 = get_cuda_ext_fp8 (raise_if_failed = False )
87
- if cuda_ext_fp8 is None or amax is None :
87
+ if cuda_ext_fp8 is None :
88
88
return fp8_eager (inputs , amax )
89
89
90
90
with torch .cuda .device (
@@ -95,9 +95,7 @@ def scaled_e4m3_impl(
95
95
elif amax .squeeze ().ndim == 1 :
96
96
axis = amax .shape .index (amax .numel ())
97
97
outputs = cuda_ext_fp8 .fake_e4m3fy_with_axis (inputs , amax .squeeze (), axis )
98
- else :
99
- outputs = fp8_eager (inputs , amax )
100
- return outputs
98
+ return outputs
101
99
102
100
103
101
def fake_quant_impl (
You can’t perform that action at this time.
0 commit comments