-
Notifications
You must be signed in to change notification settings - Fork 15.4k
[mlir][Vector] Remove 0-d corner case condition. #112937
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
0b96a31 to
5e07a46
Compare
|
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-vector Author: Harrison Hao (harrisonGPU) ChangesSupport 0-d vectors natively in VectorStoreToMemrefStoreLowering. Closes #112913 Full diff: https://github.com/llvm/llvm-project/pull/112937.diff 4 Files Affected:
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
index f9428a4ce28640..1cb3baaef82baf 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
@@ -530,15 +530,9 @@ struct VectorStoreToMemrefStoreLowering
return rewriter.notifyMatchFailure(storeOp, "not single element vector");
Value extracted;
- if (vecType.getRank() == 0) {
- // TODO: Unifiy once ExtractOp supports 0-d vectors.
- extracted = rewriter.create<vector::ExtractElementOp>(
- storeOp.getLoc(), storeOp.getValueToStore());
- } else {
- SmallVector<int64_t> indices(vecType.getRank(), 0);
- extracted = rewriter.create<vector::ExtractOp>(
- storeOp.getLoc(), storeOp.getValueToStore(), indices);
- }
+ SmallVector<int64_t> indices(vecType.getRank(), 0);
+ extracted = rewriter.create<vector::ExtractOp>(
+ storeOp.getLoc(), storeOp.getValueToStore(), indices);
rewriter.replaceOpWithNewOp<memref::StoreOp>(
storeOp, extracted, storeOp.getBase(), storeOp.getIndices());
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index c1de24fd0403ce..abbdbe02ce6c1e 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -2971,9 +2971,8 @@ func.func @vector_store_op_0d(%memref : memref<200x100xf32>, %i : index, %j : in
// CHECK-LABEL: func @vector_store_op_0d
// CHECK: %[[val:.*]] = arith.constant dense<1.100000e+01> : vector<f32>
// CHECK: %[[cast:.*]] = builtin.unrealized_conversion_cast %[[val]] : vector<f32> to vector<1xf32>
-// CHECK: %[[c0:.*]] = llvm.mlir.constant(0 : index) : i64
-// CHECK: %[[extracted:.*]] = llvm.extractelement %[[cast]][%[[c0]] : i64] : vector<1xf32>
-// CHECK: memref.store %[[extracted]], %{{.*}}[%{{.*}}, %{{.*}}]
+// CHECK: %[[cast2:.*]] = builtin.unrealized_conversion_cast %[[cast]] : vector<1xf32> to f32
+// CHECK: memref.store %[[cast2]], %{{.*}}[%{{.*}}, %{{.*}}]
// -----
diff --git a/mlir/test/Dialect/SPIRV/IR/availability.mlir b/mlir/test/Dialect/SPIRV/IR/availability.mlir
index c583a48eba2704..ceebeeffcf2677 100644
--- a/mlir/test/Dialect/SPIRV/IR/availability.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/availability.mlir
@@ -58,7 +58,7 @@ func.func @module_physical_storage_buffer64_vulkan() {
func.func @sdot_scalar_i32_i32(%a: i32) -> i32 {
// CHECK: min version: v1.0
// CHECK: max version: v1.6
- // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ]
+ // CHECK: extensions: [ [SPV_KHR_16bit_storage] ]
// CHECK: capabilities: [ [DotProduct] [DotProductInput4x8BitPacked] ]
%r = spirv.SDot %a, %a, <PackedVectorFormat4x8Bit>: i32 -> i32
return %r: i32
@@ -68,7 +68,7 @@ func.func @sdot_scalar_i32_i32(%a: i32) -> i32 {
func.func @sdot_vector_4xi8_i64(%a: vector<4xi8>) -> i64 {
// CHECK: min version: v1.0
// CHECK: max version: v1.6
- // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ]
+ // CHECK: extensions: [ [SPV_KHR_16bit_storage] ]
// CHECK: capabilities: [ [DotProduct] [DotProductInput4x8Bit] ]
%r = spirv.SDot %a, %a: vector<4xi8> -> i64
return %r: i64
@@ -78,7 +78,7 @@ func.func @sdot_vector_4xi8_i64(%a: vector<4xi8>) -> i64 {
func.func @sdot_vector_4xi16_i64(%a: vector<4xi16>) -> i64 {
// CHECK: min version: v1.0
// CHECK: max version: v1.6
- // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ]
+ // CHECK: extensions: [ [SPV_KHR_16bit_storage] ]
// CHECK: capabilities: [ [DotProduct] [DotProductInputAll] ]
%r = spirv.SDot %a, %a: vector<4xi16> -> i64
return %r: i64
@@ -88,7 +88,7 @@ func.func @sdot_vector_4xi16_i64(%a: vector<4xi16>) -> i64 {
func.func @sudot_scalar_i32_i32(%a: i32) -> i32 {
// CHECK: min version: v1.0
// CHECK: max version: v1.6
- // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ]
+ // CHECK: extensions: [ [SPV_KHR_16bit_storage] ]
// CHECK: capabilities: [ [DotProduct] [DotProductInput4x8BitPacked] ]
%r = spirv.SUDot %a, %a, <PackedVectorFormat4x8Bit>: i32 -> i32
return %r: i32
@@ -98,7 +98,7 @@ func.func @sudot_scalar_i32_i32(%a: i32) -> i32 {
func.func @sudot_vector_4xi8_i64(%a: vector<4xi8>) -> i64 {
// CHECK: min version: v1.0
// CHECK: max version: v1.6
- // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ]
+ // CHECK: extensions: [ [SPV_KHR_16bit_storage] ]
// CHECK: capabilities: [ [DotProduct] [DotProductInput4x8Bit] ]
%r = spirv.SUDot %a, %a: vector<4xi8> -> i64
return %r: i64
@@ -108,7 +108,7 @@ func.func @sudot_vector_4xi8_i64(%a: vector<4xi8>) -> i64 {
func.func @sudot_vector_4xi16_i64(%a: vector<4xi16>) -> i64 {
// CHECK: min version: v1.0
// CHECK: max version: v1.6
- // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ]
+ // CHECK: extensions: [ [SPV_KHR_16bit_storage] ]
// CHECK: capabilities: [ [DotProduct] [DotProductInputAll] ]
%r = spirv.SUDot %a, %a: vector<4xi16> -> i64
return %r: i64
@@ -118,7 +118,7 @@ func.func @sudot_vector_4xi16_i64(%a: vector<4xi16>) -> i64 {
func.func @udot_scalar_i32_i32(%a: i32) -> i32 {
// CHECK: min version: v1.0
// CHECK: max version: v1.6
- // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ]
+ // CHECK: extensions: [ [SPV_KHR_16bit_storage] ]
// CHECK: capabilities: [ [DotProduct] [DotProductInput4x8BitPacked] ]
%r = spirv.UDot %a, %a, <PackedVectorFormat4x8Bit>: i32 -> i32
return %r: i32
@@ -128,7 +128,7 @@ func.func @udot_scalar_i32_i32(%a: i32) -> i32 {
func.func @udot_vector_4xi8_i64(%a: vector<4xi8>) -> i64 {
// CHECK: min version: v1.0
// CHECK: max version: v1.6
- // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ]
+ // CHECK: extensions: [ [SPV_KHR_16bit_storage] ]
// CHECK: capabilities: [ [DotProduct] [DotProductInput4x8Bit] ]
%r = spirv.UDot %a, %a: vector<4xi8> -> i64
return %r: i64
@@ -138,7 +138,7 @@ func.func @udot_vector_4xi8_i64(%a: vector<4xi8>) -> i64 {
func.func @udot_vector_4xi16_i64(%a: vector<4xi16>) -> i64 {
// CHECK: min version: v1.0
// CHECK: max version: v1.6
- // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ]
+ // CHECK: extensions: [ [SPV_KHR_16bit_storage] ]
// CHECK: capabilities: [ [DotProduct] [DotProductInputAll] ]
%r = spirv.UDot %a, %a: vector<4xi16> -> i64
return %r: i64
@@ -148,7 +148,7 @@ func.func @udot_vector_4xi16_i64(%a: vector<4xi16>) -> i64 {
func.func @sdot_acc_sat_scalar_i32_i32(%a: i32) -> i32 {
// CHECK: min version: v1.0
// CHECK: max version: v1.6
- // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ]
+ // CHECK: extensions: [ [SPV_KHR_16bit_storage] ]
// CHECK: capabilities: [ [DotProduct] [DotProductInput4x8BitPacked] ]
%r = spirv.SDotAccSat %a, %a, %a, <PackedVectorFormat4x8Bit>: i32 -> i32
return %r: i32
@@ -158,7 +158,7 @@ func.func @sdot_acc_sat_scalar_i32_i32(%a: i32) -> i32 {
func.func @sdot_acc_sat_vector_4xi8_i64(%a: vector<4xi8>, %acc: i64) -> i64 {
// CHECK: min version: v1.0
// CHECK: max version: v1.6
- // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ]
+ // CHECK: extensions: [ [SPV_KHR_16bit_storage] ]
// CHECK: capabilities: [ [DotProduct] [DotProductInput4x8Bit] ]
%r = spirv.SDotAccSat %a, %a, %acc: vector<4xi8> -> i64
return %r: i64
@@ -168,7 +168,7 @@ func.func @sdot_acc_sat_vector_4xi8_i64(%a: vector<4xi8>, %acc: i64) -> i64 {
func.func @sdot_acc_sat_vector_4xi16_i64(%a: vector<4xi16>, %acc: i64) -> i64 {
// CHECK: min version: v1.0
// CHECK: max version: v1.6
- // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ]
+ // CHECK: extensions: [ [SPV_KHR_16bit_storage] ]
// CHECK: capabilities: [ [DotProduct] [DotProductInputAll] ]
%r = spirv.SDotAccSat %a, %a, %acc: vector<4xi16> -> i64
return %r: i64
@@ -178,7 +178,7 @@ func.func @sdot_acc_sat_vector_4xi16_i64(%a: vector<4xi16>, %acc: i64) -> i64 {
func.func @sudot_acc_sat_scalar_i32_i32(%a: i32) -> i32 {
// CHECK: min version: v1.0
// CHECK: max version: v1.6
- // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ]
+ // CHECK: extensions: [ [SPV_KHR_16bit_storage] ]
// CHECK: capabilities: [ [DotProduct] [DotProductInput4x8BitPacked] ]
%r = spirv.SUDotAccSat %a, %a, %a, <PackedVectorFormat4x8Bit>: i32 -> i32
return %r: i32
@@ -188,7 +188,7 @@ func.func @sudot_acc_sat_scalar_i32_i32(%a: i32) -> i32 {
func.func @sudot_acc_sat_vector_4xi8_i64(%a: vector<4xi8>, %acc: i64) -> i64 {
// CHECK: min version: v1.0
// CHECK: max version: v1.6
- // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ]
+ // CHECK: extensions: [ [SPV_KHR_16bit_storage] ]
// CHECK: capabilities: [ [DotProduct] [DotProductInput4x8Bit] ]
%r = spirv.SUDotAccSat %a, %a, %acc: vector<4xi8> -> i64
return %r: i64
@@ -198,7 +198,7 @@ func.func @sudot_acc_sat_vector_4xi8_i64(%a: vector<4xi8>, %acc: i64) -> i64 {
func.func @sudot_acc_sat_vector_4xi16_i64(%a: vector<4xi16>, %acc: i64) -> i64 {
// CHECK: min version: v1.0
// CHECK: max version: v1.6
- // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ]
+ // CHECK: extensions: [ [SPV_KHR_16bit_storage] ]
// CHECK: capabilities: [ [DotProduct] [DotProductInputAll] ]
%r = spirv.SUDotAccSat %a, %a, %acc: vector<4xi16> -> i64
return %r: i64
@@ -208,7 +208,7 @@ func.func @sudot_acc_sat_vector_4xi16_i64(%a: vector<4xi16>, %acc: i64) -> i64 {
func.func @udot_acc_sat_scalar_i32_i32(%a: i32) -> i32 {
// CHECK: min version: v1.0
// CHECK: max version: v1.6
- // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ]
+ // CHECK: extensions: [ [SPV_KHR_16bit_storage] ]
// CHECK: capabilities: [ [DotProduct] [DotProductInput4x8BitPacked] ]
%r = spirv.UDotAccSat %a, %a, %a, <PackedVectorFormat4x8Bit>: i32 -> i32
return %r: i32
@@ -218,7 +218,7 @@ func.func @udot_acc_sat_scalar_i32_i32(%a: i32) -> i32 {
func.func @udot_acc_sat_vector_4xi8_i64(%a: vector<4xi8>, %acc: i64) -> i64 {
// CHECK: min version: v1.0
// CHECK: max version: v1.6
- // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ]
+ // CHECK: extensions: [ [SPV_KHR_16bit_storage] ]
// CHECK: capabilities: [ [DotProduct] [DotProductInput4x8Bit] ]
%r = spirv.UDotAccSat %a, %a, %acc: vector<4xi8> -> i64
return %r: i64
@@ -228,7 +228,7 @@ func.func @udot_acc_sat_vector_4xi8_i64(%a: vector<4xi8>, %acc: i64) -> i64 {
func.func @udot_acc_sat_vector_4xi16_i64(%a: vector<4xi16>, %acc: i64) -> i64 {
// CHECK: min version: v1.0
// CHECK: max version: v1.6
- // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ]
+ // CHECK: extensions: [ [SPV_KHR_16bit_storage] ]
// CHECK: capabilities: [ [DotProduct] [DotProductInputAll] ]
%r = spirv.UDotAccSat %a, %a, %acc: vector<4xi16> -> i64
return %r: i64
diff --git a/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir b/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir
index f90111b4c88618..f75f8f8489efc1 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir
@@ -7,15 +7,13 @@ func.func @vector_transfer_ops_0d_memref(%mem: memref<f32>, %vec: vector<1x1x1xf
%f0 = arith.constant 0.0 : f32
// CHECK-NEXT: %[[S:.*]] = memref.load %[[MEM]][] : memref<f32>
-// CHECK-NEXT: %[[V:.*]] = vector.broadcast %[[S]] : f32 to vector<f32>
%0 = vector.transfer_read %mem[], %f0 : memref<f32>, vector<f32>
-// CHECK-NEXT: %[[SS:.*]] = vector.extractelement %[[V]][] : vector<f32>
-// CHECK-NEXT: memref.store %[[SS]], %[[MEM]][] : memref<f32>
+// CHECK-NEXT: memref.store %[[S]], %[[MEM]][] : memref<f32>
vector.transfer_write %0, %mem[] : vector<f32>, memref<f32>
-// CHECK-NEXT: %[[VV:.*]] = vector.extract %arg1[0, 0, 0] : f32 from vector<1x1x1xf32>
-// CHECK-NEXT: memref.store %[[VV]], %[[MEM]][] : memref<f32>
+// CHECK-NEXT: %[[V:.*]] = vector.extract %arg1[0, 0, 0] : f32 from vector<1x1x1xf32>
+// CHECK-NEXT: memref.store %[[V]], %[[MEM]][] : memref<f32>
vector.store %vec, %mem[] : memref<f32>, vector<1x1x1xf32>
return
|
|
@llvm/pr-subscribers-mlir-spirv Author: Harrison Hao (harrisonGPU) ChangesSupport 0-d vectors natively in VectorStoreToMemrefStoreLowering. Closes #112913 Full diff: https://github.com/llvm/llvm-project/pull/112937.diff 4 Files Affected:
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
index f9428a4ce28640..1cb3baaef82baf 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
@@ -530,15 +530,9 @@ struct VectorStoreToMemrefStoreLowering
return rewriter.notifyMatchFailure(storeOp, "not single element vector");
Value extracted;
- if (vecType.getRank() == 0) {
- // TODO: Unifiy once ExtractOp supports 0-d vectors.
- extracted = rewriter.create<vector::ExtractElementOp>(
- storeOp.getLoc(), storeOp.getValueToStore());
- } else {
- SmallVector<int64_t> indices(vecType.getRank(), 0);
- extracted = rewriter.create<vector::ExtractOp>(
- storeOp.getLoc(), storeOp.getValueToStore(), indices);
- }
+ SmallVector<int64_t> indices(vecType.getRank(), 0);
+ extracted = rewriter.create<vector::ExtractOp>(
+ storeOp.getLoc(), storeOp.getValueToStore(), indices);
rewriter.replaceOpWithNewOp<memref::StoreOp>(
storeOp, extracted, storeOp.getBase(), storeOp.getIndices());
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index c1de24fd0403ce..abbdbe02ce6c1e 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -2971,9 +2971,8 @@ func.func @vector_store_op_0d(%memref : memref<200x100xf32>, %i : index, %j : in
// CHECK-LABEL: func @vector_store_op_0d
// CHECK: %[[val:.*]] = arith.constant dense<1.100000e+01> : vector<f32>
// CHECK: %[[cast:.*]] = builtin.unrealized_conversion_cast %[[val]] : vector<f32> to vector<1xf32>
-// CHECK: %[[c0:.*]] = llvm.mlir.constant(0 : index) : i64
-// CHECK: %[[extracted:.*]] = llvm.extractelement %[[cast]][%[[c0]] : i64] : vector<1xf32>
-// CHECK: memref.store %[[extracted]], %{{.*}}[%{{.*}}, %{{.*}}]
+// CHECK: %[[cast2:.*]] = builtin.unrealized_conversion_cast %[[cast]] : vector<1xf32> to f32
+// CHECK: memref.store %[[cast2]], %{{.*}}[%{{.*}}, %{{.*}}]
// -----
diff --git a/mlir/test/Dialect/SPIRV/IR/availability.mlir b/mlir/test/Dialect/SPIRV/IR/availability.mlir
index c583a48eba2704..ceebeeffcf2677 100644
--- a/mlir/test/Dialect/SPIRV/IR/availability.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/availability.mlir
@@ -58,7 +58,7 @@ func.func @module_physical_storage_buffer64_vulkan() {
func.func @sdot_scalar_i32_i32(%a: i32) -> i32 {
// CHECK: min version: v1.0
// CHECK: max version: v1.6
- // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ]
+ // CHECK: extensions: [ [SPV_KHR_16bit_storage] ]
// CHECK: capabilities: [ [DotProduct] [DotProductInput4x8BitPacked] ]
%r = spirv.SDot %a, %a, <PackedVectorFormat4x8Bit>: i32 -> i32
return %r: i32
@@ -68,7 +68,7 @@ func.func @sdot_scalar_i32_i32(%a: i32) -> i32 {
func.func @sdot_vector_4xi8_i64(%a: vector<4xi8>) -> i64 {
// CHECK: min version: v1.0
// CHECK: max version: v1.6
- // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ]
+ // CHECK: extensions: [ [SPV_KHR_16bit_storage] ]
// CHECK: capabilities: [ [DotProduct] [DotProductInput4x8Bit] ]
%r = spirv.SDot %a, %a: vector<4xi8> -> i64
return %r: i64
@@ -78,7 +78,7 @@ func.func @sdot_vector_4xi8_i64(%a: vector<4xi8>) -> i64 {
func.func @sdot_vector_4xi16_i64(%a: vector<4xi16>) -> i64 {
// CHECK: min version: v1.0
// CHECK: max version: v1.6
- // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ]
+ // CHECK: extensions: [ [SPV_KHR_16bit_storage] ]
// CHECK: capabilities: [ [DotProduct] [DotProductInputAll] ]
%r = spirv.SDot %a, %a: vector<4xi16> -> i64
return %r: i64
@@ -88,7 +88,7 @@ func.func @sdot_vector_4xi16_i64(%a: vector<4xi16>) -> i64 {
func.func @sudot_scalar_i32_i32(%a: i32) -> i32 {
// CHECK: min version: v1.0
// CHECK: max version: v1.6
- // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ]
+ // CHECK: extensions: [ [SPV_KHR_16bit_storage] ]
// CHECK: capabilities: [ [DotProduct] [DotProductInput4x8BitPacked] ]
%r = spirv.SUDot %a, %a, <PackedVectorFormat4x8Bit>: i32 -> i32
return %r: i32
@@ -98,7 +98,7 @@ func.func @sudot_scalar_i32_i32(%a: i32) -> i32 {
func.func @sudot_vector_4xi8_i64(%a: vector<4xi8>) -> i64 {
// CHECK: min version: v1.0
// CHECK: max version: v1.6
- // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ]
+ // CHECK: extensions: [ [SPV_KHR_16bit_storage] ]
// CHECK: capabilities: [ [DotProduct] [DotProductInput4x8Bit] ]
%r = spirv.SUDot %a, %a: vector<4xi8> -> i64
return %r: i64
@@ -108,7 +108,7 @@ func.func @sudot_vector_4xi8_i64(%a: vector<4xi8>) -> i64 {
func.func @sudot_vector_4xi16_i64(%a: vector<4xi16>) -> i64 {
// CHECK: min version: v1.0
// CHECK: max version: v1.6
- // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ]
+ // CHECK: extensions: [ [SPV_KHR_16bit_storage] ]
// CHECK: capabilities: [ [DotProduct] [DotProductInputAll] ]
%r = spirv.SUDot %a, %a: vector<4xi16> -> i64
return %r: i64
@@ -118,7 +118,7 @@ func.func @sudot_vector_4xi16_i64(%a: vector<4xi16>) -> i64 {
func.func @udot_scalar_i32_i32(%a: i32) -> i32 {
// CHECK: min version: v1.0
// CHECK: max version: v1.6
- // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ]
+ // CHECK: extensions: [ [SPV_KHR_16bit_storage] ]
// CHECK: capabilities: [ [DotProduct] [DotProductInput4x8BitPacked] ]
%r = spirv.UDot %a, %a, <PackedVectorFormat4x8Bit>: i32 -> i32
return %r: i32
@@ -128,7 +128,7 @@ func.func @udot_scalar_i32_i32(%a: i32) -> i32 {
func.func @udot_vector_4xi8_i64(%a: vector<4xi8>) -> i64 {
// CHECK: min version: v1.0
// CHECK: max version: v1.6
- // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ]
+ // CHECK: extensions: [ [SPV_KHR_16bit_storage] ]
// CHECK: capabilities: [ [DotProduct] [DotProductInput4x8Bit] ]
%r = spirv.UDot %a, %a: vector<4xi8> -> i64
return %r: i64
@@ -138,7 +138,7 @@ func.func @udot_vector_4xi8_i64(%a: vector<4xi8>) -> i64 {
func.func @udot_vector_4xi16_i64(%a: vector<4xi16>) -> i64 {
// CHECK: min version: v1.0
// CHECK: max version: v1.6
- // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ]
+ // CHECK: extensions: [ [SPV_KHR_16bit_storage] ]
// CHECK: capabilities: [ [DotProduct] [DotProductInputAll] ]
%r = spirv.UDot %a, %a: vector<4xi16> -> i64
return %r: i64
@@ -148,7 +148,7 @@ func.func @udot_vector_4xi16_i64(%a: vector<4xi16>) -> i64 {
func.func @sdot_acc_sat_scalar_i32_i32(%a: i32) -> i32 {
// CHECK: min version: v1.0
// CHECK: max version: v1.6
- // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ]
+ // CHECK: extensions: [ [SPV_KHR_16bit_storage] ]
// CHECK: capabilities: [ [DotProduct] [DotProductInput4x8BitPacked] ]
%r = spirv.SDotAccSat %a, %a, %a, <PackedVectorFormat4x8Bit>: i32 -> i32
return %r: i32
@@ -158,7 +158,7 @@ func.func @sdot_acc_sat_scalar_i32_i32(%a: i32) -> i32 {
func.func @sdot_acc_sat_vector_4xi8_i64(%a: vector<4xi8>, %acc: i64) -> i64 {
// CHECK: min version: v1.0
// CHECK: max version: v1.6
- // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ]
+ // CHECK: extensions: [ [SPV_KHR_16bit_storage] ]
// CHECK: capabilities: [ [DotProduct] [DotProductInput4x8Bit] ]
%r = spirv.SDotAccSat %a, %a, %acc: vector<4xi8> -> i64
return %r: i64
@@ -168,7 +168,7 @@ func.func @sdot_acc_sat_vector_4xi8_i64(%a: vector<4xi8>, %acc: i64) -> i64 {
func.func @sdot_acc_sat_vector_4xi16_i64(%a: vector<4xi16>, %acc: i64) -> i64 {
// CHECK: min version: v1.0
// CHECK: max version: v1.6
- // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ]
+ // CHECK: extensions: [ [SPV_KHR_16bit_storage] ]
// CHECK: capabilities: [ [DotProduct] [DotProductInputAll] ]
%r = spirv.SDotAccSat %a, %a, %acc: vector<4xi16> -> i64
return %r: i64
@@ -178,7 +178,7 @@ func.func @sdot_acc_sat_vector_4xi16_i64(%a: vector<4xi16>, %acc: i64) -> i64 {
func.func @sudot_acc_sat_scalar_i32_i32(%a: i32) -> i32 {
// CHECK: min version: v1.0
// CHECK: max version: v1.6
- // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ]
+ // CHECK: extensions: [ [SPV_KHR_16bit_storage] ]
// CHECK: capabilities: [ [DotProduct] [DotProductInput4x8BitPacked] ]
%r = spirv.SUDotAccSat %a, %a, %a, <PackedVectorFormat4x8Bit>: i32 -> i32
return %r: i32
@@ -188,7 +188,7 @@ func.func @sudot_acc_sat_scalar_i32_i32(%a: i32) -> i32 {
func.func @sudot_acc_sat_vector_4xi8_i64(%a: vector<4xi8>, %acc: i64) -> i64 {
// CHECK: min version: v1.0
// CHECK: max version: v1.6
- // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ]
+ // CHECK: extensions: [ [SPV_KHR_16bit_storage] ]
// CHECK: capabilities: [ [DotProduct] [DotProductInput4x8Bit] ]
%r = spirv.SUDotAccSat %a, %a, %acc: vector<4xi8> -> i64
return %r: i64
@@ -198,7 +198,7 @@ func.func @sudot_acc_sat_vector_4xi8_i64(%a: vector<4xi8>, %acc: i64) -> i64 {
func.func @sudot_acc_sat_vector_4xi16_i64(%a: vector<4xi16>, %acc: i64) -> i64 {
// CHECK: min version: v1.0
// CHECK: max version: v1.6
- // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ]
+ // CHECK: extensions: [ [SPV_KHR_16bit_storage] ]
// CHECK: capabilities: [ [DotProduct] [DotProductInputAll] ]
%r = spirv.SUDotAccSat %a, %a, %acc: vector<4xi16> -> i64
return %r: i64
@@ -208,7 +208,7 @@ func.func @sudot_acc_sat_vector_4xi16_i64(%a: vector<4xi16>, %acc: i64) -> i64 {
func.func @udot_acc_sat_scalar_i32_i32(%a: i32) -> i32 {
// CHECK: min version: v1.0
// CHECK: max version: v1.6
- // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ]
+ // CHECK: extensions: [ [SPV_KHR_16bit_storage] ]
// CHECK: capabilities: [ [DotProduct] [DotProductInput4x8BitPacked] ]
%r = spirv.UDotAccSat %a, %a, %a, <PackedVectorFormat4x8Bit>: i32 -> i32
return %r: i32
@@ -218,7 +218,7 @@ func.func @udot_acc_sat_scalar_i32_i32(%a: i32) -> i32 {
func.func @udot_acc_sat_vector_4xi8_i64(%a: vector<4xi8>, %acc: i64) -> i64 {
// CHECK: min version: v1.0
// CHECK: max version: v1.6
- // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ]
+ // CHECK: extensions: [ [SPV_KHR_16bit_storage] ]
// CHECK: capabilities: [ [DotProduct] [DotProductInput4x8Bit] ]
%r = spirv.UDotAccSat %a, %a, %acc: vector<4xi8> -> i64
return %r: i64
@@ -228,7 +228,7 @@ func.func @udot_acc_sat_vector_4xi8_i64(%a: vector<4xi8>, %acc: i64) -> i64 {
func.func @udot_acc_sat_vector_4xi16_i64(%a: vector<4xi16>, %acc: i64) -> i64 {
// CHECK: min version: v1.0
// CHECK: max version: v1.6
- // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ]
+ // CHECK: extensions: [ [SPV_KHR_16bit_storage] ]
// CHECK: capabilities: [ [DotProduct] [DotProductInputAll] ]
%r = spirv.UDotAccSat %a, %a, %acc: vector<4xi16> -> i64
return %r: i64
diff --git a/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir b/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir
index f90111b4c88618..f75f8f8489efc1 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir
@@ -7,15 +7,13 @@ func.func @vector_transfer_ops_0d_memref(%mem: memref<f32>, %vec: vector<1x1x1xf
%f0 = arith.constant 0.0 : f32
// CHECK-NEXT: %[[S:.*]] = memref.load %[[MEM]][] : memref<f32>
-// CHECK-NEXT: %[[V:.*]] = vector.broadcast %[[S]] : f32 to vector<f32>
%0 = vector.transfer_read %mem[], %f0 : memref<f32>, vector<f32>
-// CHECK-NEXT: %[[SS:.*]] = vector.extractelement %[[V]][] : vector<f32>
-// CHECK-NEXT: memref.store %[[SS]], %[[MEM]][] : memref<f32>
+// CHECK-NEXT: memref.store %[[S]], %[[MEM]][] : memref<f32>
vector.transfer_write %0, %mem[] : vector<f32>, memref<f32>
-// CHECK-NEXT: %[[VV:.*]] = vector.extract %arg1[0, 0, 0] : f32 from vector<1x1x1xf32>
-// CHECK-NEXT: memref.store %[[VV]], %[[MEM]][] : memref<f32>
+// CHECK-NEXT: %[[V:.*]] = vector.extract %arg1[0, 0, 0] : f32 from vector<1x1x1xf32>
+// CHECK-NEXT: memref.store %[[V]], %[[MEM]][] : memref<f32>
vector.store %vec, %mem[] : memref<f32>, vector<1x1x1xf32>
return
|
banach-space
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LG, just a couple of comments. Thanks!
| %f0 = arith.constant 0.0 : f32 | ||
|
|
||
| // CHECK-NEXT: %[[S:.*]] = memref.load %[[MEM]][] : memref<f32> | ||
| // CHECK-NEXT: %[[V:.*]] = vector.broadcast %[[S]] : f32 to vector<f32> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Where is this broadcast generated? Perhaps that's something that could deleted?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Because vector::ExtractElementOp requires a vector type input, we need to use vector.broadcast to convert scalars into vectors. We directly use vector::ExtractOp, and even for 0-dimensional vectors, there's no longer a need to perform vector.broadcast. :-)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This makes sense, but what code-path generated vector.broadcast to begin with? And how do we know it doesn't require updating?
| extracted = rewriter.create<vector::ExtractOp>( | ||
| storeOp.getLoc(), storeOp.getValueToStore(), indices); | ||
| } | ||
| SmallVector<int64_t> indices(vecType.getRank(), 0); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The comment above this pattern needs updating.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay, I have already updated it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The comment is still referring to vector.extractelement, which your patch gets rid of the vector.extractelement path.
2b9e931 to
5f61d16
Compare
banach-space
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is marked as "draft" again.
Please, could you decide whether this is "draft" or not? Without additional comments, this is very confusing to the reviewers. And not good use of time.
| extracted = rewriter.create<vector::ExtractOp>( | ||
| storeOp.getLoc(), storeOp.getValueToStore(), indices); | ||
| } | ||
| SmallVector<int64_t> indices(vecType.getRank(), 0); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The comment is still referring to vector.extractelement, which your patch gets rid of the vector.extractelement path.
| %f0 = arith.constant 0.0 : f32 | ||
|
|
||
| // CHECK-NEXT: %[[S:.*]] = memref.load %[[MEM]][] : memref<f32> | ||
| // CHECK-NEXT: %[[V:.*]] = vector.broadcast %[[S]] : f32 to vector<f32> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This makes sense, but what code-path generated vector.broadcast to begin with? And how do we know it doesn't require updating?
Sorry Andrzej, because I think I misunderstand this issue, so I need some time to think about it, so sorry. :-) |
I understand :) For clarity in the future, it would help us reviewers if you could leave a quick note or a status update when a PR moves back to "draft" after being initially ready for review. This will ensure that everyone is on the same page and can use their time efficiently. If there's anything specific you’re stuck on, feel free to share! Happy to help if possible. |
dcaballe
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LG % ongoing comments. Thanks!
dcaballe
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry, I didn't see the "Draft" state. What is missing?
Groverkss
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for the patch. I think there is a misunderstanding of the issue linked. The issue is about removing this pattern. You can see the comment on VectorLoadToMemrefLoad:
/// Replace a 0-d vector.load with a memref.load + vector.broadcast.
// TODO: we shouldn't cross the vector/scalar domains just for this
// but atm we lack the infra to avoid it. Possible solutions include:
// - go directly to LLVM + bitcast
// - introduce a bitcast op and likely a new pointer dialect
// - let memref.load/store additionally support the 0-d vector case
// There are still deeper data layout issues lingering even in this
// trivial case (for architectures for which this matters).
Support non 0-d vectors in this pattern doesn't really get us anything. We want to instead verify if the TransferRead -> vector.load -> LLVM/SPIRV conversions are robust enough to handle 0-d vectorsa and fix them.
Hi, Diego, I misunderstood this issue, so I will update this PR. Thank you for reviewing this PR. :) |
Thanks Andrzej, I will update this PR again. :-) |
Okay, I will consider and verify it. :-) |
|
Nice, the patch is in the right direction. Can you:
|
Remove 0-d corner case condition.
Closes #112913