Skip to content

Commit 2ec7678

Browse files
authored
Add test for avg_pool
1 parent 1341a8e commit 2ec7678

File tree

1 file changed

+24
-0
lines changed

1 file changed

+24
-0
lines changed

tests/function_libs/torch_lib/e2e_ops_tests.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -238,6 +238,30 @@ def forward(self, x):
238238
)
239239
_testing.assert_onnx_program(onnx_program)
240240

241+
def test_avg_pool(self):
242+
class Model(torch.nn.Module):
243+
def forward(self, x2d, x3d, x4d, x5d):
244+
return (
245+
torch.nn.functional.avg_pool1d(x2d, 2),
246+
torch.nn.functional.avg_pool1d(x3d, 2),
247+
torch.nn.functional.avg_pool2d(x3d, 2),
248+
torch.nn.functional.avg_pool2d(x4d, 2),
249+
torch.nn.functional.avg_pool3d(x4d, 2),
250+
torch.nn.functional.avg_pool3d(x5d, 2),
251+
)
252+
253+
x2d = torch.randn(10, 10)
254+
x3d = torch.randn(10, 10, 10)
255+
x4d = torch.randn(10, 10, 10, 10)
256+
x5d = torch.randn(10, 10, 10, 10, 10)
257+
onnx_program = torch.onnx.export(
258+
Model(),
259+
(x2d, x3d, x4d, x5d),
260+
dynamo=True,
261+
verbose=False,
262+
)
263+
_testing.assert_onnx_program(onnx_program)
264+
241265

242266
if __name__ == "__main__":
243267
unittest.main()

0 commit comments

Comments
 (0)