Skip to content

Commit facf35d

Browse files
authored
Qualcomm AI Engine Direct - Cat Fix (#14325)
### Summary Fix op cat to retrieve the right node. ### Test plan CI pass
1 parent bc18834 commit facf35d

File tree

3 files changed

+20
-10
lines changed

3 files changed

+20
-10
lines changed

backends/qualcomm/builders/op_cat.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -29,22 +29,23 @@ def define_node(
2929
node: torch.fx.Node,
3030
nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper],
3131
) -> PyQnnWrapper.PyQnnOpWrapper:
32-
list_of_tensors = cast(List[torch.fx.Node], node.args[0])
33-
list_of_tensor_wrappers = []
32+
input_nodes = cast(List[torch.fx.Node], node.args[0])
33+
input_tensor_wrappers = []
3434

35-
for tensor_input in list_of_tensors:
36-
input_tensor = self.get_tensor(self.get_node(tensor_input), node)
37-
list_of_tensor_wrappers.append(
35+
for input_node in input_nodes:
36+
source_input_node = self.get_node(input_node)
37+
input_tensor = self.get_tensor(source_input_node, node)
38+
input_tensor_wrappers.append(
3839
self.define_tensor(
39-
tensor_input,
40+
source_input_node,
4041
node,
4142
input_tensor,
4243
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
4344
nodes_to_wrappers,
4445
)
4546
)
4647

47-
if len(list_of_tensors) != len(list_of_tensor_wrappers):
48+
if len(input_nodes) != len(input_tensor_wrappers):
4849
warnings.warn(
4950
"[QNN Delegate Op Builder]: The number or input tensors is not equal to the number of input tensor wrappers.",
5051
stacklevel=1,
@@ -76,7 +77,7 @@ def define_node(
7677
QNN_OP_PACKAGE_NAME_QTI_AISW,
7778
OpConcat.op_name,
7879
)
79-
concat_op.AddInputTensors(list_of_tensor_wrappers)
80+
concat_op.AddInputTensors(input_tensor_wrappers)
8081
concat_op.AddOutputTensors([output_tensor_wrapper])
8182

8283
concat_op.AddScalarParam(

backends/qualcomm/tests/models.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -274,6 +274,15 @@ def forward(self, x, y):
274274
return torch.cat((y, y, x, x), axis=2)
275275

276276

277+
class Cat5(torch.nn.Module):
278+
def __init__(self):
279+
super().__init__()
280+
self.const_tensor = torch.randn(1, 1, 2, 2)
281+
282+
def forward(self, x, y):
283+
return torch.cat((x, y, self.const_tensor), axis=2)
284+
285+
277286
class CausalMask(torch.nn.Module):
278287
def __init__(self):
279288
super().__init__()

backends/qualcomm/tests/test_qnn_delegate.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -232,7 +232,7 @@ def test_qnn_backend_cast(self):
232232
self.lower_module_and_test_output(module, sample_input)
233233

234234
def test_qnn_backend_cat(self):
235-
modules = [Cat2(), Cat3(), Cat4()] # noqa: F405
235+
modules = [Cat2(), Cat3(), Cat4(), Cat5()] # noqa: F405
236236
sample_input = (torch.randn(1, 1, 2, 2), torch.randn(1, 1, 4, 2))
237237
for i, module in enumerate(modules):
238238
with self.subTest(i=i):
@@ -1699,7 +1699,7 @@ def test_qnn_backend_cast(self):
16991699
self.lower_module_and_test_output(module, sample_input)
17001700

17011701
def test_qnn_backend_cat(self):
1702-
modules = [Cat2(), Cat3(), Cat4()] # noqa: F405
1702+
modules = [Cat2(), Cat3(), Cat4(), Cat5()] # noqa: F405
17031703
sample_input = (torch.randn(1, 1, 2, 2), torch.randn(1, 1, 4, 2))
17041704
for i, module in enumerate(modules):
17051705
with self.subTest(i=i):

0 commit comments

Comments
 (0)