@@ -114,6 +114,33 @@ def _adjust_attributes_of_avg_pool(
114114 return (kernel_shape, strides, pads)
115115
116116
117+ def _aten_avg_pool_onnx(
118+ self: TFloat,
119+ kernel_shape: Sequence[int],
120+ strides: Sequence[int],
121+ pads: Sequence[int],
122+ ceil_mode: bool,
123+ count_include_pad: bool,
124+ ) -> TFloat:
125+ self_rank_is_unbatched_rank = len(self.shape) == len(kernel_shape) + 1
126+ if self_rank_is_unbatched_rank: # C,H,W -> N,C,H,W and N=1
127+ self = op.Unsqueeze(self, [0])
128+
129+ result = op.AveragePool(
130+ self,
131+ ceil_mode=ceil_mode,
132+ count_include_pad=count_include_pad,
133+ kernel_shape=kernel_shape,
134+ pads=pads,
135+ strides=strides,
136+ )
137+
138+ if self_rank_is_unbatched_rank:
139+ result = op.Squeeze(result, [0])
140+
141+ return result
142+
143+
117144@torch_op("aten::avg_pool1d", trace_only=True)
118145def aten_avg_pool1d(
119146 self: TFloat,
@@ -134,16 +161,7 @@ def aten_avg_pool1d(
134161 expand_size, kernel_size, stride, padding
135162 )
136163
137- result = op.AveragePool(
138- self,
139- ceil_mode=ceil_mode,
140- count_include_pad=count_include_pad,
141- kernel_shape=kernel_shape,
142- pads=pads,
143- strides=strides,
144- )
145-
146- return result
164+ return _aten_avg_pool_onnx(self, kernel_shape, strides, pads, ceil_mode, count_include_pad)
147165
148166
149167@torch_op("aten::avg_pool2d", trace_only=True)
@@ -167,15 +185,6 @@ def aten_avg_pool2d(
167185 expand_size, kernel_size, stride, padding
168186 )
169187
170- result = op.AveragePool(
171- self,
172- ceil_mode=ceil_mode,
173- count_include_pad=count_include_pad,
174- kernel_shape=kernel_shape,
175- pads=pads,
176- strides=strides,
177- )
178-
179188 # TODO: if want to support divisor_override argument, need to op.Mul(result, mask)
180189 # mask = [
181190 # 1, 2, 3, S,..3, 2, 1
@@ -189,7 +198,7 @@ def aten_avg_pool2d(
189198 # S is stride size, in this case S=4,
190199 # S may dup lot of times according to the image size
191200
192- return result
201+ return _aten_avg_pool_onnx(self, kernel_shape, strides, pads, ceil_mode, count_include_pad)
193202
194203
195204def aten_avg_pool2d_backward(
@@ -228,15 +237,6 @@ def aten_avg_pool3d(
228237 expand_size, kernel_size, stride, padding
229238 )
230239
231- result = op.AveragePool(
232- self,
233- kernel_shape=kernel_shape,
234- strides=strides,
235- pads=pads,
236- count_include_pad=count_include_pad,
237- ceil_mode=ceil_mode,
238- )
239-
240240 # TODO: if want to support divisor_override argument, need to op.Mul(result, mask)
241241 # mask = [
242242 # 1, 2, 3, S,..3, 2, 1
@@ -250,7 +250,7 @@ def aten_avg_pool3d(
250250 # S is stride size, in this case S=4,
251251 # S may dup lot of times according to the image size
252252
253- return result
253+ return _aten_avg_pool_onnx(self, kernel_shape, strides, pads, ceil_mode, count_include_pad)
254254
255255
256256def aten_avg_pool3d_backward(
0 commit comments