44# LICENSE file in the root directory of this source tree.
55
66# pyre-unsafe
7- from typing import Any , List
7+ from typing import List
88
99import torch
1010
11+ import tosa_tools .v0_80 .serializer .tosa_serializer as ts # type: ignore
12+
1113from executorch .backends .arm ._passes .fold_qdq_with_annotated_qparams_pass import (
1214 get_input_qparams ,
1315 get_output_qparams ,
@@ -34,16 +36,14 @@ def __init__(self, *args):
3436 def _build_generic_avgpool2d (
3537 self ,
3638 node : torch .fx .Node ,
37- tosa_graph : Any ,
39+ tosa_graph : ts . TosaSerializer ,
3840 inputs : List [TosaArg ],
3941 output : TosaArg ,
4042 input_zp : int ,
4143 output_zp : int ,
42- accumulator_type : Any ,
44+ accumulator_type : ts . DType ,
4345 ) -> None :
4446
45- import tosa_tools .v0_80 .serializer .tosa_serializer as ts # type: ignore
46-
4747 input_tensor = inputs [0 ]
4848 kernel_size_list = inputs [1 ].special
4949 stride_size_list = inputs [2 ].special
@@ -79,12 +79,10 @@ def _build_generic_avgpool2d(
7979 def define_node (
8080 self ,
8181 node : torch .fx .Node ,
82- tosa_graph : Any ,
82+ tosa_graph : ts . TosaSerializer ,
8383 inputs : List [TosaArg ],
8484 output : TosaArg ,
8585 ) -> None :
86- import tosa_tools .v0_80 .serializer .tosa_serializer as ts # type: ignore
87-
8886 input_tensor = inputs [0 ]
8987 assert input_tensor .dtype == ts .DType .INT8
9088
@@ -112,135 +110,10 @@ class AvgPool2dVisitor_0_80_MI(AvgPool2dVisitor_0_80_BI):
112110 def define_node (
113111 self ,
114112 node : torch .fx .Node ,
115- tosa_graph : Any ,
113+ tosa_graph : ts . TosaSerializer ,
116114 inputs : List [TosaArg ],
117115 output : TosaArg ,
118116 ) -> None :
119- import tosa_tools .v0_80 .serializer .tosa_serializer as ts # type: ignore
120-
121- assert (
122- inputs [0 ].dtype == ts .DType .INT8 or inputs [0 ].dtype == ts .DType .FP32
123- ), "Only FP32 and INT8 supported"
124-
125- if inputs [0 ].dtype == ts .DType .INT8 :
126- super ().define_node (node , tosa_graph , inputs , output )
127-
128- if inputs [0 ].dtype == ts .DType .FP32 :
129- accumulator_type = ts .DType .FP32
130- # Initilize zero point to zero.
131- input_zp = 0
132- output_zp = 0
133-
134- self ._build_generic_avgpool2d (
135- node , tosa_graph , inputs , output , input_zp , output_zp , accumulator_type
136- )
137-
138-
139- @register_node_visitor
140- class AvgPool2dVisitor (NodeVisitor ):
141- target = "aten.avg_pool2d.default"
142-
143- tosa_specs = [
144- TosaSpecification .create_from_string ("TOSA-1.0+INT" ),
145- ]
146-
147- def __init__ (self , * args ):
148- super ().__init__ (* args )
149-
150- def _build_generic_avgpool2d (
151- self ,
152- node : torch .fx .Node ,
153- tosa_graph : Any ,
154- inputs : List [TosaArg ],
155- output : TosaArg ,
156- input_zp : int ,
157- output_zp : int ,
158- accumulator_type : Any ,
159- ) -> None :
160-
161- import serializer .tosa_serializer as ts # type: ignore
162-
163- input_tensor = inputs [0 ]
164- kernel_size_list = inputs [1 ].special
165- stride_size_list = inputs [2 ].special
166-
167- try :
168- pad_size_list = inputs [3 ].special
169- pad_size_list = [
170- pad_size_list [0 ],
171- pad_size_list [0 ],
172- pad_size_list [1 ],
173- pad_size_list [1 ],
174- ]
175- except IndexError :
176- pad_size_list = [0 , 0 , 0 , 0 ]
177-
178- attr = ts .TosaSerializerAttribute ()
179- attr .AvgPool2dAttribute (
180- kernel = kernel_size_list ,
181- stride = stride_size_list ,
182- pad = pad_size_list ,
183- acc_type = accumulator_type ,
184- )
185- input_zp_tensor = tosa_graph .addConst (
186- shape = [1 ], dtype = output .dtype , vals = [input_zp ]
187- )
188- output_zp_tensor = tosa_graph .addConst (
189- shape = [1 ], dtype = output .dtype , vals = [output_zp ]
190- )
191-
192- tosa_graph .addOperator (
193- ts .TosaOp .Op ().AVG_POOL2D ,
194- [input_tensor .name , input_zp_tensor .name , output_zp_tensor .name ],
195- [output .name ],
196- attr ,
197- )
198-
199- def define_node (
200- self ,
201- node : torch .fx .Node ,
202- tosa_graph : Any ,
203- inputs : List [TosaArg ],
204- output : TosaArg ,
205- ) -> None :
206- import serializer .tosa_serializer as ts # type: ignore
207-
208- input_tensor = inputs [0 ]
209- assert input_tensor .dtype == ts .DType .INT8
210-
211- accumulator_type = ts .DType .INT32
212-
213- input_qargs = get_input_qparams (node )
214- input_zp = input_qargs [0 ].zp
215-
216- output_qargs = get_output_qparams (node )
217- output_zp = output_qargs [0 ].zp
218-
219- self ._build_generic_avgpool2d (
220- node , tosa_graph , inputs , output , input_zp , output_zp , accumulator_type
221- )
222-
223-
224- @register_node_visitor
225- class AvgPool2dVisitor_FP (AvgPool2dVisitor ):
226- target = "aten.avg_pool2d.default"
227-
228- tosa_specs = [
229- TosaSpecification .create_from_string ("TOSA-1.0+FP" ),
230- ]
231-
232- def __init__ (self , * args ):
233- super ().__init__ (* args )
234-
235- def define_node (
236- self ,
237- node : torch .fx .Node ,
238- tosa_graph : Any ,
239- inputs : List [TosaArg ],
240- output : TosaArg ,
241- ) -> None :
242- import serializer .tosa_serializer as ts # type: ignore
243-
244117 assert (
245118 inputs [0 ].dtype == ts .DType .INT8 or inputs [0 ].dtype == ts .DType .FP32
246119 ), "Only FP32 and INT8 supported"
0 commit comments