Skip to content

Commit d2c6e0f

Browse files
committed
Fixed test cases fail
Signed-off-by: Jingyu Xin <[email protected]>
1 parent 831c32d commit d2c6e0f

File tree

2 files changed

+3
-3
lines changed

2 files changed

+3
-3
lines changed

modelopt/torch/quantization/nn/modules/tensor_quantizer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -707,12 +707,12 @@ def _fake_quantize(self, inputs):
707707
outputs = scaled_e4m3(
708708
inputs,
709709
amax,
710-
block_sizes_list,
711710
self._get_bias(inputs),
712711
E,
713712
M,
714713
self._trt_high_precision_dtype,
715714
self._pass_through_bwd,
715+
block_sizes_list,
716716
)
717717

718718
else:

modelopt/torch/quantization/tensor_quant.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -417,12 +417,12 @@ def symbolic(
417417
g,
418418
inputs,
419419
amax=None,
420-
block_sizes=None,
421420
bias=None,
422421
E=4, # noqa: N803
423422
M=3, # noqa: N803
424423
trt_high_precision_dtype=None,
425424
pass_through_bwd=False,
425+
block_sizes=None,
426426
):
427427
"""ONNX symbolic function."""
428428
from .export_onnx import export_fp8
@@ -435,12 +435,12 @@ def forward(
435435
ctx,
436436
inputs,
437437
amax,
438-
block_sizes,
439438
bias,
440439
E, # noqa: N803
441440
M, # noqa: N803
442441
trt_high_precision_dtype=None,
443442
pass_through_bwd=False,
443+
block_sizes=None,
444444
):
445445
"""Forward method."""
446446
if E != 4 or M != 3:

0 commit comments

Comments
 (0)