Skip to content

Commit 94ad9d2

Browse files
committed
Add lit test.
1 parent 61e4dd8 commit 94ad9d2

File tree

2 files changed

+29
-3
lines changed

2 files changed

+29
-3
lines changed

projects/pt1/python/torch_mlir_e2e_test/test_suite/pooling.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1865,10 +1865,9 @@ class AvgPool3dSingleIntTupleParamsModule(torch.nn.Module):
18651865
def __init__(self):
18661866
super().__init__()
18671867
self.apd = torch.nn.AvgPool3d(
1868-
kernel_size=(6, 6),
1868+
kernel_size=(6, 6, 6),
18691869
stride=(2,),
1870-
padding=(1, 1),
1871-
count_include_pad=False,
1870+
padding=(1, 1, 1),
18721871
)
18731872

18741873
@export

test/Dialect/Torch/canonicalize.mlir

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3389,3 +3389,30 @@ func.func @torch.symbolic_int$canonicalize(%arg0: !torch.vtensor<[?],f32>, %arg1
33893389
torch.bind_symbolic_shape %3, [%0], affine_map<()[s0] -> (s0)> : !torch.vtensor<[?],f32>
33903390
return %3 : !torch.vtensor<[?],f32>
33913391
}
3392+
3393+
// -----
3394+
// CHECK-LABEL: func.func @torch.aten.avg_pool2d.single_int_tuple(
3395+
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[2,4,20,20],f32>) -> !torch.vtensor<[2,4,9,9],f32> {
3396+
// CHECK: %[[NONE:.*]] = torch.constant.none
3397+
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
3398+
// CHECK: %[[C_6:.*]] = torch.constant.int 6
3399+
// CHECK: %[[C_1:.*]] = torch.constant.int 1
3400+
// CHECK: %[[C_2:.*]] = torch.constant.int 2
3401+
// CHECK: %[[KERNEL:.*]] = torch.prim.ListConstruct %[[C_6]], %[[C_6]] : (!torch.int, !torch.int) -> !torch.list<int>
3402+
// CHECK: %[[STRIDE:.*]] = torch.prim.ListConstruct %[[C_1]], %[[C_1]] : (!torch.int, !torch.int) -> !torch.list<int>
3403+
// CHECK: %[[PAD:.*]] = torch.prim.ListConstruct %[[C_2]], %[[C_2]] : (!torch.int, !torch.int) -> !torch.list<int>
3404+
// CHECK: %[[POOL:.*]] = torch.aten.avg_pool2d %[[ARG0]], %[[KERNEL]], %[[PAD]], %[[STRIDE]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[2,4,20,20],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[2,4,9,9],f32>
3405+
// CHECK: return %[[POOL]]
3406+
func.func @torch.aten.avg_pool2d.single_int_tuple(%arg0: !torch.vtensor<[2,4,20,20],f32>) -> !torch.vtensor<[2,4,9,9],f32> {
3407+
%int6 = torch.constant.int 6
3408+
%0 = torch.prim.ListConstruct %int6 : (!torch.int) -> !torch.list<int>
3409+
%int2 = torch.constant.int 2
3410+
%1 = torch.prim.ListConstruct %int2 : (!torch.int) -> !torch.list<int>
3411+
%int1 = torch.constant.int 1
3412+
%2 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list<int>
3413+
%false = torch.constant.bool false
3414+
%false_0 = torch.constant.bool false
3415+
%none = torch.constant.none
3416+
%3 = torch.aten.avg_pool2d %arg0, %0, %1, %2, %false, %false_0, %none : !torch.vtensor<[2,4,20,20],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[2,4,9,9],f32>
3417+
return %3 : !torch.vtensor<[2,4,9,9],f32>
3418+
}

0 commit comments

Comments
 (0)