|
14 | 14 | register_node_visitor, |
15 | 15 | ) |
16 | 16 | from executorch.backends.arm.tosa_mapping import TosaArg |
17 | | -from executorch.backends.arm.tosa_quant_utils import ( |
18 | | - build_rescale, |
19 | | - search_quant_arg_downstream, |
20 | | - search_quant_arg_upstream, |
21 | | -) |
| 17 | +from executorch.backends.arm.tosa_quant_utils import build_rescale, get_quant_node_args |
22 | 18 |
|
23 | 19 | from executorch.backends.arm.tosa_utils import build_reshape |
| 20 | +from executorch.exir.dialects._ops import ops as exir_ops |
24 | 21 | from serializer.tosa_serializer import TosaOp |
25 | 22 |
|
26 | 23 |
|
@@ -70,7 +67,12 @@ def define_node( |
70 | 67 | input_zp = 0 |
71 | 68 | if is_quant_node: |
72 | 69 | input_node = node.all_input_nodes[1] |
73 | | - input_zp = search_quant_arg_upstream(input_node).zp |
| 70 | + # rank > 2 linear layer |
| 71 | + if input_node.target == exir_ops.edge.aten.view_copy.default: |
| 72 | + quant_node = input_node.all_input_nodes[0] |
| 73 | + else: |
| 74 | + quant_node = input_node |
| 75 | + input_zp = get_quant_node_args(quant_node).zp |
74 | 76 | attr.ConvAttribute( |
75 | 77 | pad=pad_attr, |
76 | 78 | stride=stride_attr, |
@@ -105,16 +107,24 @@ def define_node( |
105 | 107 | # Read inputs' parent nodes |
106 | 108 | _, input_node, weight_node = node.all_input_nodes |
107 | 109 |
|
108 | | - qargs = search_quant_arg_upstream(input_node) |
109 | | - input_scale = qargs.scale |
110 | | - consumer_node = list(node.users)[0] |
111 | | - quant_args = search_quant_arg_downstream(consumer_node) |
112 | | - |
113 | | - consumer_node_scale = quant_args.scale |
114 | | - consumer_node_node_zp = quant_args.zp |
| 110 | + # rank > 2 linear layer |
| 111 | + if input_node.target == exir_ops.edge.aten.view_copy.default: |
| 112 | + quant_node = input_node.all_input_nodes[0] |
| 113 | + input_scale = get_quant_node_args(quant_node).scale |
| 114 | + consumer_node = list(node.users)[0] |
| 115 | + consumer_consumer_node = list(consumer_node.users)[0] |
| 116 | + quant_args = get_quant_node_args(consumer_consumer_node) |
| 117 | + consumer_node_scale = quant_args.scale |
| 118 | + consumer_node_node_zp = quant_args.zp |
| 119 | + else: |
| 120 | + input_scale = get_quant_node_args(input_node).scale |
| 121 | + consumer_node = list(node.users)[0] |
| 122 | + quant_args = get_quant_node_args(consumer_node) |
| 123 | + consumer_node_scale = quant_args.scale |
| 124 | + consumer_node_node_zp = quant_args.zp |
115 | 125 |
|
116 | 126 | weight_node_q_node = weight_node.all_input_nodes[0] |
117 | | - weight_scale = search_quant_arg_upstream(weight_node_q_node).scale |
| 127 | + weight_scale = get_quant_node_args(weight_node_q_node).scale |
118 | 128 |
|
119 | 129 | output_rescale_scale = (input_scale * weight_scale) / consumer_node_scale |
120 | 130 |
|
|
0 commit comments