@@ -730,6 +730,86 @@ func.func @test_maxpool_pad(%arg0: !torch.vtensor<[1,64,111,111],f32>) -> !torch
730730 return %0 : !torch.vtensor <[1 ,64 ,56 ,56 ],f32 >
731731}
732732
733+ // -----
734+
735+ // CHECK-LABEL: func.func @test_maxpool_2d_same_lower
736+ func.func @test_maxpool_2d_same_lower (%arg0: !torch.vtensor <[1 ,3 ,32 ,32 ],f32 >) -> !torch.vtensor <[1 ,3 ,32 ,32 ],f32 > attributes {torch.onnx_meta.ir_version = 7 : si64 , torch.onnx_meta.opset_version = 12 : si64 } {
737+ // CHECK: %[[int1:.*]] = torch.constant.int 1
738+ // CHECK: %[[int0:.*]] = torch.constant.int 0
739+ // CHECK: %[[int1_0:.*]] = torch.constant.int 1
740+ // CHECK: %[[int0_1:.*]] = torch.constant.int 0
741+ // CHECK: %[[list0:.*]] = torch.prim.ListConstruct %[[int1]], %[[int0]], %[[int1_0]], %[[int0_1]] : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
742+ // CHECK: %[[FLOAT0:.*]] = torch.constant.float -1.7976931348623157E+308
743+ // CHECK: %[[FUNC1:.*]] = torch.aten.constant_pad_nd %arg0, %[[list0]], %[[FLOAT0]] : !torch.vtensor<[1,3,32,32],f32>, !torch.list<int>, !torch.float -> !torch.vtensor<[1,3,33,33],f32>
744+ // CHECK: %[[int2:.*]] = torch.constant.int 2
745+ // CHECK: %[[int2_2:.*]] = torch.constant.int 2
746+ // CHECK: %[[list1:.*]] = torch.prim.ListConstruct %[[int2]], %[[int2_2]] : (!torch.int, !torch.int) -> !torch.list<int>
747+ // CHECK: %[[int0_3:.*]] = torch.constant.int 0
748+ // CHECK: %[[int0_4:.*]] = torch.constant.int 0
749+ // CHECK: %[[list2:.*]] = torch.prim.ListConstruct %[[int0_3]], %[[int0_4]] : (!torch.int, !torch.int) -> !torch.list<int>
750+ // CHECK: %[[int1_5:.*]] = torch.constant.int 1
751+ // CHECK: %[[int1_6:.*]] = torch.constant.int 1
752+ // CHECK: %[[list3:.*]] = torch.prim.ListConstruct %[[int1_5]], %[[int1_6]] : (!torch.int, !torch.int) -> !torch.list<int>
753+ // CHECK: %[[int1_7:.*]] = torch.constant.int 1
754+ // CHECK: %[[int1_8:.*]] = torch.constant.int 1
755+ // CHECK: %[[list4:.*]] = torch.prim.ListConstruct %[[int1_7]], %[[int1_8]] : (!torch.int, !torch.int) -> !torch.list<int>
756+ // CHECK: %[[FALSE:.*]] = torch.constant.bool false
757+ // CHECK: %[[FUNC6:.*]] = torch.aten.max_pool2d %[[FUNC1]], %[[list1]], %[[list3]], %[[list2]], %[[list4]], %[[FALSE]] : !torch.vtensor<[1,3,33,33],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool -> !torch.vtensor<[1,3,32,32],f32>
758+ %0 = torch.operator " onnx.MaxPool" (%arg0 ) {torch.onnx.auto_pad = " SAME_LOWER" , torch.onnx.kernel_shape = [2 : si64 , 2 : si64 ]} : (!torch.vtensor <[1 ,3 ,32 ,32 ],f32 >) -> !torch.vtensor <[1 ,3 ,32 ,32 ],f32 >
759+ return %0 : !torch.vtensor <[1 ,3 ,32 ,32 ],f32 >
760+ }
761+
762+ // -----
763+
764+ // CHECK-LABEL: func.func @test_maxpool_2d_same_upper
765+ func.func @test_maxpool_2d_same_upper (%arg0: !torch.vtensor <[1 ,3 ,32 ,32 ],f32 >) -> !torch.vtensor <[1 ,3 ,32 ,32 ],f32 > attributes {torch.onnx_meta.ir_version = 7 : si64 , torch.onnx_meta.opset_version = 12 : si64 , torch.onnx_meta.producer_name = " backend-test" , torch.onnx_meta.producer_version = " " } {
766+ // CHECK: %[[int0:.*]] = torch.constant.int 0
767+ // CHECK: %[[int1:.*]] = torch.constant.int 1
768+ // CHECK: %[[int0_0:.*]] = torch.constant.int 0
769+ // CHECK: %[[int1_1:.*]] = torch.constant.int 1
770+ // CHECK: %[[list0:.*]] = torch.prim.ListConstruct %[[int0]], %[[int1]], %[[int0_0]], %[[int1_1]] : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
771+ // CHECK: %[[FLOAT0:.*]] = torch.constant.float -1.7976931348623157E+308
772+ // CHECK: %[[FUNC1:.*]] = torch.aten.constant_pad_nd %arg0, %[[list0]], %[[FLOAT0]] : !torch.vtensor<[1,3,32,32],f32>, !torch.list<int>, !torch.float -> !torch.vtensor<[1,3,33,33],f32>
773+ // CHECK: %[[int2:.*]] = torch.constant.int 2
774+ // CHECK: %[[int2_2:.*]] = torch.constant.int 2
775+ // CHECK: %[[list1:.*]] = torch.prim.ListConstruct %[[int2]], %[[int2_2]] : (!torch.int, !torch.int) -> !torch.list<int>
776+ // CHECK: %[[int0_3:.*]] = torch.constant.int 0
777+ // CHECK: %[[int0_4:.*]] = torch.constant.int 0
778+ // CHECK: %[[list2:.*]] = torch.prim.ListConstruct %[[int0_3]], %[[int0_4]] : (!torch.int, !torch.int) -> !torch.list<int>
779+ // CHECK: %[[int1_5:.*]] = torch.constant.int 1
780+ // CHECK: %[[int1_6:.*]] = torch.constant.int 1
781+ // CHECK: %[[list3:.*]] = torch.prim.ListConstruct %[[int1_5]], %[[int1_6]] : (!torch.int, !torch.int) -> !torch.list<int>
782+ // CHECK: %[[int1_7:.*]] = torch.constant.int 1
783+ // CHECK: %[[int1_8:.*]] = torch.constant.int 1
784+ // CHECK: %[[list4:.*]] = torch.prim.ListConstruct %[[int1_7]], %[[int1_8]] : (!torch.int, !torch.int) -> !torch.list<int>
785+ // CHECK: %[[FALSE:.*]] = torch.constant.bool false
786+ // CHECK: %[[FUNC6:.*]] = torch.aten.max_pool2d %[[FUNC1]], %[[list1]], %[[list3]], %[[list2]], %[[list4]], %[[FALSE]] : !torch.vtensor<[1,3,33,33],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool -> !torch.vtensor<[1,3,32,32],f32>
787+ %0 = torch.operator " onnx.MaxPool" (%arg0 ) {torch.onnx.auto_pad = " SAME_UPPER" , torch.onnx.kernel_shape = [2 : si64 , 2 : si64 ]} : (!torch.vtensor <[1 ,3 ,32 ,32 ],f32 >) -> !torch.vtensor <[1 ,3 ,32 ,32 ],f32 >
788+ return %0 : !torch.vtensor <[1 ,3 ,32 ,32 ],f32 >
789+ }
790+
791+ // -----
792+
793+ // CHECK-LABEL: func.func @test_maxpool_2d_precomputed_same_upper
794+ func.func @test_maxpool_2d_precomputed_same_upper (%arg0: !torch.vtensor <[1 ,1 ,5 ,5 ],f32 >) -> !torch.vtensor <[1 ,1 ,3 ,3 ],f32 > attributes {torch.onnx_meta.ir_version = 7 : si64 , torch.onnx_meta.opset_version = 12 : si64 }{
795+ // CHECK: %[[int3:.*]] = torch.constant.int 3
796+ // CHECK: %[[int3_0:.*]] = torch.constant.int 3
797+ // CHECK: %[[list0:.*]] = torch.prim.ListConstruct %[[int3]], %[[int3_0]] : (!torch.int, !torch.int) -> !torch.list<int>
798+ // CHECK: %[[int1:.*]] = torch.constant.int 1
799+ // CHECK: %[[int1_1:.*]] = torch.constant.int 1
800+ // CHECK: %[[list1:.*]] = torch.prim.ListConstruct %[[int1]], %[[int1_1]] : (!torch.int, !torch.int) -> !torch.list<int>
801+ // CHECK: %[[int2:.*]] = torch.constant.int 2
802+ // CHECK: %[[int2_2:.*]] = torch.constant.int 2
803+ // CHECK: %[[list2:.*]] = torch.prim.ListConstruct %[[int2]], %[[int2_2]] : (!torch.int, !torch.int) -> !torch.list<int>
804+ // CHECK: %[[int1_3:.*]] = torch.constant.int 1
805+ // CHECK: %[[int1_4:.*]] = torch.constant.int 1
806+ // CHECK: %[[list3:.*]] = torch.prim.ListConstruct %[[int1_3]], %[[int1_4]] : (!torch.int, !torch.int) -> !torch.list<int>
807+ // CHECK: %[[FALSE:.*]] = torch.constant.bool false
808+ // CHECK: %[[FUNC4:.*]] = torch.aten.max_pool2d %arg0, %[[list0]], %[[list2]], %[[list1]], %[[list3]], %[[FALSE]] : !torch.vtensor<[1,1,5,5],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool -> !torch.vtensor<[1,1,3,3],f32>
809+ %0 = torch.operator " onnx.MaxPool" (%arg0 ) {torch.onnx.auto_pad = " SAME_UPPER" , torch.onnx.kernel_shape = [3 : si64 , 3 : si64 ], torch.onnx.strides = [2 : si64 , 2 : si64 ]} : (!torch.vtensor <[1 ,1 ,5 ,5 ],f32 >) -> !torch.vtensor <[1 ,1 ,3 ,3 ],f32 >
810+ return %0 : !torch.vtensor <[1 ,1 ,3 ,3 ],f32 >
811+ }
812+
733813
734814// -----
735815
0 commit comments