@@ -103,7 +103,7 @@ func.func @test_pass_dims_through_concat(%arg0: tensor<?x256xi64>) -> (tensor<4x
103103
104104// -----
105105
106- func.func @test_pass_dims_through_cast_2 (%arg0: tensor <?x?x200 xf32 >) -> tensor <2 xi64 > {
106+ func.func @test_pass_dims_through_gather (%arg0: tensor <?x?x200 xf32 >) -> tensor <2 xi64 > {
107107 %0 = onnx.Constant dense <[0 , 1 ]> : tensor <2 xi64 >
108108 %1 = " onnx.Dim" (%arg0 ) {axis = 0 : si64 } : (tensor <?x?x200 xf32 >) -> tensor <1 xi64 >
109109 %2 = " onnx.Dim" (%arg0 ) {axis = 1 : si64 } : (tensor <?x?x200 xf32 >) -> tensor <1 xi64 >
@@ -113,7 +113,28 @@ func.func @test_pass_dims_through_cast_2(%arg0: tensor<?x?x200xf32>) -> tensor<2
113113 onnx.Return %5 : tensor <2 xi64 >
114114
115115// mlir2FileCheck.py
116- // CHECK-LABEL: func.func @test_pass_dims_through_cast_2
116+ // CHECK-LABEL: func.func @test_pass_dims_through_gather
117+ // CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<?x?x200xf32>) -> tensor<2xi64> {
118+ // CHECK-DAG: [[VAR_0_:%.+]] = "onnx.Dim"([[PARAM_0_]]) {axis = 0 : si64} : (tensor<?x?x200xf32>) -> tensor<1xi64>
119+ // CHECK-DAG: [[VAR_1_:%.+]] = "onnx.Dim"([[PARAM_0_]]) {axis = 1 : si64} : (tensor<?x?x200xf32>) -> tensor<1xi64>
120+ // CHECK: [[VAR_2_:%.+]] = "onnx.Concat"([[VAR_0_]], [[VAR_1_]]) {axis = 0 : si64} : (tensor<1xi64>, tensor<1xi64>) -> tensor<2xi64>
121+ // CHECK: onnx.Return [[VAR_2_]] : tensor<2xi64>
122+ // CHECK: }
123+ }
124+
125+ // -----
126+
127+ func.func @test_pass_dims_through_gather_2 (%arg0: tensor <?x?x200 xf32 >) -> tensor <2 xi64 > {
128+ %0 = onnx.Constant dense <[-3 , -2 ]> : tensor <2 xi64 >
129+ %1 = " onnx.Dim" (%arg0 ) {axis = 0 : si64 } : (tensor <?x?x200 xf32 >) -> tensor <1 xi64 >
130+ %2 = " onnx.Dim" (%arg0 ) {axis = 1 : si64 } : (tensor <?x?x200 xf32 >) -> tensor <1 xi64 >
131+ %3 = onnx.Constant dense <200 > : tensor <1 xi64 >
132+ %4 = " onnx.Concat" (%1 , %2 , %3 ) {axis = 0 : si64 } : (tensor <1 xi64 >, tensor <1 xi64 >, tensor <1 xi64 >) -> tensor <3 xi64 >
133+ %5 = " onnx.Gather" (%4 , %0 ) {axis = 0 : si64 } : (tensor <3 xi64 >, tensor <2 xi64 >) -> tensor <2 xi64 >
134+ onnx.Return %5 : tensor <2 xi64 >
135+
136+ // mlir2FileCheck.py
137+ // CHECK-LABEL: func.func @test_pass_dims_through_gather_2
117138// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<?x?x200xf32>) -> tensor<2xi64> {
118139// CHECK-DAG: [[VAR_0_:%.+]] = "onnx.Dim"([[PARAM_0_]]) {axis = 0 : si64} : (tensor<?x?x200xf32>) -> tensor<1xi64>
119140// CHECK-DAG: [[VAR_1_:%.+]] = "onnx.Dim"([[PARAM_0_]]) {axis = 1 : si64} : (tensor<?x?x200xf32>) -> tensor<1xi64>
0 commit comments