55# LICENSE file in the root directory of this source tree.
66import torch
77from executorch .backends .qualcomm ._passes .utils import find_patterns
8-
8+ from executorch . backends . qualcomm . builders . node_visitor import dq_ops
99from executorch .exir .dialects ._ops import ops as exir_ops
1010from executorch .exir .pass_base import ExportPass , PassResult
1111
@@ -29,10 +29,16 @@ def _is_get_attr(node):
2929def _is_add (node ):
3030 return _is_call (node ) and node .target in [
3131 exir_ops .edge .aten .add .Tensor ,
32+ exir_ops .edge .aten .add .Scalar ,
3233 torch .ops .aten .add .Tensor ,
34+ torch .ops .aten .add .Scalar ,
3335 ]
3436
3537
38+ def _is_dq (node ):
39+ return _is_call (node ) and node .target in dq_ops
40+
41+
3642def _is_mean (node ):
3743 return _is_call (node ) and node .target in [
3844 exir_ops .edge .aten .mean .dim ,
@@ -50,6 +56,7 @@ def _is_mul(node):
5056def _is_pow (node ):
5157 return _is_call (node ) and node .target in [
5258 exir_ops .edge .aten .pow .Tensor_Tensor ,
59+ exir_ops .edge .aten .pow .Tensor_Scalar ,
5360 torch .ops .aten .pow .Tensor_Scalar ,
5461 ]
5562
@@ -72,6 +79,7 @@ def __init__(self, quantization_capture=False):
7279 self .skip_targets = [
7380 exir_ops .edge .aten .to .dtype ,
7481 ]
82+ self .quantization_capture = quantization_capture
7583 if quantization_capture :
7684 self .rms_norm_target = torch .ops .aten .rms_norm .default
7785 self .skip_targets = [
@@ -112,8 +120,9 @@ def call(self, graph_module: torch.fx.GraphModule):
112120 gamma_node = None
113121 # weight should be a constant
114122 for arg in last_mul_node .args :
115- if _is_get_attr (arg ) or _is_placeholder (arg ):
123+ if _is_get_attr (arg ) or _is_placeholder (arg ) or _is_dq ( arg ) :
116124 gamma_node = arg
125+
117126 if gamma_node is None :
118127 continue
119128
0 commit comments