Skip to content

Commit 6abeb94

Browse files
chunnienccopybara-github
authored andcommitted
fix pt2e quant ops lowering
PiperOrigin-RevId: 706945163
1 parent fc9c986 commit 6abeb94

File tree

2 files changed

+9
-2
lines changed

2 files changed

+9
-2
lines changed

ai_edge_torch/odml_torch/export.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -304,9 +304,13 @@ def exported_program_to_mlir(
304304
)
305305

306306
_convert_i64_to_i32(exported_program)
307+
307308
exported_program = _torch_future.safe_run_decompositions(
308309
exported_program, lowerings.decompositions()
309310
)
311+
312+
# Passes below mutate the exported program to a state not executable by torch.
313+
# Do not call run_decompositions after applying the passes.
310314
_convert_q_dq_per_channel_args_to_list(exported_program)
311315

312316
with export_utils.create_ir_context() as context, ir.Location.unknown():

ai_edge_torch/odml_torch/lowerings/_quantized_decomposed.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,10 +52,13 @@ def _uniform_quantized_type(
5252
assert isinstance(scale, (list, tuple))
5353
assert isinstance(zero_point, (list, tuple))
5454

55+
scale = list(scale)
56+
zero_point = list(zero_point)
57+
5558
if len(scale) == 1:
56-
scale *= channel_axis_size
59+
scale = scale * channel_axis_size
5760
if len(zero_point) == 1:
58-
zero_point *= channel_axis_size
61+
zero_point = zero_point * channel_axis_size
5962

6063
assert len(scale) == len(zero_point) == channel_axis_size
6164
scale_zp_strs = []

0 commit comments

Comments
 (0)