11// RUN: mlir-opt %s -test-compose-subview -split-input-file | FileCheck %s
22
33// CHECK-LABEL: func.func @subview_strided(
4- // CHECK-SAME: %[[VAL_0 :.*]]: memref<4x1024xf32>) -> memref<1x128xf32, strided<[1024, 1], offset: 3456>> {
4+ // CHECK-SAME: %[[input :.*]]: memref<4x1024xf32>) -> memref<1x128xf32, strided<[1024, 1], offset: 3456>> {
55func.func @subview_strided (%input: memref <4 x1024 xf32 >) -> memref <1 x128 xf32 , strided <[1024 , 1 ], offset : 3456 >> {
6- // CHECK: %[[VAL_1:.*]] = memref.subview %[[VAL_0 ]][3, 384] [1, 128] [1, 1] : memref<4x1024xf32> to memref<1x128xf32, strided<[1024, 1], offset: 3456>>
6+ // CHECK: {{.*}} = memref.subview %[[input ]][3, 384] [1, 128] [1, 1] : memref<4x1024xf32> to memref<1x128xf32, strided<[1024, 1], offset: 3456>>
77 %0 = memref.subview %input [2 , 256 ] [2 , 256 ] [1 , 1 ] : memref <4 x1024 xf32 > to memref <2 x256 xf32 , strided <[1024 , 1 ], offset : 2304 >>
88 %1 = memref.subview %0 [1 , 128 ] [1 , 128 ] [1 , 1 ] : memref <2 x256 xf32 , strided <[1024 , 1 ], offset : 2304 >> to memref <1 x128 xf32 , strided <[1024 , 1 ], offset : 3456 >>
99 return %1 : memref <1 x128 xf32 , strided <[1024 , 1 ], offset : 3456 >>
@@ -12,9 +12,9 @@ func.func @subview_strided(%input: memref<4x1024xf32>) -> memref<1x128xf32, stri
1212// -----
1313
1414// CHECK-LABEL: func.func @subview_strided(
15- // CHECK-SAME: %[[VAL_0 :.*]]: memref<4x1024xf32>) -> memref<1x10xf32, strided<[1024, 1], offset: 3745>> {
15+ // CHECK-SAME: %[[input :.*]]: memref<4x1024xf32>) -> memref<1x10xf32, strided<[1024, 1], offset: 3745>> {
1616func.func @subview_strided (%input: memref <4 x1024 xf32 >) -> memref <1 x10 xf32 , strided <[1024 , 1 ], offset : 3745 >> {
17- // CHECK: %[[VAL_1:.*]] = memref.subview %[[VAL_0 ]][3, 673] [1, 10] [1, 1] : memref<4x1024xf32> to memref<1x10xf32, strided<[1024, 1], offset: 3745>>
17+ // CHECK: {{.*}} = memref.subview %[[input ]][3, 673] [1, 10] [1, 1] : memref<4x1024xf32> to memref<1x10xf32, strided<[1024, 1], offset: 3745>>
1818 %0 = memref.subview %input [1 , 512 ] [3 , 256 ] [1 , 1 ] : memref <4 x1024 xf32 > to memref <3 x256 xf32 , strided <[1024 , 1 ], offset : 1536 >>
1919 %1 = memref.subview %0 [1 , 128 ] [2 , 128 ] [1 , 1 ] : memref <3 x256 xf32 , strided <[1024 , 1 ], offset : 1536 >> to memref <2 x128 xf32 , strided <[1024 , 1 ], offset : 2688 >>
2020 %2 = memref.subview %1 [1 , 33 ] [1 , 10 ] [1 , 1 ] : memref <2 x128 xf32 , strided <[1024 , 1 ], offset : 2688 >> to memref <1 x10 xf32 , strided <[1024 , 1 ], offset : 3745 >>
@@ -24,12 +24,12 @@ func.func @subview_strided(%input: memref<4x1024xf32>) -> memref<1x10xf32, strid
2424// -----
2525
2626// CHECK-LABEL: func.func @subview_strided(
27- // CHECK-SAME: %[[VAL_0 :.*]]: memref<4x1024xf32>) -> memref<1x128xf32, strided<[1024, 1], offset: ?>> {
27+ // CHECK-SAME: %[[input :.*]]: memref<4x1024xf32>) -> memref<1x128xf32, strided<[1024, 1], offset: ?>> {
2828func.func @subview_strided (%input: memref <4 x1024 xf32 >) -> memref <1 x128 xf32 , strided <[1024 , 1 ], offset : ?>> {
29- // CHECK: %[[VAL_1 :.*]] = arith.constant 3 : index
29+ // CHECK: %[[C3 :.*]] = arith.constant 3 : index
3030 %cst_1 = arith.constant 1 : index
3131 %cst_2 = arith.constant 2 : index
32- // CHECK: %[[VAL_2:.*]] = memref.subview %[[VAL_0 ]]{{\[}}%[[VAL_1 ]], 384] [1, 128] [1, 1] : memref<4x1024xf32> to memref<1x128xf32, strided<[1024, 1], offset: ?>>
32+ // CHECK: {{.*}} = memref.subview %[[input ]]{{\[}}%[[C3 ]], 384] [1, 128] [1, 1] : memref<4x1024xf32> to memref<1x128xf32, strided<[1024, 1], offset: ?>>
3333 %0 = memref.subview %input [%cst_2 , 256 ] [2 , 256 ] [1 , 1 ] : memref <4 x1024 xf32 > to memref <2 x256 xf32 , strided <[1024 , 1 ], offset : ?>>
3434 %1 = memref.subview %0 [%cst_1 , 128 ] [1 , 128 ] [1 , 1 ] : memref <2 x256 xf32 , strided <[1024 , 1 ], offset : ?>> to memref <1 x128 xf32 , strided <[1024 , 1 ], offset : ?>>
3535 return %1 : memref <1 x128 xf32 , strided <[1024 , 1 ], offset : ?>>
@@ -38,13 +38,13 @@ func.func @subview_strided(%input: memref<4x1024xf32>) -> memref<1x128xf32, stri
3838// -----
3939
4040// CHECK-LABEL: func.func @subview_strided(
41- // CHECK-SAME: %[[VAL_0 :.*]]: memref<4x1024xf32>) -> memref<1x128xf32, strided<[1024, 1], offset: ?>> {
41+ // CHECK-SAME: %[[input :.*]]: memref<4x1024xf32>) -> memref<1x128xf32, strided<[1024, 1], offset: ?>> {
4242func.func @subview_strided (%input: memref <4 x1024 xf32 >) -> memref <1 x128 xf32 , strided <[1024 , 1 ], offset : ?>> {
43- // CHECK: %[[VAL_1 :.*]] = arith.constant 3 : index
43+ // CHECK: %[[C3 :.*]] = arith.constant 3 : index
4444 %cst_2 = arith.constant 2 : index
45- // CHECK: %[[VAL_2 :.*]] = arith.constant 384 : index
45+ // CHECK: %[[C384 :.*]] = arith.constant 384 : index
4646 %cst_128 = arith.constant 128 : index
47- // CHECK: %[[VAL_3:.*]] = memref.subview %[[VAL_0 ]]{{\[}}%[[VAL_1 ]], %[[VAL_2 ]]] [1, 128] [1, 1] : memref<4x1024xf32> to memref<1x128xf32, strided<[1024, 1], offset: ?>>
47+ // CHECK: {{.*}} = memref.subview %[[input ]]{{\[}}%[[C3 ]], %[[C384 ]]] [1, 128] [1, 1] : memref<4x1024xf32> to memref<1x128xf32, strided<[1024, 1], offset: ?>>
4848 %0 = memref.subview %input [%cst_2 , 256 ] [2 , 256 ] [1 , 1 ] : memref <4 x1024 xf32 > to memref <2 x256 xf32 , strided <[1024 , 1 ], offset : ?>>
4949 %1 = memref.subview %0 [1 , %cst_128 ] [1 , 128 ] [1 , 1 ] : memref <2 x256 xf32 , strided <[1024 , 1 ], offset : ?>> to memref <1 x128 xf32 , strided <[1024 , 1 ], offset : ?>>
5050 return %1 : memref <1 x128 xf32 , strided <[1024 , 1 ], offset : ?>>
@@ -53,9 +53,9 @@ func.func @subview_strided(%input: memref<4x1024xf32>) -> memref<1x128xf32, stri
5353// -----
5454
5555// CHECK-LABEL: func.func @subview_strided(
56- // CHECK-SAME: %[[VAL_0 :.*]]: memref<8x1024xf32>) -> memref<1x64xf32, strided<[4096, 4], offset: 4480>> {
56+ // CHECK-SAME: %[[input :.*]]: memref<8x1024xf32>) -> memref<1x64xf32, strided<[4096, 4], offset: 4480>> {
5757func.func @subview_strided (%input: memref <8 x1024 xf32 >) -> memref <1 x64 xf32 , strided <[4096 , 4 ], offset : 4480 >> {
58- // CHECK: %[[VAL_1:.*]] = memref.subview %[[VAL_0 ]][4, 384] [1, 64] [4, 4] : memref<8x1024xf32> to memref<1x64xf32, strided<[4096, 4], offset: 4480>>
58+ // CHECK: {{.*}} = memref.subview %[[input ]][4, 384] [1, 64] [4, 4] : memref<8x1024xf32> to memref<1x64xf32, strided<[4096, 4], offset: 4480>>
5959 %0 = memref.subview %input [2 , 256 ] [2 , 256 ] [2 , 2 ] : memref <8 x1024 xf32 > to memref <2 x256 xf32 , strided <[2048 , 2 ], offset : 2304 >>
6060 %1 = memref.subview %0 [1 , 64 ] [1 , 64 ] [2 , 2 ] : memref <2 x256 xf32 , strided <[2048 , 2 ], offset : 2304 >> to memref <1 x64 xf32 , strided <[4096 , 4 ], offset : 4480 >>
6161 return %1 : memref <1 x64 xf32 , strided <[4096 , 4 ], offset : 4480 >>
@@ -64,9 +64,9 @@ func.func @subview_strided(%input: memref<8x1024xf32>) -> memref<1x64xf32, strid
6464// -----
6565
6666// CHECK-LABEL: func.func @subview_strided(
67- // CHECK-SAME: %[[VAL_0 :.*]]: memref<30x30xf32>) -> memref<2x2xf32, strided<[240, 8], offset: 217>> {
67+ // CHECK-SAME: %[[input :.*]]: memref<30x30xf32>) -> memref<2x2xf32, strided<[240, 8], offset: 217>> {
6868func.func @subview_strided (%input: memref <30 x30 xf32 >) -> memref <2 x2 xf32 , strided <[240 , 8 ], offset : 217 >> {
69- // CHECK: %[[VAL_1:.*]] = memref.subview %[[VAL_0 ]][7, 7] [2, 2] [8, 8] : memref<30x30xf32> to memref<2x2xf32, strided<[240, 8], offset: 217>>
69+ // CHECK: {{.*}} = memref.subview %[[input ]][7, 7] [2, 2] [8, 8] : memref<30x30xf32> to memref<2x2xf32, strided<[240, 8], offset: 217>>
7070 %0 = memref.subview %input [1 , 1 ] [12 , 12 ] [2 , 2 ] : memref <30 x30 xf32 > to memref <12 x12 xf32 , strided <[60 , 2 ], offset : 31 >>
7171 %1 = memref.subview %0 [1 , 1 ] [5 , 5 ] [2 , 2 ] : memref <12 x12 xf32 , strided <[60 , 2 ], offset : 31 >> to memref <5 x5 xf32 , strided <[120 , 4 ], offset : 93 >>
7272 %2 = memref.subview %1 [1 , 1 ] [2 , 2 ] [2 , 2 ] : memref <5 x5 xf32 , strided <[120 , 4 ], offset : 93 >> to memref <2 x2 xf32 , strided <[240 , 8 ], offset : 217 >>
@@ -76,13 +76,13 @@ func.func @subview_strided(%input: memref<30x30xf32>) -> memref<2x2xf32, strided
7676// -----
7777
7878// CHECK-LABEL: func.func @subview_strided(
79- // CHECK-SAME: %[[VAL_0 :.*]]: memref<4x1024xf32>) -> memref<1x64xf32, strided<[4096, 4], offset: ?>> {
79+ // CHECK-SAME: %[[input :.*]]: memref<4x1024xf32>) -> memref<1x64xf32, strided<[4096, 4], offset: ?>> {
8080func.func @subview_strided (%input: memref <4 x1024 xf32 >) -> memref <1 x64 xf32 , strided <[4096 , 4 ], offset : ?>> {
81- // CHECK: %[[VAL_1 :.*]] = arith.constant 4 : index
81+ // CHECK: %[[C4 :.*]] = arith.constant 4 : index
8282 %cst_2 = arith.constant 2 : index
83- // CHECK: %[[VAL_2 :.*]] = arith.constant 384 : index
83+ // CHECK: %[[C384 :.*]] = arith.constant 384 : index
8484 %cst_64 = arith.constant 64 : index
85- // CHECK: %[[VAL_3:.*]] = memref.subview %[[VAL_0 ]]{{\[}}%[[VAL_1 ]], %[[VAL_2 ]]] [1, 64] [4, 4] : memref<4x1024xf32> to memref<1x64xf32, strided<[4096, 4], offset: ?>>
85+ // CHECK: {{.*}} = memref.subview %[[input ]]{{\[}}%[[C4 ]], %[[C384 ]]] [1, 64] [4, 4] : memref<4x1024xf32> to memref<1x64xf32, strided<[4096, 4], offset: ?>>
8686 %0 = memref.subview %input [%cst_2 , 256 ] [2 , 256 ] [2 , 2 ] : memref <4 x1024 xf32 > to memref <2 x256 xf32 , strided <[2048 , 2 ], offset : ?>>
8787 %1 = memref.subview %0 [1 , %cst_64 ] [1 , 64 ] [2 , 2 ] : memref <2 x256 xf32 , strided <[2048 , 2 ], offset : ?>> to memref <1 x64 xf32 , strided <[4096 , 4 ], offset : ?>>
8888 return %1 : memref <1 x64 xf32 , strided <[4096 , 4 ], offset : ?>>
@@ -91,13 +91,39 @@ func.func @subview_strided(%input: memref<4x1024xf32>) -> memref<1x64xf32, strid
9191// -----
9292
9393// CHECK-LABEL: func.func @subview_strided(
94- // CHECK-SAME: %[[VAL_0 :.*]]: memref<4x1024xf32>) -> memref<1x64xf32, strided<[4096, 4], offset: ?>> {
94+ // CHECK-SAME: %[[input :.*]]: memref<4x1024xf32>) -> memref<1x64xf32, strided<[4096, 4], offset: ?>> {
9595func.func @subview_strided (%input: memref <4 x1024 xf32 >) -> memref <1 x64 xf32 , strided <[4096 , 4 ], offset : ?>> {
96- // CHECK: %[[VAL_1 :.*]] = arith.constant 4 : index
96+ // CHECK: %[[C4 :.*]] = arith.constant 4 : index
9797 %cst_1 = arith.constant 1 : index
9898 %cst_2 = arith.constant 2 : index
99- // CHECK: %[[VAL_2:.*]] = memref.subview %[[VAL_0 ]]{{\[}}%[[VAL_1 ]], 384] [1, 64] [4, 4] : memref<4x1024xf32> to memref<1x64xf32, strided<[4096, 4], offset: ?>>
99+ // CHECK: {{.*}} = memref.subview %[[input ]]{{\[}}%[[C4 ]], 384] [1, 64] [4, 4] : memref<4x1024xf32> to memref<1x64xf32, strided<[4096, 4], offset: ?>>
100100 %0 = memref.subview %input [%cst_2 , 256 ] [2 , 256 ] [2 , 2 ] : memref <4 x1024 xf32 > to memref <2 x256 xf32 , strided <[2048 , 2 ], offset : ?>>
101101 %1 = memref.subview %0 [%cst_1 , 64 ] [1 , 64 ] [2 , 2 ] : memref <2 x256 xf32 , strided <[2048 , 2 ], offset : ?>> to memref <1 x64 xf32 , strided <[4096 , 4 ], offset : ?>>
102102 return %1 : memref <1 x64 xf32 , strided <[4096 , 4 ], offset : ?>>
103103}
104+
105+ // -----
106+
107+ // CHECK-LABEL: func.func @single_dynamic_size_subview(
108+ // CHECK-SAME: %[[input:.*]]: memref<256x?xf32>,
109+ // CHECK-SAME: %{{.*}}: index,
110+ // CHECK-SAME: %[[SIZE_1:.*]]: index) -> memref<8x?xf32> {
111+ func.func @single_dynamic_size_subview (%input: memref <256 x?xf32 >, %size0 : index , %size1 : index ) -> memref <8 x?xf32 >{
112+ %subview = memref.subview %input [0 , 0 ][8 , %size0 ][1 , 1 ] : memref <256 x?xf32 > to memref <8 x?xf32 >
113+ %subview_1 = memref.subview %subview [0 , 0 ][8 , %size1 ][1 , 1 ] : memref <8 x?xf32 > to memref <8 x?xf32 >
114+ // CHECK: %{{.*}} = memref.subview %[[input]][0, 0] [8, %[[SIZE_1]]] [1, 1] : memref<256x?xf32> to memref<8x?xf32>
115+ return %subview_1 : memref <8 x?xf32 >
116+ }
117+
118+ // -----
119+
120+ // CHECK-LABEL: func.func @all_dynamic_size_subview(
121+ // CHECK-SAME: %[[input:.*]]: memref<256x?xf32>,
122+ // CHECK-SAME: %{{.*}}: index,
123+ // CHECK-SAME: %[[SIZE1:.*]]: index) -> memref<?x?xf32> {
124+ func.func @all_dynamic_size_subview (%input: memref <256 x?xf32 >, %size0 : index , %size1 : index ) -> memref <?x?xf32 >{
125+ %subview = memref.subview %input [0 , 0 ][%size0 , %size0 ][1 , 1 ] : memref <256 x?xf32 > to memref <?x?xf32 >
126+ %subview_1 = memref.subview %subview [0 , 0 ][%size1 , %size1 ][1 , 1 ] : memref <?x?xf32 > to memref <?x?xf32 >
127+ // CHECK: {{.*}} = memref.subview %[[input]][0, 0] {{\[}}%[[SIZE1]], %[[SIZE1]]] [1, 1] : memref<256x?xf32> to memref<?x?xf32>
128+ return %subview_1 : memref <?x?xf32 >
129+ }
0 commit comments