@@ -3400,3 +3400,30 @@ func.func @torch.symbolic_int$canonicalize(%arg0: !torch.vtensor<[?],f32>, %arg1
3400
3400
torch.bind_symbolic_shape %3 , [%0 ], affine_map <()[s0 ] -> (s0 )> : !torch.vtensor <[?],f32 >
3401
3401
return %3 : !torch.vtensor <[?],f32 >
3402
3402
}
3403
+
3404
+ // -----
3405
+ // CHECK-LABEL: func.func @torch.aten.avg_pool2d.single_int_tuple(
3406
+ // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[2,4,20,20],f32>) -> !torch.vtensor<[2,4,9,9],f32> {
3407
+ // CHECK: %[[NONE:.*]] = torch.constant.none
3408
+ // CHECK: %[[FALSE:.*]] = torch.constant.bool false
3409
+ // CHECK: %[[C_6:.*]] = torch.constant.int 6
3410
+ // CHECK: %[[C_1:.*]] = torch.constant.int 1
3411
+ // CHECK: %[[C_2:.*]] = torch.constant.int 2
3412
+ // CHECK: %[[KERNEL:.*]] = torch.prim.ListConstruct %[[C_6]], %[[C_6]] : (!torch.int, !torch.int) -> !torch.list<int>
3413
+ // CHECK: %[[STRIDE:.*]] = torch.prim.ListConstruct %[[C_1]], %[[C_1]] : (!torch.int, !torch.int) -> !torch.list<int>
3414
+ // CHECK: %[[PAD:.*]] = torch.prim.ListConstruct %[[C_2]], %[[C_2]] : (!torch.int, !torch.int) -> !torch.list<int>
3415
+ // CHECK: %[[POOL:.*]] = torch.aten.avg_pool2d %[[ARG0]], %[[KERNEL]], %[[PAD]], %[[STRIDE]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[2,4,20,20],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[2,4,9,9],f32>
3416
+ // CHECK: return %[[POOL]]
3417
+ func.func @torch.aten.avg_pool2d.single_int_tuple (%arg0: !torch.vtensor <[2 ,4 ,20 ,20 ],f32 >) -> !torch.vtensor <[2 ,4 ,9 ,9 ],f32 > {
3418
+ %int6 = torch.constant.int 6
3419
+ %0 = torch.prim.ListConstruct %int6 : (!torch.int ) -> !torch.list <int >
3420
+ %int2 = torch.constant.int 2
3421
+ %1 = torch.prim.ListConstruct %int2 : (!torch.int ) -> !torch.list <int >
3422
+ %int1 = torch.constant.int 1
3423
+ %2 = torch.prim.ListConstruct %int1 : (!torch.int ) -> !torch.list <int >
3424
+ %false = torch.constant.bool false
3425
+ %false_0 = torch.constant.bool false
3426
+ %none = torch.constant.none
3427
+ %3 = torch.aten.avg_pool2d %arg0 , %0 , %1 , %2 , %false , %false_0 , %none : !torch.vtensor <[2 ,4 ,20 ,20 ],f32 >, !torch.list <int >, !torch.list <int >, !torch.list <int >, !torch.bool , !torch.bool , !torch.none -> !torch.vtensor <[2 ,4 ,9 ,9 ],f32 >
3428
+ return %3 : !torch.vtensor <[2 ,4 ,9 ,9 ],f32 >
3429
+ }
0 commit comments