Skip to content

Commit dca006f

Browse files
committed
add test coverage of cast op
1 parent 02001bf commit dca006f

File tree

3 files changed

+28
-4
lines changed

3 files changed

+28
-4
lines changed

backends/qualcomm/_passes/fuse_consecutive_cast.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,8 @@ def _canonicalize_cast(
7474
clone_cast_node = graph.create_node(
7575
"call_function",
7676
exir_ops.edge.aten._to_copy.default,
77-
(n.args[0]),
77+
n.args,
78+
kwargs=n.kwargs,
7879
)
7980
clone_cast_node.meta = n.meta
8081
users[i].replace_input_with(n, clone_cast_node)
@@ -98,6 +99,8 @@ def _traverse(self, node):
9899
def _fuse(self, graph_module: torch.fx.GraphModule) -> torch.fx.GraphModule:
99100
for n in graph_module.graph.nodes:
100101
self._traverse(n)
102+
# TODO: how to handle following scenario (won't happen for quantized graph)
103+
# fp -> to(i32) -> to(fp)
101104
if len(self.nodes) > 1:
102105
input_node, output_node = self.nodes[0], self.nodes[-1]
103106
output_node.replace_input_with(output_node.args[0], input_node.args[0])

backends/qualcomm/tests/models.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,16 @@ def forward(self, x):
166166
return x.type(torch.IntTensor)
167167

168168

169+
class CastMultiUsers(torch.nn.Module):
170+
def __init__(self):
171+
super().__init__()
172+
173+
def forward(self, x, y):
174+
index = x.to(torch.long)
175+
res = torch.gather(y, dim=1, index=index)
176+
return res + index.to(torch.int32)
177+
178+
169179
class Cat2(torch.nn.Module):
170180
def __init__(self):
171181
super().__init__()

backends/qualcomm/tests/test_qnn_delegate.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -165,9 +165,14 @@ def test_qnn_backend_bmm(self):
165165
self.lower_module_and_test_output(module, sample_input)
166166

167167
def test_qnn_backend_cast(self):
168-
module = Cast() # noqa: F405
169-
sample_input = (10 * torch.rand((9, 4, 5, 3)),)
170-
self.lower_module_and_test_output(module, sample_input)
168+
modules = [Cast(), CastMultiUsers()] # noqa: F405
169+
sample_inputs = [
170+
(10 * torch.rand((9, 4, 5, 3)),),
171+
(torch.randint(0, 3, size=(3, 3)), torch.randn(3, 3)),
172+
]
173+
for i, (module, sample_input) in enumerate(zip(modules, sample_inputs)):
174+
with self.subTest(i=i):
175+
self.lower_module_and_test_output(module, sample_input)
171176

172177
def test_qnn_backend_cat(self):
173178
modules = [Cat2(), Cat3(), Cat4()] # noqa: F405
@@ -1234,6 +1239,12 @@ def test_qnn_backend_bmm(self):
12341239
module = self.get_qdq_module(module, sample_input)
12351240
self.lower_module_and_test_output(module, sample_input)
12361241

1242+
def test_qnn_backend_cast(self):
1243+
module = CastMultiUsers() # noqa: F405
1244+
sample_input = (torch.randint(0, 3, size=(3, 3)), torch.randn(3, 3))
1245+
module = self.get_qdq_module(module, sample_input)
1246+
self.lower_module_and_test_output(module, sample_input)
1247+
12371248
def test_qnn_backend_cat(self):
12381249
modules = [Cat2(), Cat3(), Cat4()] # noqa: F405
12391250
sample_input = (torch.randn(1, 1, 2, 2), torch.randn(1, 1, 4, 2))

0 commit comments

Comments
 (0)