@@ -114,6 +114,34 @@ def _adjust_attributes_of_avg_pool(
114114 return (kernel_shape , strides , pads )
115115
116116
117+ def _aten_avg_pool_onnx (
118+ self : TFloat ,
119+ kernel_shape : Sequence [int ],
120+ strides : Sequence [int ],
121+ pads : Sequence [int ],
122+ ceil_mode : bool ,
123+ count_include_pad : bool ,
124+ unbatched_rank : int ,
125+ ) -> TFloat :
126+ self_rank_is_unbatched_rank = len (self .shape ) == unbatched_rank
127+ if self_rank_is_unbatched_rank : # C,H,W -> N,C,H,W and N=1
128+ self = op .Unsqueeze (self , [0 ])
129+
130+ result = op .AveragePool (
131+ self ,
132+ ceil_mode = ceil_mode ,
133+ count_include_pad = count_include_pad ,
134+ kernel_shape = kernel_shape ,
135+ pads = pads ,
136+ strides = strides ,
137+ )
138+
139+ if self_rank_is_unbatched_rank :
140+ result = op .Squeeze (result , [0 ])
141+
142+ return result
143+
144+
117145@torch_op ("aten::avg_pool1d" , trace_only = True )
118146def aten_avg_pool1d (
119147 self : TFloat ,
@@ -134,16 +162,7 @@ def aten_avg_pool1d(
134162 expand_size , kernel_size , stride , padding
135163 )
136164
137- result = op .AveragePool (
138- self ,
139- ceil_mode = ceil_mode ,
140- count_include_pad = count_include_pad ,
141- kernel_shape = kernel_shape ,
142- pads = pads ,
143- strides = strides ,
144- )
145-
146- return result
165+ return _aten_avg_pool_onnx (self , kernel_shape , strides , pads , ceil_mode , count_include_pad , 2 )
147166
148167
149168@torch_op ("aten::avg_pool2d" , trace_only = True )
@@ -167,15 +186,6 @@ def aten_avg_pool2d(
167186 expand_size , kernel_size , stride , padding
168187 )
169188
170- result = op .AveragePool (
171- self ,
172- ceil_mode = ceil_mode ,
173- count_include_pad = count_include_pad ,
174- kernel_shape = kernel_shape ,
175- pads = pads ,
176- strides = strides ,
177- )
178-
179189 # TODO: if want to support divisor_override argument, need to op.Mul(result, mask)
180190 # mask = [
181191 # 1, 2, 3, S,..3, 2, 1
@@ -189,7 +199,7 @@ def aten_avg_pool2d(
189199 # S is stride size, in this case S=4,
190200 # S may dup lot of times according to the image size
191201
192- return result
202+ return _aten_avg_pool_onnx ( self , kernel_shape , strides , pads , ceil_mode , count_include_pad , 3 )
193203
194204
195205def aten_avg_pool2d_backward (
@@ -228,15 +238,6 @@ def aten_avg_pool3d(
228238 expand_size , kernel_size , stride , padding
229239 )
230240
231- result = op .AveragePool (
232- self ,
233- kernel_shape = kernel_shape ,
234- strides = strides ,
235- pads = pads ,
236- count_include_pad = count_include_pad ,
237- ceil_mode = ceil_mode ,
238- )
239-
240241 # TODO: if want to support divisor_override argument, need to op.Mul(result, mask)
241242 # mask = [
242243 # 1, 2, 3, S,..3, 2, 1
@@ -250,7 +251,7 @@ def aten_avg_pool3d(
250251 # S is stride size, in this case S=4,
251252 # S may dup lot of times according to the image size
252253
253- return result
254+ return _aten_avg_pool_onnx ( self , kernel_shape , strides , pads , ceil_mode , count_include_pad , 4 )
254255
255256
256257def aten_avg_pool3d_backward (
0 commit comments