We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 1341a8e commit 2ec7678Copy full SHA for 2ec7678
tests/function_libs/torch_lib/e2e_ops_tests.py
@@ -238,6 +238,30 @@ def forward(self, x):
238
)
239
_testing.assert_onnx_program(onnx_program)
240
241
+ def test_avg_pool(self):
242
+ class Model(torch.nn.Module):
243
+ def forward(self, x2d, x3d, x4d, x5d):
244
+ return (
245
+ torch.nn.functional.avg_pool1d(x2d, 2),
246
+ torch.nn.functional.avg_pool1d(x3d, 2),
247
+ torch.nn.functional.avg_pool2d(x3d, 2),
248
+ torch.nn.functional.avg_pool2d(x4d, 2),
249
+ torch.nn.functional.avg_pool3d(x4d, 2),
250
+ torch.nn.functional.avg_pool3d(x5d, 2),
251
+ )
252
+
253
+ x2d = torch.randn(10, 10)
254
+ x3d = torch.randn(10, 10, 10)
255
+ x4d = torch.randn(10, 10, 10, 10)
256
+ x5d = torch.randn(10, 10, 10, 10, 10)
257
+ onnx_program = torch.onnx.export(
258
+ Model(),
259
+ (x2d, x3d, x4d, x5d),
260
+ dynamo=True,
261
+ verbose=False,
262
263
+ _testing.assert_onnx_program(onnx_program)
264
265
266
if __name__ == "__main__":
267
unittest.main()
0 commit comments