Skip to content

Commit 1341a8e

Browse files
authored
Unsqueeze unbatched input of avg_pool
1 parent 8a94ad6 commit 1341a8e

File tree

1 file changed

+31
-30
lines changed
  • onnxscript/function_libs/torch_lib/ops

1 file changed

+31
-30
lines changed

onnxscript/function_libs/torch_lib/ops/nn.py

Lines changed: 31 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,34 @@ def _adjust_attributes_of_avg_pool(
114114
return (kernel_shape, strides, pads)
115115

116116

117+
def _aten_avg_pool_onnx(
118+
self: TFloat,
119+
kernel_shape: Sequence[int],
120+
strides: Sequence[int],
121+
pads: Sequence[int],
122+
ceil_mode: bool,
123+
count_include_pad: bool,
124+
unbatched_rank: int,
125+
) -> TFloat:
126+
self_rank_is_unbatched_rank = len(self.shape) == unbatched_rank
127+
if self_rank_is_unbatched_rank: # C,H,W -> N,C,H,W and N=1
128+
self = op.Unsqueeze(self, [0])
129+
130+
result = op.AveragePool(
131+
self,
132+
ceil_mode=ceil_mode,
133+
count_include_pad=count_include_pad,
134+
kernel_shape=kernel_shape,
135+
pads=pads,
136+
strides=strides,
137+
)
138+
139+
if self_rank_is_unbatched_rank:
140+
result = op.Squeeze(result, [0])
141+
142+
return result
143+
144+
117145
@torch_op("aten::avg_pool1d", trace_only=True)
118146
def aten_avg_pool1d(
119147
self: TFloat,
@@ -134,16 +162,7 @@ def aten_avg_pool1d(
134162
expand_size, kernel_size, stride, padding
135163
)
136164

137-
result = op.AveragePool(
138-
self,
139-
ceil_mode=ceil_mode,
140-
count_include_pad=count_include_pad,
141-
kernel_shape=kernel_shape,
142-
pads=pads,
143-
strides=strides,
144-
)
145-
146-
return result
165+
return _aten_avg_pool_onnx(self, kernel_shape, strides, pads, ceil_mode, count_include_pad, 2)
147166

148167

149168
@torch_op("aten::avg_pool2d", trace_only=True)
@@ -167,15 +186,6 @@ def aten_avg_pool2d(
167186
expand_size, kernel_size, stride, padding
168187
)
169188

170-
result = op.AveragePool(
171-
self,
172-
ceil_mode=ceil_mode,
173-
count_include_pad=count_include_pad,
174-
kernel_shape=kernel_shape,
175-
pads=pads,
176-
strides=strides,
177-
)
178-
179189
# TODO: if want to support divisor_override argument, need to op.Mul(result, mask)
180190
# mask = [
181191
# 1, 2, 3, S,..3, 2, 1
@@ -189,7 +199,7 @@ def aten_avg_pool2d(
189199
# S is stride size, in this case S=4,
190200
# S may dup lot of times according to the image size
191201

192-
return result
202+
return _aten_avg_pool_onnx(self, kernel_shape, strides, pads, ceil_mode, count_include_pad, 3)
193203

194204

195205
def aten_avg_pool2d_backward(
@@ -228,15 +238,6 @@ def aten_avg_pool3d(
228238
expand_size, kernel_size, stride, padding
229239
)
230240

231-
result = op.AveragePool(
232-
self,
233-
kernel_shape=kernel_shape,
234-
strides=strides,
235-
pads=pads,
236-
count_include_pad=count_include_pad,
237-
ceil_mode=ceil_mode,
238-
)
239-
240241
# TODO: if want to support divisor_override argument, need to op.Mul(result, mask)
241242
# mask = [
242243
# 1, 2, 3, S,..3, 2, 1
@@ -250,7 +251,7 @@ def aten_avg_pool3d(
250251
# S is stride size, in this case S=4,
251252
# S may dup lot of times according to the image size
252253

253-
return result
254+
return _aten_avg_pool_onnx(self, kernel_shape, strides, pads, ceil_mode, count_include_pad, 4)
254255

255256

256257
def aten_avg_pool3d_backward(

0 commit comments

Comments
 (0)