Skip to content

Commit 1e3df5f

Browse files
committed
minor
Signed-off-by: realAsma <[email protected]>
1 parent 8415eb1 commit 1e3df5f

File tree

1 file changed

+4
-3
lines changed

1 file changed

+4
-3
lines changed

modelopt/torch/quantization/tensor_quant.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -379,8 +379,8 @@ def forward(
379379

380380
def legacy_quant_func():
381381
# The LegacyFakeTensorQuantFunction support cpu and amax with any shape that can be broadcasted to inputs.
382-
outputs, scale = _tensor_quant(inputs, amax, num_bits, unsigned, narrow_range)
383-
return outputs / scale.to(inputs.dtype)
382+
outputs = _tensor_quant(inputs, amax, num_bits, unsigned, narrow_range)
383+
return outputs
384384

385385
if not inputs.is_cuda:
386386
outputs = legacy_quant_func()
@@ -614,9 +614,10 @@ def _tensor_quant(inputs, amax, num_bits=8, unsigned=False, narrow_range=True):
614614
scale[zero_amax_mask] = (
615615
1.0 # Return 1 makes more sense for values quantized to 0 with amax=0
616616
)
617+
outputs = outputs / scale
617618

618619
outputs = outputs.to(input_dtype)
619-
return outputs, scale
620+
return outputs
620621

621622

622623
fake_tensor_quant = FakeTensorQuantFunction.apply

0 commit comments

Comments
 (0)