Skip to content

Conversation

@harrisonGPU
Copy link
Contributor

@harrisonGPU harrisonGPU commented Oct 18, 2024

Remove 0-d corner case condition.

Closes #112913

@harrisonGPU harrisonGPU changed the title [MLIR] Merge AnyVector and AnyVectorOfAnyRank type constraints. [mlir][Vector] Support 0-d vectors natively in VectorStoreToMemrefStoreLowering. Nov 5, 2024
@harrisonGPU harrisonGPU marked this pull request as ready for review November 5, 2024 12:14
@harrisonGPU harrisonGPU marked this pull request as draft November 5, 2024 12:14
@llvmbot
Copy link
Member

llvmbot commented Nov 5, 2024

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-vector

Author: Harrison Hao (harrisonGPU)

Changes

Support 0-d vectors natively in VectorStoreToMemrefStoreLowering.

Closes #112913


Full diff: https://github.com/llvm/llvm-project/pull/112937.diff

4 Files Affected:

  • (modified) mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp (+3-9)
  • (modified) mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir (+2-3)
  • (modified) mlir/test/Dialect/SPIRV/IR/availability.mlir (+18-18)
  • (modified) mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir (+3-5)
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
index f9428a4ce28640..1cb3baaef82baf 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
@@ -530,15 +530,9 @@ struct VectorStoreToMemrefStoreLowering
       return rewriter.notifyMatchFailure(storeOp, "not single element vector");
 
     Value extracted;
