Skip to content

Commit 644393a

Browse files
committed
fix ci error
1 parent 6e2b755 commit 644393a

File tree

1 file changed

+11
-2
lines changed

1 file changed

+11
-2
lines changed

backends/qualcomm/_passes/recompose_rms_norm.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
# LICENSE file in the root directory of this source tree.
66
import torch
77
from executorch.backends.qualcomm._passes.utils import find_patterns
8-
8+
from executorch.backends.qualcomm.builders.node_visitor import dq_ops
99
from executorch.exir.dialects._ops import ops as exir_ops
1010
from executorch.exir.pass_base import ExportPass, PassResult
1111

@@ -29,10 +29,16 @@ def _is_get_attr(node):
2929
def _is_add(node):
3030
return _is_call(node) and node.target in [
3131
exir_ops.edge.aten.add.Tensor,
32+
exir_ops.edge.aten.add.Scalar,
3233
torch.ops.aten.add.Tensor,
34+
torch.ops.aten.add.Scalar,
3335
]
3436

3537

38+
def _is_dq(node):
39+
return _is_call(node) and node.target in dq_ops
40+
41+
3642
def _is_mean(node):
3743
return _is_call(node) and node.target in [
3844
exir_ops.edge.aten.mean.dim,
@@ -50,6 +56,7 @@ def _is_mul(node):
5056
def _is_pow(node):
5157
return _is_call(node) and node.target in [
5258
exir_ops.edge.aten.pow.Tensor_Tensor,
59+
exir_ops.edge.aten.pow.Tensor_Scalar,
5360
torch.ops.aten.pow.Tensor_Scalar,
5461
]
5562

@@ -72,6 +79,7 @@ def __init__(self, quantization_capture=False):
7279
self.skip_targets = [
7380
exir_ops.edge.aten.to.dtype,
7481
]
82+
self.quantization_capture = quantization_capture
7583
if quantization_capture:
7684
self.rms_norm_target = torch.ops.aten.rms_norm.default
7785
self.skip_targets = [
@@ -112,8 +120,9 @@ def call(self, graph_module: torch.fx.GraphModule):
112120
gamma_node = None
113121
# weight should be a constant
114122
for arg in last_mul_node.args:
115-
if _is_get_attr(arg) or _is_placeholder(arg):
123+
if _is_get_attr(arg) or _is_placeholder(arg) or _is_dq(arg):
116124
gamma_node = arg
125+
117126
if gamma_node is None:
118127
continue
119128

0 commit comments

Comments
 (0)