diff --git a/backends/qualcomm/_passes/convert_bmm_to_matmul.py b/backends/qualcomm/_passes/convert_bmm_to_matmul.py index 480a985a90a..84e1ff26aa1 100644 --- a/backends/qualcomm/_passes/convert_bmm_to_matmul.py +++ b/backends/qualcomm/_passes/convert_bmm_to_matmul.py @@ -46,7 +46,8 @@ def _get_ordered_inputs( def call(self, graph_module: torch.fx.GraphModule): graph = graph_module.graph partitions = get_source_partitions( - graph, [operator.matmul, torch.matmul, torch.bmm] + graph, + [operator.matmul, torch.matmul, torch.bmm, torch.ops.aten.matmul.default], ) for _, src_partitions in partitions.items(): for src_partition in src_partitions: diff --git a/backends/qualcomm/_passes/convert_to_linear.py b/backends/qualcomm/_passes/convert_to_linear.py index 87b9f8a74b8..484b1399d38 100644 --- a/backends/qualcomm/_passes/convert_to_linear.py +++ b/backends/qualcomm/_passes/convert_to_linear.py @@ -190,7 +190,9 @@ def _extract_bmm_ops(self, partitioned_nodes: List[edge_op]) -> List[torch.fx.No return ret def _convert(self, graph_module: torch.fx.GraphModule): - partitions = get_source_partitions(graph_module.graph, [torch.nn.Linear]) + partitions = get_source_partitions( + graph_module.graph, [torch.nn.Linear, torch.ops.aten.linear.default] + ) for _, src_partitions in partitions.items(): for src_partition in src_partitions: op_cnt = Counter(