Skip to content

Commit aae0dba

Browse files
authored
Make the requant pass call the per_tensor overload
Differential Revision: D74216340 Pull Request resolved: #11789
1 parent 28905e7 commit aae0dba

File tree

4 files changed

+12
-30
lines changed

4 files changed

+12
-30
lines changed

backends/cadence/aot/fuse_ops.py

Lines changed: 5 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -712,32 +712,14 @@ def _create_requantize_node(
712712
out_dtype: torch.dtype,
713713
graph: torch.fx.Graph,
714714
) -> torch.fx.Node:
715-
in_scale_tensor = graph.call_function(
716-
exir_ops.edge.aten.full.default, args=((1,), in_scale)
717-
)
718-
in_zero_point_tensor = graph.call_function(
719-
exir_ops.edge.aten.full.default,
720-
args=((1,), in_zero_point),
721-
kwargs={"dtype": torch.int32},
722-
)
723-
out_scale_tensor = graph.call_function(
724-
exir_ops.edge.aten.full.default, args=((1,), out_scale)
725-
)
726-
out_zero_point_tensor = graph.call_function(
727-
exir_ops.edge.aten.full.default,
728-
args=((1,), out_zero_point),
729-
kwargs={"dtype": torch.int32},
730-
)
731-
# cadence::requantize(Tensor input, Tensor in_scale, Tensor in_zero_point, Tensor out_scale, Tensor out_zero_point, ScalarType out_dtype) -> Tensor Y
732-
# TODO(hardiksharma): Add support for per-tensor requantize.
733715
return graph.call_function(
734-
exir_ops.edge.cadence.requantize.default,
716+
exir_ops.edge.cadence.requantize.per_tensor,
735717
args=(
736718
in_tensor,
737-
in_scale_tensor,
738-
in_zero_point_tensor,
739-
out_scale_tensor,
740-
out_zero_point_tensor,
719+
in_scale,
720+
in_zero_point,
721+
out_scale,
722+
out_zero_point,
741723
out_dtype,
742724
),
743725
)

backends/cadence/aot/remove_ops.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -447,7 +447,7 @@ def call_operator(
447447
kwargs: dict[str, Argument],
448448
meta: NodeMetadata,
449449
) -> ProxyValue:
450-
if op != exir_ops.edge.cadence.requantize.default:
450+
if op != exir_ops.edge.cadence.requantize.per_tensor:
451451
return super().call_operator(op, args, kwargs, meta)
452452

453453
# Parse the args

backends/cadence/aot/tests/test_fusion_ops_passes.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -306,7 +306,7 @@ def test_force_quant_dequant_fusion(self) -> None:
306306
# Verify that dequant/quant pair was replaced with requantize.
307307
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default: 0,
308308
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default: 0,
309-
exir_ops.edge.cadence.requantize.default: 1,
309+
exir_ops.edge.cadence.requantize.per_tensor: 1,
310310
},
311311
)
312312

@@ -336,7 +336,7 @@ def test_no_replace_quant_permute_dequant_with_requantize(self) -> None:
336336
# quantize -> permute -> dequantize should not be replaced with requantize.
337337
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default: 1,
338338
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default: 1,
339-
exir_ops.edge.cadence.requantize.default: 0,
339+
exir_ops.edge.cadence.requantize.per_tensor: 0,
340340
},
341341
)
342342

@@ -364,7 +364,7 @@ def test_replace_quant_view_dequant_with_requantize(self) -> None:
364364
# Verify that dequant/quant pair was replaced with requantize.
365365
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default: 0,
366366
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default: 0,
367-
exir_ops.edge.cadence.requantize.default: 1,
367+
exir_ops.edge.cadence.requantize.per_tensor: 1,
368368
},
369369
)
370370

@@ -390,7 +390,7 @@ def test_replace_dequant_quant_with_requantize(self) -> None:
390390
# Verify that dequant -> quant was replaced with requantize.
391391
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default: 0,
392392
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default: 0,
393-
exir_ops.edge.cadence.requantize.default: 1,
393+
exir_ops.edge.cadence.requantize.per_tensor: 1,
394394
},
395395
)
396396

@@ -420,7 +420,7 @@ def test_replace_dequant_permute_quant_with_requantize(self) -> None:
420420
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default: 0,
421421
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default: 0,
422422
exir_ops.edge.aten.permute_copy.default: 1,
423-
exir_ops.edge.cadence.requantize.default: 1,
423+
exir_ops.edge.cadence.requantize.per_tensor: 1,
424424
},
425425
)
426426

backends/cadence/aot/tests/test_reorder_ops_passes.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -217,7 +217,7 @@ def test_advance_branched_quantize(self) -> None:
217217
self.assertEqual(
218218
count_node(
219219
graph_module,
220-
exir_ops.edge.cadence.requantize.default,
220+
exir_ops.edge.cadence.requantize.per_tensor,
221221
),
222222
1,
223223
)

0 commit comments

Comments
 (0)