@@ -730,6 +730,86 @@ func.func @test_maxpool_pad(%arg0: !torch.vtensor<[1,64,111,111],f32>) -> !torch
730
730
return %0 : !torch.vtensor <[1 ,64 ,56 ,56 ],f32 >
731
731
}
732
732
733
+ // -----
734
+
735
+ // CHECK-LABEL: func.func @test_maxpool_2d_same_lower
736
+ func.func @test_maxpool_2d_same_lower (%arg0: !torch.vtensor <[1 ,3 ,32 ,32 ],f32 >) -> !torch.vtensor <[1 ,3 ,32 ,32 ],f32 > attributes {torch.onnx_meta.ir_version = 7 : si64 , torch.onnx_meta.opset_version = 12 : si64 } {
737
+ // CHECK: %[[int1:.*]] = torch.constant.int 1
738
+ // CHECK: %[[int0:.*]] = torch.constant.int 0
739
+ // CHECK: %[[int1_0:.*]] = torch.constant.int 1
740
+ // CHECK: %[[int0_1:.*]] = torch.constant.int 0
741
+ // CHECK: %[[list0:.*]] = torch.prim.ListConstruct %[[int1]], %[[int0]], %[[int1_0]], %[[int0_1]] : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
742
+ // CHECK: %[[FLOAT0:.*]] = torch.constant.float -1.7976931348623157E+308
743
+ // CHECK: %[[FUNC1:.*]] = torch.aten.constant_pad_nd %arg0, %[[list0]], %[[FLOAT0]] : !torch.vtensor<[1,3,32,32],f32>, !torch.list<int>, !torch.float -> !torch.vtensor<[1,3,33,33],f32>
744
+ // CHECK: %[[int2:.*]] = torch.constant.int 2
745
+ // CHECK: %[[int2_2:.*]] = torch.constant.int 2
746
+ // CHECK: %[[list1:.*]] = torch.prim.ListConstruct %[[int2]], %[[int2_2]] : (!torch.int, !torch.int) -> !torch.list<int>
747
+ // CHECK: %[[int0_3:.*]] = torch.constant.int 0
748
+ // CHECK: %[[int0_4:.*]] = torch.constant.int 0
749
+ // CHECK: %[[list2:.*]] = torch.prim.ListConstruct %[[int0_3]], %[[int0_4]] : (!torch.int, !torch.int) -> !torch.list<int>
750
+ // CHECK: %[[int1_5:.*]] = torch.constant.int 1
751
+ // CHECK: %[[int1_6:.*]] = torch.constant.int 1
752
+ // CHECK: %[[list3:.*]] = torch.prim.ListConstruct %[[int1_5]], %[[int1_6]] : (!torch.int, !torch.int) -> !torch.list<int>
753
+ // CHECK: %[[int1_7:.*]] = torch.constant.int 1
754
+ // CHECK: %[[int1_8:.*]] = torch.constant.int 1
755
+ // CHECK: %[[list4:.*]] = torch.prim.ListConstruct %[[int1_7]], %[[int1_8]] : (!torch.int, !torch.int) -> !torch.list<int>
756
+ // CHECK: %[[FALSE:.*]] = torch.constant.bool false
757
+ // CHECK: %[[FUNC6:.*]] = torch.aten.max_pool2d %[[FUNC1]], %[[list1]], %[[list3]], %[[list2]], %[[list4]], %[[FALSE]] : !torch.vtensor<[1,3,33,33],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool -> !torch.vtensor<[1,3,32,32],f32>
758
+ %0 = torch.operator " onnx.MaxPool" (%arg0 ) {torch.onnx.auto_pad = " SAME_LOWER" , torch.onnx.kernel_shape = [2 : si64 , 2 : si64 ]} : (!torch.vtensor <[1 ,3 ,32 ,32 ],f32 >) -> !torch.vtensor <[1 ,3 ,32 ,32 ],f32 >
759
+ return %0 : !torch.vtensor <[1 ,3 ,32 ,32 ],f32 >
760
+ }
761
+
762
+ // -----
763
+
764
+ // CHECK-LABEL: func.func @test_maxpool_2d_same_upper
765
+ func.func @test_maxpool_2d_same_upper (%arg0: !torch.vtensor <[1 ,3 ,32 ,32 ],f32 >) -> !torch.vtensor <[1 ,3 ,32 ,32 ],f32 > attributes {torch.onnx_meta.ir_version = 7 : si64 , torch.onnx_meta.opset_version = 12 : si64 , torch.onnx_meta.producer_name = " backend-test" , torch.onnx_meta.producer_version = " " } {
766
+ // CHECK: %[[int0:.*]] = torch.constant.int 0
767
+ // CHECK: %[[int1:.*]] = torch.constant.int 1
768
+ // CHECK: %[[int0_0:.*]] = torch.constant.int 0
769
+ // CHECK: %[[int1_1:.*]] = torch.constant.int 1
770
+ // CHECK: %[[list0:.*]] = torch.prim.ListConstruct %[[int0]], %[[int1]], %[[int0_0]], %[[int1_1]] : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
771
+ // CHECK: %[[FLOAT0:.*]] = torch.constant.float -1.7976931348623157E+308
772
+ // CHECK: %[[FUNC1:.*]] = torch.aten.constant_pad_nd %arg0, %[[list0]], %[[FLOAT0]] : !torch.vtensor<[1,3,32,32],f32>, !torch.list<int>, !torch.float -> !torch.vtensor<[1,3,33,33],f32>
773
+ // CHECK: %[[int2:.*]] = torch.constant.int 2
774
+ // CHECK: %[[int2_2:.*]] = torch.constant.int 2
775
+ // CHECK: %[[list1:.*]] = torch.prim.ListConstruct %[[int2]], %[[int2_2]] : (!torch.int, !torch.int) -> !torch.list<int>
776
+ // CHECK: %[[int0_3:.*]] = torch.constant.int 0
777
+ // CHECK: %[[int0_4:.*]] = torch.constant.int 0
778
+ // CHECK: %[[list2:.*]] = torch.prim.ListConstruct %[[int0_3]], %[[int0_4]] : (!torch.int, !torch.int) -> !torch.list<int>
779
+ // CHECK: %[[int1_5:.*]] = torch.constant.int 1
780
+ // CHECK: %[[int1_6:.*]] = torch.constant.int 1
781
+ // CHECK: %[[list3:.*]] = torch.prim.ListConstruct %[[int1_5]], %[[int1_6]] : (!torch.int, !torch.int) -> !torch.list<int>
782
+ // CHECK: %[[int1_7:.*]] = torch.constant.int 1
783
+ // CHECK: %[[int1_8:.*]] = torch.constant.int 1
784
+ // CHECK: %[[list4:.*]] = torch.prim.ListConstruct %[[int1_7]], %[[int1_8]] : (!torch.int, !torch.int) -> !torch.list<int>
785
+ // CHECK: %[[FALSE:.*]] = torch.constant.bool false
786
+ // CHECK: %[[FUNC6:.*]] = torch.aten.max_pool2d %[[FUNC1]], %[[list1]], %[[list3]], %[[list2]], %[[list4]], %[[FALSE]] : !torch.vtensor<[1,3,33,33],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool -> !torch.vtensor<[1,3,32,32],f32>
787
+ %0 = torch.operator " onnx.MaxPool" (%arg0 ) {torch.onnx.auto_pad = " SAME_UPPER" , torch.onnx.kernel_shape = [2 : si64 , 2 : si64 ]} : (!torch.vtensor <[1 ,3 ,32 ,32 ],f32 >) -> !torch.vtensor <[1 ,3 ,32 ,32 ],f32 >
788
+ return %0 : !torch.vtensor <[1 ,3 ,32 ,32 ],f32 >
789
+ }
790
+
791
+ // -----
792
+
793
+ // CHECK-LABEL: func.func @test_maxpool_2d_precomputed_same_upper
794
+ func.func @test_maxpool_2d_precomputed_same_upper (%arg0: !torch.vtensor <[1 ,1 ,5 ,5 ],f32 >) -> !torch.vtensor <[1 ,1 ,3 ,3 ],f32 > attributes {torch.onnx_meta.ir_version = 7 : si64 , torch.onnx_meta.opset_version = 12 : si64 }{
795
+ // CHECK: %[[int3:.*]] = torch.constant.int 3
796
+ // CHECK: %[[int3_0:.*]] = torch.constant.int 3
797
+ // CHECK: %[[list0:.*]] = torch.prim.ListConstruct %[[int3]], %[[int3_0]] : (!torch.int, !torch.int) -> !torch.list<int>
798
+ // CHECK: %[[int1:.*]] = torch.constant.int 1
799
+ // CHECK: %[[int1_1:.*]] = torch.constant.int 1
800
+ // CHECK: %[[list1:.*]] = torch.prim.ListConstruct %[[int1]], %[[int1_1]] : (!torch.int, !torch.int) -> !torch.list<int>
801
+ // CHECK: %[[int2:.*]] = torch.constant.int 2
802
+ // CHECK: %[[int2_2:.*]] = torch.constant.int 2
803
+ // CHECK: %[[list2:.*]] = torch.prim.ListConstruct %[[int2]], %[[int2_2]] : (!torch.int, !torch.int) -> !torch.list<int>
804
+ // CHECK: %[[int1_3:.*]] = torch.constant.int 1
805
+ // CHECK: %[[int1_4:.*]] = torch.constant.int 1
806
+ // CHECK: %[[list3:.*]] = torch.prim.ListConstruct %[[int1_3]], %[[int1_4]] : (!torch.int, !torch.int) -> !torch.list<int>
807
+ // CHECK: %[[FALSE:.*]] = torch.constant.bool false
808
+ // CHECK: %[[FUNC4:.*]] = torch.aten.max_pool2d %arg0, %[[list0]], %[[list2]], %[[list1]], %[[list3]], %[[FALSE]] : !torch.vtensor<[1,1,5,5],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool -> !torch.vtensor<[1,1,3,3],f32>
809
+ %0 = torch.operator " onnx.MaxPool" (%arg0 ) {torch.onnx.auto_pad = " SAME_UPPER" , torch.onnx.kernel_shape = [3 : si64 , 3 : si64 ], torch.onnx.strides = [2 : si64 , 2 : si64 ]} : (!torch.vtensor <[1 ,1 ,5 ,5 ],f32 >) -> !torch.vtensor <[1 ,1 ,3 ,3 ],f32 >
810
+ return %0 : !torch.vtensor <[1 ,1 ,3 ,3 ],f32 >
811
+ }
812
+
733
813
734
814
// -----
735
815
0 commit comments