|
16 | 16 | from executorch.backends.arm.tosa_mapping import TosaArg |
17 | 17 | from executorch.backends.arm.tosa_quant_utils import ( |
18 | 18 | build_rescale, |
19 | | - search_quant_arg_downstream, |
20 | | - search_quant_arg_upstream, |
| 19 | + get_quant_arg_downstream, |
| 20 | + get_quant_arg_upstream, |
21 | 21 | ) |
22 | 22 | from executorch.backends.arm.tosa_utils import ( |
23 | 23 | build_reshape, |
@@ -58,8 +58,8 @@ def define_node( |
58 | 58 | # For INT8, we need to get the zero point, otherwise it is 0 |
59 | 59 | input0_zp, input1_zp = 0, 0 |
60 | 60 | if is_quant_node: |
61 | | - input0_zp = search_quant_arg_upstream(input0).zp |
62 | | - input1_zp = search_quant_arg_upstream(input1).zp |
| 61 | + input0_zp = get_quant_arg_upstream(input0).zp |
| 62 | + input1_zp = get_quant_arg_upstream(input1).zp |
63 | 63 |
|
64 | 64 | mat_mul_result = tosa_graph.addIntermediate( |
65 | 65 | output_new_shape, ts.DType.INT32 if is_quant_node else output.dtype |
@@ -90,9 +90,9 @@ def define_node( |
90 | 90 |
|
91 | 91 | # As INT8 accumulates into INT32, we need to rescale it back to INT8 |
92 | 92 | if is_quant_node: |
93 | | - input0_q_params = search_quant_arg_upstream(input0) |
94 | | - input1_q_params = search_quant_arg_upstream(input1) |
95 | | - output_q_params = search_quant_arg_downstream(list(node.users)[0]) |
| 93 | + input0_q_params = get_quant_arg_upstream(input0) |
| 94 | + input1_q_params = get_quant_arg_upstream(input1) |
| 95 | + output_q_params = get_quant_arg_downstream(list(node.users)[0]) |
96 | 96 |
|
97 | 97 | final_output_scale = ( |
98 | 98 | input0_q_params.scale * input1_q_params.scale |
|
0 commit comments