-    if (vecType.getRank() == 0) {
-      // TODO: Unifiy once ExtractOp supports 0-d vectors.
-      extracted = rewriter.create<vector::ExtractElementOp>(
-          storeOp.getLoc(), storeOp.getValueToStore());
-    } else {
-      SmallVector<int64_t> indices(vecType.getRank(), 0);
-      extracted = rewriter.create<vector::ExtractOp>(
-          storeOp.getLoc(), storeOp.getValueToStore(), indices);
-    }
+    SmallVector<int64_t> indices(vecType.getRank(), 0);
+    extracted = rewriter.create<vector::ExtractOp>(
+        storeOp.getLoc(), storeOp.getValueToStore(), indices);
 
     rewriter.replaceOpWithNewOp<memref::StoreOp>(
         storeOp, extracted, storeOp.getBase(), storeOp.getIndices());
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index c1de24fd0403ce..abbdbe02ce6c1e 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -2971,9 +2971,8 @@ func.func @vector_store_op_0d(%memref : memref<200x100xf32>, %i : index, %j : in
 // CHECK-LABEL: func @vector_store_op_0d
 // CHECK: %[[val:.*]] = arith.constant dense<1.100000e+01> : vector<f32>
 // CHECK: %[[cast:.*]] = builtin.unrealized_conversion_cast %[[val]] : vector<f32> to vector<1xf32>
-// CHECK: %[[c0:.*]] = llvm.mlir.constant(0 : index) : i64
-// CHECK: %[[extracted:.*]] = llvm.extractelement %[[cast]][%[[c0]] : i64] : vector<1xf32>
-// CHECK: memref.store %[[extracted]], %{{.*}}[%{{.*}}, %{{.*}}]
+// CHECK: %[[cast2:.*]] = builtin.unrealized_conversion_cast %[[cast]] : vector<1xf32> to f32
+// CHECK: memref.store %[[cast2]], %{{.*}}[%{{.*}}, %{{.*}}]
 
 // -----
 
diff --git a/mlir/test/Dialect/SPIRV/IR/availability.mlir b/mlir/test/Dialect/SPIRV/IR/availability.mlir
index c583a48eba2704..ceebeeffcf2677 100644
--- a/mlir/test/Dialect/SPIRV/IR/availability.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/availability.mlir
@@ -58,7 +58,7 @@ func.func @module_physical_storage_buffer64_vulkan() {
 func.func @sdot_scalar_i32_i32(%a: i32) -> i32 {
   // CHECK: min version: v1.0
   // CHECK: max version: v1.6
-  // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ]
+  // CHECK: extensions: [ [SPV_KHR_16bit_storage] ]
   // CHECK: capabilities: [ [DotProduct] [DotProductInput4x8BitPacked] ]
   %r = spirv.SDot %a, %a, <PackedVectorFormat4x8Bit>: i32 -> i32
   return %r: i32
@@ -68,7 +68,7 @@ func.func @sdot_scalar_i32_i32(%a: i32) -> i32 {
 func.func @sdot_vector_4xi8_i64(%a: vector<4xi8>) -> i64 {
   // CHECK: min version: v1.0
   // CHECK: max version: v1.6
-  // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ]
+  // CHECK: extensions: [ [SPV_KHR_16bit_storage] ]
   // CHECK: capabilities: [ [DotProduct] [DotProductInput4x8Bit] ]
   %r = spirv.SDot %a, %a: vector<4xi8> -> i64
   return %r: i64
@@ -78,7 +78,7 @@ func.func @sdot_vector_4xi8_i64(%a: vector<4xi8>) -> i64 {
 func.func @sdot_vector_4xi16_i64(%a: vector<4xi16>) -> i64 {
   // CHECK: min version: v1.0
   // CHECK: max version: v1.6
-  // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ]
+  // CHECK: extensions: [ [SPV_KHR_16bit_storage] ]
   // CHECK: capabilities: [ [DotProduct] [DotProductInputAll] ]
   %r = spirv.SDot %a, %a: vector<4xi16> -> i64
   return %r: i64
@@ -88,7 +88,7 @@ func.func @sdot_vector_4xi16_i64(%a: vector<4xi16>) -> i64 {
 func.func @sudot_scalar_i32_i32(%a: i32) -> i32 {
   // CHECK: min version: v1.0
   // CHECK: max version: v1.6
-  // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ]
+  // CHECK: extensions: [ [SPV_KHR_16bit_storage] ]
   // CHECK: capabilities: [ [DotProduct] [DotProductInput4x8BitPacked] ]
   %r = spirv.SUDot %a, %a, <PackedVectorFormat4x8Bit>: i32 -> i32
   return %r: i32
@@ -98,7 +98,7 @@ func.func @sudot_scalar_i32_i32(%a: i32) -> i32 {
 func.func @sudot_vector_4xi8_i64(%a: vector<4xi8>) -> i64 {
   // CHECK: min version: v1.0
   // CHECK: max version: v1.6
-  // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ]
+  // CHECK: extensions: [ [SPV_KHR_16bit_storage] ]
   // CHECK: capabilities: [ [DotProduct] [DotProductInput4x8Bit] ]
   %r = spirv.SUDot %a, %a: vector<4xi8> -> i64
   return %r: i64
@@ -108,7 +108,7 @@ func.func @sudot_vector_4xi8_i64(%a: vector<4xi8>) -> i64 {
 func.func @sudot_vector_4xi16_i64(%a: vector<4xi16>) -> i64 {
   // CHECK: min version: v1.0
   // CHECK: max version: v1.6
-  // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ]
+  // CHECK: extensions: [ [SPV_KHR_16bit_storage] ]
   // CHECK: capabilities: [ [DotProduct] [DotProductInputAll] ]
   %r = spirv.SUDot %a, %a: vector<4xi16> -> i64
   return %r: i64
@@ -118,7 +118,7 @@ func.func @sudot_vector_4xi16_i64(%a: vector<4xi16>) -> i64 {
 func.func @udot_scalar_i32_i32(%a: i32) -> i32 {
   // CHECK: min version: v1.0
   // CHECK: max version: v1.6
-  // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ]
+  // CHECK: extensions: [ [SPV_KHR_16bit_storage] ]
   // CHECK: capabilities: [ [DotProduct] [DotProductInput4x8BitPacked] ]
   %r = spirv.UDot %a, %a, <PackedVectorFormat4x8Bit>: i32 -> i32
   return %r: i32
@@ -128,7 +128,7 @@ func.func @udot_scalar_i32_i32(%a: i32) -> i32 {
 func.func @udot_vector_4xi8_i64(%a: vector<4xi8>) -> i64 {
   // CHECK: min version: v1.0
   // CHECK: max version: v1.6
-  // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ]
+  // CHECK: extensions: [ [SPV_KHR_16bit_storage] ]
   // CHECK: capabilities: [ [DotProduct] [DotProductInput4x8Bit] ]
   %r = spirv.UDot %a, %a: vector<4xi8> -> i64
   return %r: i64
@@ -138,7 +138,7 @@ func.func @udot_vector_4xi8_i64(%a: vector<4xi8>) -> i64 {
 func.func @udot_vector_4xi16_i64(%a: vector<4xi16>) -> i64 {
   // CHECK: min version: v1.0
   // CHECK: max version: v1.6
-  // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ]
+  // CHECK: extensions: [ [SPV_KHR_16bit_storage] ]
   // CHECK: capabilities: [ [DotProduct] [DotProductInputAll] ]
   %r = spirv.UDot %a, %a: vector<4xi16> -> i64
   return %r: i64
@@ -148,7 +148,7 @@ func.func @udot_vector_4xi16_i64(%a: vector<4xi16>) -> i64 {
 func.func @sdot_acc_sat_scalar_i32_i32(%a: i32) -> i32 {
   // CHECK: min version: v1.0
   // CHECK: max version: v1.6
-  // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ]
+  // CHECK: extensions: [ [SPV_KHR_16bit_storage] ]
   // CHECK: capabilities: [ [DotProduct] [DotProductInput4x8BitPacked] ]
   %r = spirv.SDotAccSat %a, %a, %a, <PackedVectorFormat4x8Bit>: i32 -> i32
   return %r: i32
@@ -158,7 +158,7 @@ func.func @sdot_acc_sat_scalar_i32_i32(%a: i32) -> i32 {
 func.func @sdot_acc_sat_vector_4xi8_i64(%a: vector<4xi8>, %acc: i64) -> i64 {
   // CHECK: min version: v1.0
   // CHECK: max version: v1.6
-  // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ]
+  // CHECK: extensions: [ [SPV_KHR_16bit_storage] ]
   // CHECK: capabilities: [ [DotProduct] [DotProductInput4x8Bit] ]
   %r = spirv.SDotAccSat %a, %a, %acc: vector<4xi8> -> i64
   return %r: i64
@@ -168,7 +168,7 @@ func.func @sdot_acc_sat_vector_4xi8_i64(%a: vector<4xi8>, %acc: i64) -> i64 {
 func.func @sdot_acc_sat_vector_4xi16_i64(%a: vector<4xi16>, %acc: i64) -> i64 {
   // CHECK: min version: v1.0
   // CHECK: max version: v1.6
-  // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ]
+  // CHECK: extensions: [ [SPV_KHR_16bit_storage] ]
   // CHECK: capabilities: [ [DotProduct] [DotProductInputAll] ]
   %r = spirv.SDotAccSat %a, %a, %acc: vector<4xi16> -> i64
   return %r: i64
@@ -178,7 +178,7 @@ func.func @sdot_acc_sat_vector_4xi16_i64(%a: vector<4xi16>, %acc: i64) -> i64 {
 func.func @sudot_acc_sat_scalar_i32_i32(%a: i32) -> i32 {
   // CHECK: min version: v1.0
   // CHECK: max version: v1.6
-  // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ]
+  // CHECK: extensions: [ [SPV_KHR_16bit_storage] ]
   // CHECK: capabilities: [ [DotProduct] [DotProductInput4x8BitPacked] ]
   %r = spirv.SUDotAccSat %a, %a, %a, <PackedVectorFormat4x8Bit>: i32 -> i32
   return %r: i32
@@ -188,7 +188,7 @@ func.func @sudot_acc_sat_scalar_i32_i32(%a: i32) -> i32 {
 func.func @sudot_acc_sat_vector_4xi8_i64(%a: vector<4xi8>, %acc: i64) -> i64 {
   // CHECK: min version: v1.0
   // CHECK: max version: v1.6
-  // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ]
+  // CHECK: extensions: [ [SPV_KHR_16bit_storage] ]
   // CHECK: capabilities: [ [DotProduct] [DotProductInput4x8Bit] ]
   %r = spirv.SUDotAccSat %a, %a, %acc: vector<4xi8> -> i64
   return %r: i64
@@ -198,7 +198,7 @@ func.func @sudot_acc_sat_vector_4xi8_i64(%a: vector<4xi8>, %acc: i64) -> i64 {
 func.func @sudot_acc_sat_vector_4xi16_i64(%a: vector<4xi16>, %acc: i64) -> i64 {
   // CHECK: min version: v1.0
   // CHECK: max version: v1.6
-  // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ]
+  // CHECK: extensions: [ [SPV_KHR_16bit_storage] ]
   // CHECK: capabilities: [ [DotProduct] [DotProductInputAll] ]
   %r = spirv.SUDotAccSat %a, %a, %acc: vector<4xi16> -> i64
   return %r: i64
@@ -208,7 +208,7 @@ func.func @sudot_acc_sat_vector_4xi16_i64(%a: vector<4xi16>, %acc: i64) -> i64 {
 func.func @udot_acc_sat_scalar_i32_i32(%a: i32) -> i32 {
   // CHECK: min version: v1.0
   // CHECK: max version: v1.6
-  // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ]
+  // CHECK: extensions: [ [SPV_KHR_16bit_storage] ]
   // CHECK: capabilities: [ [DotProduct] [DotProductInput4x8BitPacked] ]
   %r = spirv.UDotAccSat %a, %a, %a, <PackedVectorFormat4x8Bit>: i32 -> i32
   return %r: i32
@@ -218,7 +218,7 @@ func.func @udot_acc_sat_scalar_i32_i32(%a: i32) -> i32 {
 func.func @udot_acc_sat_vector_4xi8_i64(%a: vector<4xi8>, %acc: i64) -> i64 {
   // CHECK: min version: v1.0
   // CHECK: max version: v1.6
-  // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ]
+  // CHECK: extensions: [ [SPV_KHR_16bit_storage] ]
   // CHECK: capabilities: [ [DotProduct] [DotProductInput4x8Bit] ]
   %r = spirv.UDotAccSat %a, %a, %acc: vector<4xi8> -> i64
   return %r: i64
@@ -228,7 +228,7 @@ func.func @udot_acc_sat_vector_4xi8_i64(%a: vector<4xi8>, %acc: i64) -> i64 {
 func.func @udot_acc_sat_vector_4xi16_i64(%a: vector<4xi16>, %acc: i64) -> i64 {
   // CHECK: min version: v1.0
   // CHECK: max version: v1.6
-  // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ]
+  // CHECK: extensions: [ [SPV_KHR_16bit_storage] ]
   // CHECK: capabilities: [ [DotProduct] [DotProductInputAll] ]
   %r = spirv.UDotAccSat %a, %a, %acc: vector<4xi16> -> i64
   return %r: i64
diff --git a/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir b/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir
index f90111b4c88618..f75f8f8489efc1 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir
@@ -7,15 +7,13 @@ func.func @vector_transfer_ops_0d_memref(%mem: memref<f32>, %vec: vector<1x1x1xf
     %f0 = arith.constant 0.0 : f32
 
 //  CHECK-NEXT:   %[[S:.*]] = memref.load %[[MEM]][] : memref<f32>
-//  CHECK-NEXT:   %[[V:.*]] = vector.broadcast %[[S]] : f32 to vector<f32>
     %0 = vector.transfer_read %mem[], %f0 : memref<f32>, vector<f32>
 
-//  CHECK-NEXT:   %[[SS:.*]] = vector.extractelement %[[V]][] : vector<f32>
-//  CHECK-NEXT:   memref.store %[[SS]], %[[MEM]][] : memref<f32>
+//  CHECK-NEXT:   memref.store %[[S]], %[[MEM]][] : memref<f32>
     vector.transfer_write %0, %mem[] : vector<f32>, memref<f32>
 
-//  CHECK-NEXT:   %[[VV:.*]] = vector.extract %arg1[0, 0, 0] : f32 from vector<1x1x1xf32>
-//  CHECK-NEXT:   memref.store %[[VV]], %[[MEM]][] : memref<f32>
+//  CHECK-NEXT:   %[[V:.*]] = vector.extract %arg1[0, 0, 0] : f32 from vector<1x1x1xf32>
+//  CHECK-NEXT:   memref.store %[[V]], %[[MEM]][] : memref<f32>
     vector.store %vec, %mem[] : memref<f32>, vector<1x1x1xf32>
 
     return

@llvmbot
Copy link
Member

llvmbot commented Nov 5, 2024

@llvm/pr-subscribers-mlir-spirv

Author: Harrison Hao (harrisonGPU)

Changes

Support 0-d vectors natively in VectorStoreToMemrefStoreLowering.

Closes #112913


Full diff: https://github.com/llvm/llvm-project/pull/112937.diff

4 Files Affected:

  • (modified) mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp (+3-9)
  • (modified) mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir (+2-3)
  • (modified) mlir/test/Dialect/SPIRV/IR/availability.mlir (+18-18)
  • (modified) mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir (+3-5)
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
index f9428a4ce28640..1cb3baaef82baf 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
@@ -530,15 +530,9 @@ struct VectorStoreToMemrefStoreLowering
       return rewriter.notifyMatchFailure(storeOp, "not single element vector");
 
     Value extracted;
-    if (vecType.getRank() == 0) {
-      // TODO: Unifiy once ExtractOp supports 0-d vectors.
-      extracted = rewriter.create<vector::ExtractElementOp>(
-          storeOp.getLoc(), storeOp.getValueToStore());
-    } else {
-      SmallVector<int64_t> indices(vecType.getRank(), 0);
-      extracted = rewriter.create<vector::ExtractOp>(
-          storeOp.getLoc(), storeOp.getValueToStore(), indices);
-    }
+    SmallVector<int64_t> indices(vecType.getRank(), 0);
+    extracted = rewriter.create<vector::ExtractOp>(
+        storeOp.getLoc(), storeOp.getValueToStore(), indices);
 
     rewriter.replaceOpWithNewOp<memref::StoreOp>(
         storeOp, extracted, storeOp.getBase(), storeOp.getIndices());
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index c1de24fd0403ce..abbdbe02ce6c1e 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -2971,9 +2971,8 @@ func.func @vector_store_op_0d(%memref : memref<200x100xf32>, %i : index, %j : in
 // CHECK-LABEL: func @vector_store_op_0d
 // CHECK: %[[val:.*]] = arith.constant dense<1.100000e+01> : vector<f32>
 // CHECK: %[[cast:.*]] = builtin.unrealized_conversion_cast %[[val]] : vector<f32> to vector<1xf32>
-// CHECK: %[[c0:.*]] = llvm.mlir.constant(0 : index) : i64
-// CHECK: %[[extracted:.*]] = llvm.extractelement %[[cast]][%[[c0]] : i64] : vector<1xf32>
-// CHECK: memref.store %[[extracted]], %{{.*}}[%{{.*}}, %{{.*}}]
+// CHECK: %[[cast2:.*]] = builtin.unrealized_conversion_cast %[[cast]] : vector<1xf32> to f32
+// CHECK: memref.store %[[cast2]], %{{.*}}[%{{.*}}, %{{.*}}]
 
 // -----
 
diff --git a/mlir/test/Dialect/SPIRV/IR/availability.mlir b/mlir/test/Dialect/SPIRV/IR/availability.mlir
index c583a48eba2704..ceebeeffcf2677 100644
--- a/mlir/test/Dialect/SPIRV/IR/availability.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/availability.mlir
@@ -58,7 +58,7 @@ func.func @module_physical_storage_buffer64_vulkan() {
 func.func @sdot_scalar_i32_i32(%a: i32) -> i32 {
   // CHECK: min version: v1.0
   // CHECK: max version: v1.6
-  // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ]
+  // CHECK: extensions: [ [SPV_KHR_16bit_storage] ]
   // CHECK: capabilities: [ [DotProduct] [DotProductInput4x8BitPacked] ]
   %r = spirv.SDot %a, %a, <PackedVectorFormat4x8Bit>: i32 -> i32
   return %r: i32
@@ -68,7 +68,7 @@ func.func @sdot_scalar_i32_i32(%a: i32) -> i32 {
 func.func @sdot_vector_4xi8_i64(%a: vector<4xi8>) -> i64 {
   // CHECK: min version: v1.0
   // CHECK: max version: v1.6
-  // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ]
+  // CHECK: extensions: [ [SPV_KHR_16bit_storage] ]
   // CHECK: capabilities: [ [DotProduct] [DotProductInput4x8Bit] ]
   %r = spirv.SDot %a, %a: vector<4xi8> -> i64
   return %r: i64
@@ -78,7 +78,7 @@ func.func @sdot_vector_4xi8_i64(%a: vector<4xi8>) -> i64 {
 func.func @sdot_vector_4xi16_i64(%a: vector<4xi16>) -> i64 {
   // CHECK: min version: v1.0
   // CHECK: max version: v1.6
-  // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ]
+  // CHECK: extensions: [ [SPV_KHR_16bit_storage] ]
   // CHECK: capabilities: [ [DotProduct] [DotProductInputAll] ]
   %r = spirv.SDot %a, %a: vector<4xi16> -> i64
   return %r: i64
@@ -88,7 +88,7 @@ func.func @sdot_vector_4xi16_i64(%a: vector<4xi16>) -> i64 {
 func.func @sudot_scalar_i32_i32(%a: i32) -> i32 {
   // CHECK: min version: v1.0
   // CHECK: max version: v1.6
-  // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ]
+  // CHECK: extensions: [ [SPV_KHR_16bit_storage] ]
   // CHECK: capabilities: [ [DotProduct] [DotProductInput4x8BitPacked] ]
   %r = spirv.SUDot %a, %a, <PackedVectorFormat4x8Bit>: i32 -> i32
   return %r: i32
@@ -98,7 +98,7 @@ func.func @sudot_scalar_i32_i32(%a: i32) -> i32 {
 func.func @sudot_vector_4xi8_i64(%a: vector<4xi8>) -> i64 {
   // CHECK: min version: v1.0
   // CHECK: max version: v1.6
-  // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ]
+  // CHECK: extensions: [ [SPV_KHR_16bit_storage] ]
   // CHECK: capabilities: [ [DotProduct] [DotProductInput4x8Bit] ]
   %r = spirv.SUDot %a, %a: vector<4xi8> -> i64
   return %r: i64
@@ -108,7 +108,7 @@ func.func @sudot_vector_4xi8_i64(%a: vector<4xi8>) -> i64 {
 func.func @sudot_vector_4xi16_i64(%a: vector<4xi16>) -> i64 {
   // CHECK: min version: v1.0
   // CHECK: max version: v1.6
-  // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ]
+  // CHECK: extensions: [ [SPV_KHR_16bit_storage] ]
   // CHECK: capabilities: [ [DotProduct] [DotProductInputAll] ]
   %r = spirv.SUDot %a, %a: vector<4xi16> -> i64
   return %r: i64
@@ -118,7 +118,7 @@ func.func @sudot_vector_4xi16_i64(%a: vector<4xi16>) -> i64 {
 func.func @udot_scalar_i32_i32(%a: i32) -> i32 {
   // CHECK: min version: v1.0
   // CHECK: max version: v1.6
-  // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ]
+  // CHECK: extensions: [ [SPV_KHR_16bit_storage] ]
   // CHECK: capabilities: [ [DotProduct] [DotProductInput4x8BitPacked] ]
   %r = spirv.UDot %a, %a, <PackedVectorFormat4x8Bit>: i32 -> i32
   return %r: i32
@@ -128,7 +128,7 @@ func.func @udot_scalar_i32_i32(%a: i32) -> i32 {
 func.func @udot_vector_4xi8_i64(%a: vector<4xi8>) -> i64 {
   // CHECK: min version: v1.0
   // CHECK: max version: v1.6
-  // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ]
+  // CHECK: extensions: [ [SPV_KHR_16bit_storage] ]
   // CHECK: capabilities: [ [DotProduct] [DotProductInput4x8Bit] ]
   %r = spirv.UDot %a, %a: vector<4xi8> -> i64
   return %r: i64
@@ -138,7 +138,7 @@ func.func @udot_vector_4xi8_i64(%a: vector<4xi8>) -> i64 {
 func.func @udot_vector_4xi16_i64(%a: vector<4xi16>) -> i64 {
   // CHECK: min version: v1.0
   // CHECK: max version: v1.6
-  // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ]
+  // CHECK: extensions: [ [SPV_KHR_16bit_storage] ]
   // CHECK: capabilities: [ [DotProduct] [DotProductInputAll] ]
   %r = spirv.UDot %a, %a: vector<4xi16> -> i64
   return %r: i64
@@ -148,7 +148,7 @@ func.func @udot_vector_4xi16_i64(%a: vector<4xi16>) -> i64 {
 func.func @sdot_acc_sat_scalar_i32_i32(%a: i32) -> i32 {
   // CHECK: min version: v1.0
   // CHECK: max version: v1.6
-  // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ]
+  // CHECK: extensions: [ [SPV_KHR_16bit_storage] ]
   // CHECK: capabilities: [ [DotProduct] [DotProductInput4x8BitPacked] ]
   %r = spirv.SDotAccSat %a, %a, %a, <PackedVectorFormat4x8Bit>: i32 -> i32
   return %r: i32
@@ -158,7 +158,7 @@ func.func @sdot_acc_sat_scalar_i32_i32(%a: i32) -> i32 {
 func.func @sdot_acc_sat_vector_4xi8_i64(%a: vector<4xi8>, %acc: i64) -> i64 {
   // CHECK: min version: v1.0
   // CHECK: max version: v1.6
-  // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ]
+  // CHECK: extensions: [ [SPV_KHR_16bit_storage] ]
   // CHECK: capabilities: [ [DotProduct] [DotProductInput4x8Bit] ]
   %r = spirv.SDotAccSat %a, %a, %acc: vector<4xi8> -> i64
   return %r: i64
@@ -168,7 +168,7 @@ func.func @sdot_acc_sat_vector_4xi8_i64(%a: vector<4xi8>, %acc: i64) -> i64 {
 func.func @sdot_acc_sat_vector_4xi16_i64(%a: vector<4xi16>, %acc: i64) -> i64 {
   // CHECK: min version: v1.0
   // CHECK: max version: v1.6
-  // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ]
+  // CHECK: extensions: [ [SPV_KHR_16bit_storage] ]
   // CHECK: capabilities: [ [DotProduct] [DotProductInputAll] ]
   %r = spirv.SDotAccSat %a, %a, %acc: vector<4xi16> -> i64
   return %r: i64
@@ -178,7 +178,7 @@ func.func @sdot_acc_sat_vector_4xi16_i64(%a: vector<4xi16>, %acc: i64) -> i64 {
 func.func @sudot_acc_sat_scalar_i32_i32(%a: i32) -> i32 {
   // CHECK: min version: v1.0
   // CHECK: max version: v1.6
-  // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ]
+  // CHECK: extensions: [ [SPV_KHR_16bit_storage] ]
   // CHECK: capabilities: [ [DotProduct] [DotProductInput4x8BitPacked] ]
   %r = spirv.SUDotAccSat %a, %a, %a, <PackedVectorFormat4x8Bit>: i32 -> i32
   return %r: i32
@@ -188,7 +188,7 @@ func.func @sudot_acc_sat_scalar_i32_i32(%a: i32) -> i32 {
 func.func @sudot_acc_sat_vector_4xi8_i64(%a: vector<4xi8>, %acc: i64) -> i64 {
   // CHECK: min version: v1.0
   // CHECK: max version: v1.6
-  // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ]
+  // CHECK: extensions: [ [SPV_KHR_16bit_storage] ]
   // CHECK: capabilities: [ [DotProduct] [DotProductInput4x8Bit] ]
   %r = spirv.SUDotAccSat %a, %a, %acc: vector<4xi8> -> i64
   return %r: i64
@@ -198,7 +198,7 @@ func.func @sudot_acc_sat_vector_4xi8_i64(%a: vector<4xi8>, %acc: i64) -> i64 {
 func.func @sudot_acc_sat_vector_4xi16_i64(%a: vector<4xi16>, %acc: i64) -> i64 {
   // CHECK: min version: v1.0
   // CHECK: max version: v1.6
-  // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ]
+  // CHECK: extensions: [ [SPV_KHR_16bit_storage] ]
   // CHECK: capabilities: [ [DotProduct] [DotProductInputAll] ]
   %r = spirv.SUDotAccSat %a, %a, %acc: vector<4xi16> -> i64
   return %r: i64
@@ -208,7 +208,7 @@ func.func @sudot_acc_sat_vector_4xi16_i64(%a: vector<4xi16>, %acc: i64) -> i64 {
 func.func @udot_acc_sat_scalar_i32_i32(%a: i32) -> i32 {
   // CHECK: min version: v1.0
   // CHECK: max version: v1.6
-  // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ]
+  // CHECK: extensions: [ [SPV_KHR_16bit_storage] ]
   // CHECK: capabilities: [ [DotProduct] [DotProductInput4x8BitPacked] ]
   %r = spirv.UDotAccSat %a, %a, %a, <PackedVectorFormat4x8Bit>: i32 -> i32
   return %r: i32
@@ -218,7 +218,7 @@ func.func @udot_acc_sat_scalar_i32_i32(%a: i32) -> i32 {
 func.func @udot_acc_sat_vector_4xi8_i64(%a: vector<4xi8>, %acc: i64) -> i64 {
   // CHECK: min version: v1.0
   // CHECK: max version: v1.6
-  // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ]
+  // CHECK: extensions: [ [SPV_KHR_16bit_storage] ]
   // CHECK: capabilities: [ [DotProduct] [DotProductInput4x8Bit] ]
   %r = spirv.UDotAccSat %a, %a, %acc: vector<4xi8> -> i64
   return %r: i64
@@ -228,7 +228,7 @@ func.func @udot_acc_sat_vector_4xi8_i64(%a: vector<4xi8>, %acc: i64) -> i64 {
 func.func @udot_acc_sat_vector_4xi16_i64(%a: vector<4xi16>, %acc: i64) -> i64 {
   // CHECK: min version: v1.0
   // CHECK: max version: v1.6
-  // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ]
+  // CHECK: extensions: [ [SPV_KHR_16bit_storage] ]
   // CHECK: capabilities: [ [DotProduct] [DotProductInputAll] ]
   %r = spirv.UDotAccSat %a, %a, %acc: vector<4xi16> -> i64
   return %r: i64
diff --git a/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir b/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir
index f90111b4c88618..f75f8f8489efc1 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir
@@ -7,15 +7,13 @@ func.func @vector_transfer_ops_0d_memref(%mem: memref<f32>, %vec: vector<1x1x1xf
     %f0 = arith.constant 0.0 : f32
 
 //  CHECK-NEXT:   %[[S:.*]] = memref.load %[[MEM]][] : memref<f32>
-//  CHECK-NEXT:   %[[V:.*]] = vector.broadcast %[[S]] : f32 to vector<f32>
     %0 = vector.transfer_read %mem[], %f0 : memref<f32>, vector<f32>
 
-//  CHECK-NEXT:   %[[SS:.*]] = vector.extractelement %[[V]][] : vector<f32>
-//  CHECK-NEXT:   memref.store %[[SS]], %[[MEM]][] : memref<f32>
+//  CHECK-NEXT:   memref.store %[[S]], %[[MEM]][] : memref<f32>
     vector.transfer_write %0, %mem[] : vector<f32>, memref<f32>
 
-//  CHECK-NEXT:   %[[VV:.*]] = vector.extract %arg1[0, 0, 0] : f32 from vector<1x1x1xf32>
-//  CHECK-NEXT:   memref.store %[[VV]], %[[MEM]][] : memref<f32>
+//  CHECK-NEXT:   %[[V:.*]] = vector.extract %arg1[0, 0, 0] : f32 from vector<1x1x1xf32>
+//  CHECK-NEXT:   memref.store %[[V]], %[[MEM]][] : memref<f32>
     vector.store %vec, %mem[] : memref<f32>, vector<1x1x1xf32>
 
     return

@harrisonGPU harrisonGPU marked this pull request as ready for review November 6, 2024 01:47
Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LG, just a couple of comments. Thanks!

%f0 = arith.constant 0.0 : f32

// CHECK-NEXT: %[[S:.*]] = memref.load %[[MEM]][] : memref<f32>
// CHECK-NEXT: %[[V:.*]] = vector.broadcast %[[S]] : f32 to vector<f32>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where is this broadcast generated? Perhaps that's something that could deleted?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because vector::ExtractElementOp requires a vector type input, we need to use vector.broadcast to convert scalars into vectors. We directly use vector::ExtractOp, and even for 0-dimensional vectors, there's no longer a need to perform vector.broadcast. :-)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This makes sense, but what code-path generated vector.broadcast to begin with? And how do we know it doesn't require updating?

extracted = rewriter.create<vector::ExtractOp>(
storeOp.getLoc(), storeOp.getValueToStore(), indices);
}
SmallVector<int64_t> indices(vecType.getRank(), 0);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The comment above this pattern needs updating.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, I have already updated it.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The comment is still referring to vector.extractelement, which your patch gets rid of the vector.extractelement path.

@harrisonGPU harrisonGPU marked this pull request as draft November 6, 2024 15:59
Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is marked as "draft" again.

Please, could you decide whether this is "draft" or not? Without additional comments, this is very confusing to the reviewers. And not good use of time.

extracted = rewriter.create<vector::ExtractOp>(
storeOp.getLoc(), storeOp.getValueToStore(), indices);
}
SmallVector<int64_t> indices(vecType.getRank(), 0);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The comment is still referring to vector.extractelement, which your patch gets rid of the vector.extractelement path.

%f0 = arith.constant 0.0 : f32

// CHECK-NEXT: %[[S:.*]] = memref.load %[[MEM]][] : memref<f32>
// CHECK-NEXT: %[[V:.*]] = vector.broadcast %[[S]] : f32 to vector<f32>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This makes sense, but what code-path generated vector.broadcast to begin with? And how do we know it doesn't require updating?

@harrisonGPU
Copy link
Contributor Author

harrisonGPU commented Nov 7, 2024

This is marked as "draft" again.

Please, could you decide whether this is "draft" or not? Without additional comments, this is very confusing to the reviewers. And not good use of time.

Sorry Andrzej, because I think I misunderstand this issue, so I need some time to think about it, so sorry. :-)

@banach-space
Copy link
Contributor

This is marked as "draft" again.
Please, could you decide whether this is "draft" or not? Without additional comments, this is very confusing to the reviewers. And not good use of time.

Sorry Andrzej, because I think I misunderstand this issue, so I need some time to think about it, so sorry. :-)

I understand :) For clarity in the future, it would help us reviewers if you could leave a quick note or a status update when a PR moves back to "draft" after being initially ready for review. This will ensure that everyone is on the same page and can use their time efficiently.

If there's anything specific you’re stuck on, feel free to share! Happy to help if possible.

Copy link
Contributor

@dcaballe dcaballe left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LG % ongoing comments. Thanks!

Copy link
Contributor

@dcaballe dcaballe left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I didn't see the "Draft" state. What is missing?

Copy link
Member

@Groverkss Groverkss left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the patch. I think there is a misunderstanding of the issue linked. The issue is about removing this pattern. You can see the comment on VectorLoadToMemrefLoad:

/// Replace a 0-d vector.load with a memref.load + vector.broadcast.
// TODO: we shouldn't cross the vector/scalar domains just for this
// but atm we lack the infra to avoid it. Possible solutions include:
// - go directly to LLVM + bitcast
// - introduce a bitcast op and likely a new pointer dialect
// - let memref.load/store additionally support the 0-d vector case
// There are still deeper data layout issues lingering even in this
// trivial case (for architectures for which this matters).

Support non 0-d vectors in this pattern doesn't really get us anything. We want to instead verify if the TransferRead -> vector.load -> LLVM/SPIRV conversions are robust enough to handle 0-d vectorsa and fix them.

@harrisonGPU
Copy link
Contributor Author

LG % ongoing comments. Thanks!

Hi, Diego, I misunderstood this issue, so I will update this PR. Thank you for reviewing this PR. :)

@harrisonGPU
Copy link
Contributor Author

This is marked as "draft" again.
Please, could you decide whether this is "draft" or not? Without additional comments, this is very confusing to the reviewers. And not good use of time.

Sorry Andrzej, because I think I misunderstand this issue, so I need some time to think about it, so sorry. :-)

I understand :) For clarity in the future, it would help us reviewers if you could leave a quick note or a status update when a PR moves back to "draft" after being initially ready for review. This will ensure that everyone is on the same page and can use their time efficiently.

If there's anything specific you’re stuck on, feel free to share! Happy to help if possible.

Thanks Andrzej, I will update this PR again. :-)

@harrisonGPU
Copy link
Contributor Author

Thank you for the patch. I think there is a misunderstanding of the issue linked. The issue is about removing this pattern. You can see the comment on VectorLoadToMemrefLoad:

/// Replace a 0-d vector.load with a memref.load + vector.broadcast.
// TODO: we shouldn't cross the vector/scalar domains just for this
// but atm we lack the infra to avoid it. Possible solutions include:
// - go directly to LLVM + bitcast
// - introduce a bitcast op and likely a new pointer dialect
// - let memref.load/store additionally support the 0-d vector case
// There are still deeper data layout issues lingering even in this
// trivial case (for architectures for which this matters).

Support non 0-d vectors in this pattern doesn't really get us anything. We want to instead verify if the TransferRead -> vector.load -> LLVM/SPIRV conversions are robust enough to handle 0-d vectorsa and fix them.

Okay, I will consider and verify it. :-)

@harrisonGPU harrisonGPU changed the title [mlir][Vector] Support 0-d vectors natively in VectorStoreToMemrefStoreLowering. [mlir][Vector] Remove 0-d corner case condition. Nov 10, 2024
@Groverkss
Copy link
Member

Nice, the patch is in the right direction. Can you:

  • Since the 0-d case is supported now, can you remove the VectorLoadToMemrefLoad/VectorStoreToMemrefStore patterns? This will give you a more accurate view of what needs to be updated.
  • Can you check if the test coverage for 0-d transfer_read/transfer_write is enough? If it is enough, please mention it in the PR description and point it out. Otherwise, can you add more tests related to it?

@harrisonGPU harrisonGPU deleted the harrison/mlir branch August 22, 2025 16:36
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[mlir][Vector] Improve support for 0-d vectors in vector dialect lowerings

5 participants