Skip to content

Commit 06aad56

Browse files
committed
Add lit test.
1 parent e9756c7 commit 06aad56

File tree

2 files changed

+29
-3
lines changed

2 files changed

+29
-3
lines changed

projects/pt1/python/torch_mlir_e2e_test/test_suite/pooling.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1865,10 +1865,9 @@ class AvgPool3dSingleIntTupleParamsModule(torch.nn.Module):
18651865
def __init__(self):
18661866
super().__init__()
18671867
self.apd = torch.nn.AvgPool3d(
1868-
kernel_size=(6, 6),
1868+
kernel_size=(6, 6, 6),
18691869
stride=(2,),
1870-
padding=(1, 1),
1871-
count_include_pad=False,
1870+
padding=(1, 1, 1),
18721871
)
18731872

18741873
@export

test/Dialect/Torch/canonicalize.mlir

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3400,3 +3400,30 @@ func.func @torch.symbolic_int$canonicalize(%arg0: !torch.vtensor<[?],f32>, %arg1
34003400
torch.bind_symbolic_shape %3, [%0], affine_map<()[s0] -> (s0)> : !torch.vtensor<[?],f32>
34013401
return %3 : !torch.vtensor<[?],f32>
34023402
}
3403+
3404+
// -----
3405+
// CHECK-LABEL: func.func @torch.aten.avg_pool2d.single_int_tuple(
3406+
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[2,4,20,20],f32>) -> !torch.vtensor<[2,4,9,9],f32> {
3407+
// CHECK: %[[NONE:.*]] = torch.constant.none
3408+
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
3409+
// CHECK: %[[C_6:.*]] = torch.constant.int 6
3410+
// CHECK: %[[C_1:.*]] = torch.constant.int 1
3411+
// CHECK: %[[C_2:.*]] = torch.constant.int 2
3412+
// CHECK: %[[KERNEL:.*]] = torch.prim.ListConstruct %[[C_6]], %[[C_6]] : (!torch.int, !torch.int) -> !torch.list<int>
3413+
// CHECK: %[[STRIDE:.*]] = torch.prim.ListConstruct %[[C_1]], %[[C_1]] : (!torch.int, !torch.int) -> !torch.list<int>
3414+
// CHECK: %[[PAD:.*]] = torch.prim.ListConstruct %[[C_2]], %[[C_2]] : (!torch.int, !torch.int) -> !torch.list<int>
3415+
// CHECK: %[[POOL:.*]] = torch.aten.avg_pool2d %[[ARG0]], %[[KERNEL]], %[[PAD]], %[[STRIDE]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[2,4,20,20],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[2,4,9,9],f32>
3416+
// CHECK: return %[[POOL]]
3417+
func.func @torch.aten.avg_pool2d.single_int_tuple(%arg0: !torch.vtensor<[2,4,20,20],f32>) -> !torch.vtensor<[2,4,9,9],f32> {
3418+
%int6 = torch.constant.int 6
3419+
%0 = torch.prim.ListConstruct %int6 : (!torch.int) -> !torch.list<int>
3420+
%int2 = torch.constant.int 2
3421+
%1 = torch.prim.ListConstruct %int2 : (!torch.int) -> !torch.list<int>
3422+
%int1 = torch.constant.int 1
3423+
%2 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list<int>
3424+
%false = torch.constant.bool false
3425+
%false_0 = torch.constant.bool false
3426+
%none = torch.constant.none
3427+
%3 = torch.aten.avg_pool2d %arg0, %0, %1, %2, %false, %false_0, %none : !torch.vtensor<[2,4,20,20],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[2,4,9,9],f32>
3428+
return %3 : !torch.vtensor<[2,4,9,9],f32>
3429+
}

0 commit comments

Comments
 (0)