Skip to content

Commit 5e07a46

Browse files
committed
[mlir][Vector] Support 0-d vectors natively in VectorStoreToMemrefStoreLowering.
1 parent 7f9d348 commit 5e07a46

File tree

4 files changed

+26
-35
lines changed

4 files changed

+26
-35
lines changed

mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -530,15 +530,9 @@ struct VectorStoreToMemrefStoreLowering
530530
return rewriter.notifyMatchFailure(storeOp, "not single element vector");
531531

532532
Value extracted;
533-
if (vecType.getRank() == 0) {
534-
// TODO: Unifiy once ExtractOp supports 0-d vectors.
535-
extracted = rewriter.create<vector::ExtractElementOp>(
536-
storeOp.getLoc(), storeOp.getValueToStore());
537-
} else {
538-
SmallVector<int64_t> indices(vecType.getRank(), 0);
539-
extracted = rewriter.create<vector::ExtractOp>(
540-
storeOp.getLoc(), storeOp.getValueToStore(), indices);
541-
}
533+
SmallVector<int64_t> indices(vecType.getRank(), 0);
534+
extracted = rewriter.create<vector::ExtractOp>(
535+
storeOp.getLoc(), storeOp.getValueToStore(), indices);
542536

543537
rewriter.replaceOpWithNewOp<memref::StoreOp>(
544538
storeOp, extracted, storeOp.getBase(), storeOp.getIndices());

mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2971,9 +2971,8 @@ func.func @vector_store_op_0d(%memref : memref<200x100xf32>, %i : index, %j : in
29712971
// CHECK-LABEL: func @vector_store_op_0d
29722972
// CHECK: %[[val:.*]] = arith.constant dense<1.100000e+01> : vector<f32>
29732973
// CHECK: %[[cast:.*]] = builtin.unrealized_conversion_cast %[[val]] : vector<f32> to vector<1xf32>
2974-
// CHECK: %[[c0:.*]] = llvm.mlir.constant(0 : index) : i64
2975-
// CHECK: %[[extracted:.*]] = llvm.extractelement %[[cast]][%[[c0]] : i64] : vector<1xf32>
2976-
// CHECK: memref.store %[[extracted]], %{{.*}}[%{{.*}}, %{{.*}}]
2974+
// CHECK: %[[cast2:.*]] = builtin.unrealized_conversion_cast %[[cast]] : vector<1xf32> to f32
2975+
// CHECK: memref.store %[[cast2]], %{{.*}}[%{{.*}}, %{{.*}}]
29772976

29782977
// -----
29792978

mlir/test/Dialect/SPIRV/IR/availability.mlir

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ func.func @module_physical_storage_buffer64_vulkan() {
5858
func.func @sdot_scalar_i32_i32(%a: i32) -> i32 {
5959
// CHECK: min version: v1.0
6060
// CHECK: max version: v1.6
61-
// CHECK: extensions: [ [SPV_KHR_integer_dot_product] ]
61+
// CHECK: extensions: [ [SPV_KHR_16bit_storage] ]
6262
// CHECK: capabilities: [ [DotProduct] [DotProductInput4x8BitPacked] ]
6363
%r = spirv.SDot %a, %a, <PackedVectorFormat4x8Bit>: i32 -> i32
6464
return %r: i32
@@ -68,7 +68,7 @@ func.func @sdot_scalar_i32_i32(%a: i32) -> i32 {
6868
func.func @sdot_vector_4xi8_i64(%a: vector<4xi8>) -> i64 {
6969
// CHECK: min version: v1.0
7070
// CHECK: max version: v1.6
71-
// CHECK: extensions: [ [SPV_KHR_integer_dot_product] ]
71+
// CHECK: extensions: [ [SPV_KHR_16bit_storage] ]
7272
// CHECK: capabilities: [ [DotProduct] [DotProductInput4x8Bit] ]
7373
%r = spirv.SDot %a, %a: vector<4xi8> -> i64
7474
return %r: i64
@@ -78,7 +78,7 @@ func.func @sdot_vector_4xi8_i64(%a: vector<4xi8>) -> i64 {
7878
func.func @sdot_vector_4xi16_i64(%a: vector<4xi16>) -> i64 {
7979
// CHECK: min version: v1.0
8080
// CHECK: max version: v1.6
81-
// CHECK: extensions: [ [SPV_KHR_integer_dot_product] ]
81+
// CHECK: extensions: [ [SPV_KHR_16bit_storage] ]
8282
// CHECK: capabilities: [ [DotProduct] [DotProductInputAll] ]
8383
%r = spirv.SDot %a, %a: vector<4xi16> -> i64
8484
return %r: i64
@@ -88,7 +88,7 @@ func.func @sdot_vector_4xi16_i64(%a: vector<4xi16>) -> i64 {
8888
func.func @sudot_scalar_i32_i32(%a: i32) -> i32 {
8989
// CHECK: min version: v1.0
9090
// CHECK: max version: v1.6
91-
// CHECK: extensions: [ [SPV_KHR_integer_dot_product] ]
91+
// CHECK: extensions: [ [SPV_KHR_16bit_storage] ]
9292
// CHECK: capabilities: [ [DotProduct] [DotProductInput4x8BitPacked] ]
9393
%r = spirv.SUDot %a, %a, <PackedVectorFormat4x8Bit>: i32 -> i32
9494
return %r: i32
@@ -98,7 +98,7 @@ func.func @sudot_scalar_i32_i32(%a: i32) -> i32 {
9898
func.func @sudot_vector_4xi8_i64(%a: vector<4xi8>) -> i64 {
9999
// CHECK: min version: v1.0
100100
// CHECK: max version: v1.6
101-
// CHECK: extensions: [ [SPV_KHR_integer_dot_product] ]
101+
// CHECK: extensions: [ [SPV_KHR_16bit_storage] ]
102102
// CHECK: capabilities: [ [DotProduct] [DotProductInput4x8Bit] ]
103103
%r = spirv.SUDot %a, %a: vector<4xi8> -> i64
104104
return %r: i64
@@ -108,7 +108,7 @@ func.func @sudot_vector_4xi8_i64(%a: vector<4xi8>) -> i64 {
108108
func.func @sudot_vector_4xi16_i64(%a: vector<4xi16>) -> i64 {
109109
// CHECK: min version: v1.0
110110
// CHECK: max version: v1.6
111-
// CHECK: extensions: [ [SPV_KHR_integer_dot_product] ]
111+
// CHECK: extensions: [ [SPV_KHR_16bit_storage] ]
112112
// CHECK: capabilities: [ [DotProduct] [DotProductInputAll] ]
113113
%r = spirv.SUDot %a, %a: vector<4xi16> -> i64
114114
return %r: i64
@@ -118,7 +118,7 @@ func.func @sudot_vector_4xi16_i64(%a: vector<4xi16>) -> i64 {
118118
func.func @udot_scalar_i32_i32(%a: i32) -> i32 {
119119
// CHECK: min version: v1.0
120120
// CHECK: max version: v1.6
121-
// CHECK: extensions: [ [SPV_KHR_integer_dot_product] ]
121+
// CHECK: extensions: [ [SPV_KHR_16bit_storage] ]
122122
// CHECK: capabilities: [ [DotProduct] [DotProductInput4x8BitPacked] ]
123123
%r = spirv.UDot %a, %a, <PackedVectorFormat4x8Bit>: i32 -> i32
124124
return %r: i32
@@ -128,7 +128,7 @@ func.func @udot_scalar_i32_i32(%a: i32) -> i32 {
128128
func.func @udot_vector_4xi8_i64(%a: vector<4xi8>) -> i64 {
129129
// CHECK: min version: v1.0
130130
// CHECK: max version: v1.6
131-
// CHECK: extensions: [ [SPV_KHR_integer_dot_product] ]
131+
// CHECK: extensions: [ [SPV_KHR_16bit_storage] ]
132132
// CHECK: capabilities: [ [DotProduct] [DotProductInput4x8Bit] ]
133133
%r = spirv.UDot %a, %a: vector<4xi8> -> i64
134134
return %r: i64
@@ -138,7 +138,7 @@ func.func @udot_vector_4xi8_i64(%a: vector<4xi8>) -> i64 {
138138
func.func @udot_vector_4xi16_i64(%a: vector<4xi16>) -> i64 {
139139
// CHECK: min version: v1.0
140140
// CHECK: max version: v1.6
141-
// CHECK: extensions: [ [SPV_KHR_integer_dot_product] ]
141+
// CHECK: extensions: [ [SPV_KHR_16bit_storage] ]
142142
// CHECK: capabilities: [ [DotProduct] [DotProductInputAll] ]
143143
%r = spirv.UDot %a, %a: vector<4xi16> -> i64
144144
return %r: i64
@@ -148,7 +148,7 @@ func.func @udot_vector_4xi16_i64(%a: vector<4xi16>) -> i64 {
148148
func.func @sdot_acc_sat_scalar_i32_i32(%a: i32) -> i32 {
149149
// CHECK: min version: v1.0
150150
// CHECK: max version: v1.6
151-
// CHECK: extensions: [ [SPV_KHR_integer_dot_product] ]
151+
// CHECK: extensions: [ [SPV_KHR_16bit_storage] ]
152152
// CHECK: capabilities: [ [DotProduct] [DotProductInput4x8BitPacked] ]
153153
%r = spirv.SDotAccSat %a, %a, %a, <PackedVectorFormat4x8Bit>: i32 -> i32
154154
return %r: i32
@@ -158,7 +158,7 @@ func.func @sdot_acc_sat_scalar_i32_i32(%a: i32) -> i32 {
158158
func.func @sdot_acc_sat_vector_4xi8_i64(%a: vector<4xi8>, %acc: i64) -> i64 {
159159
// CHECK: min version: v1.0
160160
// CHECK: max version: v1.6
161-
// CHECK: extensions: [ [SPV_KHR_integer_dot_product] ]
161+
// CHECK: extensions: [ [SPV_KHR_16bit_storage] ]
162162
// CHECK: capabilities: [ [DotProduct] [DotProductInput4x8Bit] ]
163163
%r = spirv.SDotAccSat %a, %a, %acc: vector<4xi8> -> i64
164164
return %r: i64
@@ -168,7 +168,7 @@ func.func @sdot_acc_sat_vector_4xi8_i64(%a: vector<4xi8>, %acc: i64) -> i64 {
168168
func.func @sdot_acc_sat_vector_4xi16_i64(%a: vector<4xi16>, %acc: i64) -> i64 {
169169
// CHECK: min version: v1.0
170170
// CHECK: max version: v1.6
171-
// CHECK: extensions: [ [SPV_KHR_integer_dot_product] ]
171+
// CHECK: extensions: [ [SPV_KHR_16bit_storage] ]
172172
// CHECK: capabilities: [ [DotProduct] [DotProductInputAll] ]
173173
%r = spirv.SDotAccSat %a, %a, %acc: vector<4xi16> -> i64
174174
return %r: i64
@@ -178,7 +178,7 @@ func.func @sdot_acc_sat_vector_4xi16_i64(%a: vector<4xi16>, %acc: i64) -> i64 {
178178
func.func @sudot_acc_sat_scalar_i32_i32(%a: i32) -> i32 {
179179
// CHECK: min version: v1.0
180180
// CHECK: max version: v1.6
181-
// CHECK: extensions: [ [SPV_KHR_integer_dot_product] ]
181+
// CHECK: extensions: [ [SPV_KHR_16bit_storage] ]
182182
// CHECK: capabilities: [ [DotProduct] [DotProductInput4x8BitPacked] ]
183183
%r = spirv.SUDotAccSat %a, %a, %a, <PackedVectorFormat4x8Bit>: i32 -> i32
184184
return %r: i32
@@ -188,7 +188,7 @@ func.func @sudot_acc_sat_scalar_i32_i32(%a: i32) -> i32 {
188188
func.func @sudot_acc_sat_vector_4xi8_i64(%a: vector<4xi8>, %acc: i64) -> i64 {
189189
// CHECK: min version: v1.0
190190
// CHECK: max version: v1.6
191-
// CHECK: extensions: [ [SPV_KHR_integer_dot_product] ]
191+
// CHECK: extensions: [ [SPV_KHR_16bit_storage] ]
192192
// CHECK: capabilities: [ [DotProduct] [DotProductInput4x8Bit] ]
193193
%r = spirv.SUDotAccSat %a, %a, %acc: vector<4xi8> -> i64
194194
return %r: i64
@@ -198,7 +198,7 @@ func.func @sudot_acc_sat_vector_4xi8_i64(%a: vector<4xi8>, %acc: i64) -> i64 {
198198
func.func @sudot_acc_sat_vector_4xi16_i64(%a: vector<4xi16>, %acc: i64) -> i64 {
199199
// CHECK: min version: v1.0
200200
// CHECK: max version: v1.6
201-
// CHECK: extensions: [ [SPV_KHR_integer_dot_product] ]
201+
// CHECK: extensions: [ [SPV_KHR_16bit_storage] ]
202202
// CHECK: capabilities: [ [DotProduct] [DotProductInputAll] ]
203203
%r = spirv.SUDotAccSat %a, %a, %acc: vector<4xi16> -> i64
204204
return %r: i64
@@ -208,7 +208,7 @@ func.func @sudot_acc_sat_vector_4xi16_i64(%a: vector<4xi16>, %acc: i64) -> i64 {
208208
func.func @udot_acc_sat_scalar_i32_i32(%a: i32) -> i32 {
209209
// CHECK: min version: v1.0
210210
// CHECK: max version: v1.6
211-
// CHECK: extensions: [ [SPV_KHR_integer_dot_product] ]
211+
// CHECK: extensions: [ [SPV_KHR_16bit_storage] ]
212212
// CHECK: capabilities: [ [DotProduct] [DotProductInput4x8BitPacked] ]
213213
%r = spirv.UDotAccSat %a, %a, %a, <PackedVectorFormat4x8Bit>: i32 -> i32
214214
return %r: i32
@@ -218,7 +218,7 @@ func.func @udot_acc_sat_scalar_i32_i32(%a: i32) -> i32 {
218218
func.func @udot_acc_sat_vector_4xi8_i64(%a: vector<4xi8>, %acc: i64) -> i64 {
219219
// CHECK: min version: v1.0
220220
// CHECK: max version: v1.6
221-
// CHECK: extensions: [ [SPV_KHR_integer_dot_product] ]
221+
// CHECK: extensions: [ [SPV_KHR_16bit_storage] ]
222222
// CHECK: capabilities: [ [DotProduct] [DotProductInput4x8Bit] ]
223223
%r = spirv.UDotAccSat %a, %a, %acc: vector<4xi8> -> i64
224224
return %r: i64
@@ -228,7 +228,7 @@ func.func @udot_acc_sat_vector_4xi8_i64(%a: vector<4xi8>, %acc: i64) -> i64 {
228228
func.func @udot_acc_sat_vector_4xi16_i64(%a: vector<4xi16>, %acc: i64) -> i64 {
229229
// CHECK: min version: v1.0
230230
// CHECK: max version: v1.6
231-
// CHECK: extensions: [ [SPV_KHR_integer_dot_product] ]
231+
// CHECK: extensions: [ [SPV_KHR_16bit_storage] ]
232232
// CHECK: capabilities: [ [DotProduct] [DotProductInputAll] ]
233233
%r = spirv.UDotAccSat %a, %a, %acc: vector<4xi16> -> i64
234234
return %r: i64

mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,15 +7,13 @@ func.func @vector_transfer_ops_0d_memref(%mem: memref<f32>, %vec: vector<1x1x1xf
77
%f0 = arith.constant 0.0 : f32
88

99
// CHECK-NEXT: %[[S:.*]] = memref.load %[[MEM]][] : memref<f32>
10-
// CHECK-NEXT: %[[V:.*]] = vector.broadcast %[[S]] : f32 to vector<f32>
1110
%0 = vector.transfer_read %mem[], %f0 : memref<f32>, vector<f32>
1211

13-
// CHECK-NEXT: %[[SS:.*]] = vector.extractelement %[[V]][] : vector<f32>
14-
// CHECK-NEXT: memref.store %[[SS]], %[[MEM]][] : memref<f32>
12+
// CHECK-NEXT: memref.store %[[S]], %[[MEM]][] : memref<f32>
1513
vector.transfer_write %0, %mem[] : vector<f32>, memref<f32>
1614

17-
// CHECK-NEXT: %[[VV:.*]] = vector.extract %arg1[0, 0, 0] : f32 from vector<1x1x1xf32>
18-
// CHECK-NEXT: memref.store %[[VV]], %[[MEM]][] : memref<f32>
15+
// CHECK-NEXT: %[[V:.*]] = vector.extract %arg1[0, 0, 0] : f32 from vector<1x1x1xf32>
16+
// CHECK-NEXT: memref.store %[[V]], %[[MEM]][] : memref<f32>
1917
vector.store %vec, %mem[] : memref<f32>, vector<1x1x1xf32>
2018

2119
return

0 commit comments

Comments
 (0)