1- # Copyright 2024 Arm Limited and/or its affiliates.
1+ # Copyright 2024-2025 Arm Limited and/or its affiliates.
22#
33# This source code is licensed under the BSD-style license found in the
44# LICENSE file in the root directory of this source tree.
1111import serializer .tosa_serializer as ts
1212import torch
1313import torch .fx
14-
15- # pyre-fixme[21]: 'Could not find a module corresponding to import `executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass`.'
16- from executorch .backends .arm ._passes .fold_qdq_with_annotated_qparams_pass import (
17- get_input_qparams ,
18- )
1914from executorch .backends .arm .operators .node_visitor import NodeVisitor
2015from executorch .backends .arm .tosa_mapping import map_dtype , TosaArg
2116from executorch .backends .arm .tosa_quant_utils import (
2419 is_node_quantized ,
2520)
2621from executorch .backends .arm .tosa_specification import TosaSpecification
27- from executorch .backends .arm .tosa_utils import (
28- getNodeArgs ,
29- is_bias_node_for_quantized_conv ,
30- tosa_shape ,
31- )
22+ from executorch .backends .arm .tosa_utils import getNodeArgs , tosa_shape
3223from torch .export .exported_program import ExportedProgram
3324
3425
@@ -99,41 +90,6 @@ def process_inputs(
9990 tosa_graph .addInputTensor (tensor )
10091
10192
102- def process_quantized_bias (
103- node : torch .fx .Node ,
104- tosa_graph : ts .TosaSerializer ,
105- parameter_values ,
106- ):
107- """
108- Serialize bias node that needs to be quantized.
109- """
110- consumer_node = list (node .users )[0 ]
111- (
112- input_node ,
113- weight_node ,
114- _ ,
115- ) = consumer_node .all_input_nodes
116-
117- input_qargs = get_input_qparams ( # pyre-ignore[16]: Module `executorch.backends.arm` has no attribute `_passes`.
118- consumer_node
119- )
120-
121- input_node_scale = input_qargs [0 ].scale
122- weight_node_scale = input_qargs [1 ].scale
123- bias_values_quantized = (
124- (parameter_values / (input_node_scale * weight_node_scale ))
125- .round ()
126- .astype (np .int32 )
127- )
128-
129- tosa_graph .addConst (
130- bias_values_quantized .shape ,
131- ts .DType .INT32 ,
132- bias_values_quantized ,
133- name = node .name ,
134- )
135-
136-
13793def process_inputs_to_parameters (
13894 node : torch .fx .Node ,
13995 tosa_graph : ts .TosaSerializer ,
@@ -148,20 +104,14 @@ def process_inputs_to_parameters(
148104 assert isinstance (parameter_data , torch .Tensor ), "Expect Attr to be tensor"
149105 parameter_values = parameter_data .detach ().numpy ()
150106
151- if is_bias_node_for_quantized_conv (node ):
152- # BI bias
153- assert tosa_spec .support_integer (), f"{ tosa_spec } doesnt't support integer"
154- process_quantized_bias (node , tosa_graph , parameter_values )
155- else :
156- # MI weights or bias
157- if inputs [0 ].dtype == torch .float32 :
158- assert tosa_spec .support_float (), f"{ tosa_spec } doesn't support float"
107+ if inputs [0 ].dtype == torch .float32 :
108+ assert tosa_spec .support_float (), f"{ tosa_spec } doesn't support float"
159109
160- parameter_values = np .transpose (parameter_values , inputs [0 ].dim_order )
110+ parameter_values = np .transpose (parameter_values , inputs [0 ].dim_order )
161111
162- tosa_graph .addConst (
163- parameter_values .shape , inputs [0 ].dtype , parameter_values , name = node .name
164- )
112+ tosa_graph .addConst (
113+ parameter_values .shape , inputs [0 ].dtype , parameter_values , name = node .name
114+ )
165115
166116
167117def process_inputs_to_buffers (
0 commit comments