Skip to content

Commit c6e540c

Browse files
committed
[quantization] Propagate qparam for expand
This PR propagates qparam forward for `torch.ops.aten.expand.default` . TICO-DCO-1.0-Signed-off-by: s.malakhov <s.malakhov@partner.samsung.com>
1 parent 63d18ee commit c6e540c

File tree

2 files changed

+22
-1
lines changed

2 files changed

+22
-1
lines changed

test/quantization/pass/test_propagate_quant_param.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -260,3 +260,21 @@ def test_s16_different_scale(self):
260260

261261
# The test will check cat's scale is 1.0, the larger one
262262
self.run_test()
263+
264+
265+
class ExpandModule(torch.nn.Module):
266+
def __init__(self):
267+
super().__init__()
268+
269+
def forward(self, x):
270+
return x.expand(5, 3)
271+
272+
def get_example_inputs(self):
273+
return (torch.randn(1, 3),), {}
274+
275+
276+
class ExpandTest(SingleOpPropagateQParamForwardTest):
277+
# TODO Support u8
278+
def test_s16(self):
279+
self.setup(ExpandModule(), torch.ops.aten.expand.default, dtype="int16")
280+
self.run_test()

tico/quantization/passes/propagate_qparam_forward.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from tico.utils.trace_decorators import trace_graph_diff_on_pass
2828
from tico.utils.validate_args_kwargs import (
2929
CatArgs,
30+
ExpandArgs,
3031
NegArgs,
3132
PermuteArgs,
3233
ReshapeArgs,
@@ -130,7 +131,9 @@ def _propagate_qparam_if_possible(src: torch.fx.Node, dst: torch.fx.Node):
130131

131132
assert max_scale_node is not None
132133
_propagate_qparam_if_possible(max_scale_node, node)
133-
134+
elif node.target == torch.ops.aten.expand.default:
135+
expand_args = ExpandArgs(*node.args, **node.kwargs)
136+
_propagate_qparam_if_possible(expand_args.input, node)
134137
# TODO Support more ops.
135138

136139
graph.eliminate_dead_code()

0 commit comments

Comments
 (0)