44# LICENSE file in the root directory of this source tree.
55
66# pyre-unsafe
7- from typing import List
7+ from typing import Any , List
88
99import torch
1010
11- import tosa_tools .v0_80 .serializer .tosa_serializer as ts # type: ignore
12-
1311from executorch .backends .arm ._passes .fold_qdq_with_annotated_qparams_pass import (
1412 get_input_qparams ,
1513 get_output_qparams ,
@@ -36,14 +34,16 @@ def __init__(self, *args):
3634 def _build_generic_avgpool2d (
3735 self ,
3836 node : torch .fx .Node ,
39- tosa_graph : ts . TosaSerializer ,
37+ tosa_graph : Any ,
4038 inputs : List [TosaArg ],
4139 output : TosaArg ,
4240 input_zp : int ,
4341 output_zp : int ,
44- accumulator_type : ts . DType ,
42+ accumulator_type : Any ,
4543 ) -> None :
4644
45+ import tosa_tools .v0_80 .serializer .tosa_serializer as ts # type: ignore
46+
4747 input_tensor = inputs [0 ]
4848 kernel_size_list = inputs [1 ].special
4949 stride_size_list = inputs [2 ].special
@@ -79,10 +79,12 @@ def _build_generic_avgpool2d(
7979 def define_node (
8080 self ,
8181 node : torch .fx .Node ,
82- tosa_graph : ts . TosaSerializer ,
82+ tosa_graph : Any ,
8383 inputs : List [TosaArg ],
8484 output : TosaArg ,
8585 ) -> None :
86+ import tosa_tools .v0_80 .serializer .tosa_serializer as ts # type: ignore
87+
8688 input_tensor = inputs [0 ]
8789 assert input_tensor .dtype == ts .DType .INT8
8890
@@ -110,10 +112,135 @@ class AvgPool2dVisitor_0_80_MI(AvgPool2dVisitor_0_80_BI):
110112 def define_node (
111113 self ,
112114 node : torch .fx .Node ,
113- tosa_graph : ts . TosaSerializer ,
115+ tosa_graph : Any ,
114116 inputs : List [TosaArg ],
115117 output : TosaArg ,
116118 ) -> None :
119+ import tosa_tools .v0_80 .serializer .tosa_serializer as ts # type: ignore
120+
121+ assert (
122+ inputs [0 ].dtype == ts .DType .INT8 or inputs [0 ].dtype == ts .DType .FP32
123+ ), "Only FP32 and INT8 supported"
124+
125+ if inputs [0 ].dtype == ts .DType .INT8 :
126+ super ().define_node (node , tosa_graph , inputs , output )
127+
128+ if inputs [0 ].dtype == ts .DType .FP32 :
129+ accumulator_type = ts .DType .FP32
130+ # Initilize zero point to zero.
131+ input_zp = 0
132+ output_zp = 0
133+
134+ self ._build_generic_avgpool2d (
135+ node , tosa_graph , inputs , output , input_zp , output_zp , accumulator_type
136+ )
137+
138+
139+ @register_node_visitor
140+ class AvgPool2dVisitor (NodeVisitor ):
141+ target = "aten.avg_pool2d.default"
142+
143+ tosa_specs = [
144+ TosaSpecification .create_from_string ("TOSA-1.0+INT" ),
145+ ]
146+
147+ def __init__ (self , * args ):
148+ super ().__init__ (* args )
149+
150+ def _build_generic_avgpool2d (
151+ self ,
152+ node : torch .fx .Node ,
153+ tosa_graph : Any ,
154+ inputs : List [TosaArg ],
155+ output : TosaArg ,
156+ input_zp : int ,
157+ output_zp : int ,
158+ accumulator_type : Any ,
159+ ) -> None :
160+
161+ import serializer .tosa_serializer as ts # type: ignore
162+
163+ input_tensor = inputs [0 ]
164+ kernel_size_list = inputs [1 ].special
165+ stride_size_list = inputs [2 ].special
166+
167+ try :
168+ pad_size_list = inputs [3 ].special
169+ pad_size_list = [
170+ pad_size_list [0 ],
171+ pad_size_list [0 ],
172+ pad_size_list [1 ],
173+ pad_size_list [1 ],
174+ ]
175+ except IndexError :
176+ pad_size_list = [0 , 0 , 0 , 0 ]
177+
178+ attr = ts .TosaSerializerAttribute ()
179+ attr .AvgPool2dAttribute (
180+ kernel = kernel_size_list ,
181+ stride = stride_size_list ,
182+ pad = pad_size_list ,
183+ acc_type = accumulator_type ,
184+ )
185+ input_zp_tensor = tosa_graph .addConst (
186+ shape = [1 ], dtype = output .dtype , vals = [input_zp ]
187+ )
188+ output_zp_tensor = tosa_graph .addConst (
189+ shape = [1 ], dtype = output .dtype , vals = [output_zp ]
190+ )
191+
192+ tosa_graph .addOperator (
193+ ts .TosaOp .Op ().AVG_POOL2D ,
194+ [input_tensor .name , input_zp_tensor .name , output_zp_tensor .name ],
195+ [output .name ],
196+ attr ,
197+ )
198+
199+ def define_node (
200+ self ,
201+ node : torch .fx .Node ,
202+ tosa_graph : Any ,
203+ inputs : List [TosaArg ],
204+ output : TosaArg ,
205+ ) -> None :
206+ import serializer .tosa_serializer as ts # type: ignore
207+
208+ input_tensor = inputs [0 ]
209+ assert input_tensor .dtype == ts .DType .INT8
210+
211+ accumulator_type = ts .DType .INT32
212+
213+ input_qargs = get_input_qparams (node )
214+ input_zp = input_qargs [0 ].zp
215+
216+ output_qargs = get_output_qparams (node )
217+ output_zp = output_qargs [0 ].zp
218+
219+ self ._build_generic_avgpool2d (
220+ node , tosa_graph , inputs , output , input_zp , output_zp , accumulator_type
221+ )
222+
223+
224+ @register_node_visitor
225+ class AvgPool2dVisitor_FP (AvgPool2dVisitor ):
226+ target = "aten.avg_pool2d.default"
227+
228+ tosa_specs = [
229+ TosaSpecification .create_from_string ("TOSA-1.0+FP" ),
230+ ]
231+
232+ def __init__ (self , * args ):
233+ super ().__init__ (* args )
234+
235+ def define_node (
236+ self ,
237+ node : torch .fx .Node ,
238+ tosa_graph : Any ,
239+ inputs : List [TosaArg ],
240+ output : TosaArg ,
241+ ) -> None :
242+ import serializer .tosa_serializer as ts # type: ignore
243+
117244 assert (
118245 inputs [0 ].dtype == ts .DType .INT8 or inputs [0 ].dtype == ts .DType .FP32
119246 ), "Only FP32 and INT8 supported"
0 commit comments