-
Notifications
You must be signed in to change notification settings - Fork 15.3k
[mlir][spirv] Propagate alignment requirements from vector to spirv #155278
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][spirv] Propagate alignment requirements from vector to spirv #155278
Conversation
9fd4b1e to
e4e62f1
Compare
e4e62f1 to
4d6453c
Compare
|
@llvm/pr-subscribers-mlir-spirv @llvm/pr-subscribers-mlir Author: Erick Ochoa Lopez (amd-eochoalo) ChangesPropagates the alignment attribute from Full diff: https://github.com/llvm/llvm-project/pull/155278.diff 4 Files Affected:
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMemoryOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMemoryOps.td
index aad50175546a5..6253601a7c2b2 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMemoryOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMemoryOps.td
@@ -220,7 +220,7 @@ def SPIRV_LoadOp : SPIRV_Op<"Load", []> {
let arguments = (ins
SPIRV_AnyPtr:$ptr,
OptionalAttr<SPIRV_MemoryAccessAttr>:$memory_access,
- OptionalAttr<I32Attr>:$alignment
+ OptionalAttr<IntValidAlignment<I32Attr>>:$alignment
);
let results = (outs
@@ -345,7 +345,7 @@ def SPIRV_StoreOp : SPIRV_Op<"Store", []> {
SPIRV_AnyPtr:$ptr,
SPIRV_Type:$value,
OptionalAttr<SPIRV_MemoryAccessAttr>:$memory_access,
- OptionalAttr<I32Attr>:$alignment
+ OptionalAttr<IntValidAlignment<I32Attr>>:$alignment
);
let results = (outs);
diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index a4be7d4bb5473..e6fdb800a017c 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -743,6 +743,23 @@ struct VectorLoadOpConverter final
auto vectorPtrType = spirv::PointerType::get(spirvVectorType, storageClass);
+ auto alignment = loadOp.getAlignment();
+ if (alignment.has_value() &&
+ alignment > std::numeric_limits<uint32_t>::max()) {
+ return rewriter.notifyMatchFailure(loadOp,
+ "invalid alignment requirement");
+ }
+
+ auto memoryAccess = spirv::MemoryAccess::None;
+ auto memoryAccessAttr = spirv::MemoryAccessAttr{};
+ IntegerAttr alignmentAttr = nullptr;
+ if (alignment.has_value()) {
+ memoryAccess = memoryAccess | spirv::MemoryAccess::Aligned;
+ memoryAccessAttr =
+ spirv::MemoryAccessAttr::get(rewriter.getContext(), memoryAccess);
+ alignmentAttr = rewriter.getI32IntegerAttr(alignment.value());
+ }
+
// For single element vectors, we don't need to bitcast the access chain to
// the original vector type. Both is going to be the same, a pointer
// to a scalar.
@@ -753,7 +770,8 @@ struct VectorLoadOpConverter final
accessChain);
rewriter.replaceOpWithNewOp<spirv::LoadOp>(loadOp, spirvVectorType,
- castedAccessChain);
+ castedAccessChain,
+ memoryAccessAttr, alignmentAttr);
return success();
}
@@ -782,6 +800,12 @@ struct VectorStoreOpConverter final
return rewriter.notifyMatchFailure(
storeOp, "failed to get memref element pointer");
+ auto alignment = storeOp.getAlignment();
+ if (alignment && alignment > std::numeric_limits<uint32_t>::max()) {
+ return rewriter.notifyMatchFailure(storeOp,
+ "invalid alignment requirement");
+ }
+
spirv::StorageClass storageClass = attr.getValue();
auto vectorType = storeOp.getVectorType();
auto vectorPtrType = spirv::PointerType::get(vectorType, storageClass);
@@ -795,8 +819,19 @@ struct VectorStoreOpConverter final
: spirv::BitcastOp::create(rewriter, loc, vectorPtrType,
accessChain);
- rewriter.replaceOpWithNewOp<spirv::StoreOp>(storeOp, castedAccessChain,
- adaptor.getValueToStore());
+ auto memoryAccess = spirv::MemoryAccess::None;
+ auto memoryAccessAttr = spirv::MemoryAccessAttr{};
+ IntegerAttr alignmentAttr = nullptr;
+ if (alignment.has_value()) {
+ memoryAccess = memoryAccess | spirv::MemoryAccess::Aligned;
+ memoryAccessAttr =
+ spirv::MemoryAccessAttr::get(rewriter.getContext(), memoryAccess);
+ alignmentAttr = rewriter.getI32IntegerAttr(alignment.value());
+ }
+
+ rewriter.replaceOpWithNewOp<spirv::StoreOp>(
+ storeOp, castedAccessChain, adaptor.getValueToStore(), memoryAccessAttr,
+ alignmentAttr);
return success();
}
diff --git a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
index 8918f91ef9145..4b56897821dbb 100644
--- a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
+++ b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
@@ -953,6 +953,14 @@ func.func @vector_load_single_elem(%arg0 : memref<4xf32, #spirv.storage_class<St
return %0: vector<1xf32>
}
+// CHECK-LABEL: @vector_load_aligned
+func.func @vector_load_aligned(%arg0 : memref<4xf32, #spirv.storage_class<StorageBuffer>>) -> vector<4xf32> {
+ %idx = arith.constant 0 : index
+ // CHECK: spirv.Load
+ // CHECK-SAME: ["Aligned", 8]
+ %0 = vector.load %arg0[%idx] { alignment = 8 } : memref<4xf32, #spirv.storage_class<StorageBuffer>>, vector<4xf32>
+ return %0: vector<4xf32>
+}
// CHECK-LABEL: @vector_load_2d
// CHECK-SAME: (%[[ARG0:.*]]: memref<4x4xf32, #spirv.storage_class<StorageBuffer>>) -> vector<4xf32> {
@@ -996,6 +1004,15 @@ func.func @vector_store(%arg0 : memref<4xf32, #spirv.storage_class<StorageBuffer
return
}
+// CHECK-LABEL: @vector_store_aligned
+func.func @vector_store_aligned(%arg0 : memref<4xf32, #spirv.storage_class<StorageBuffer>>, %arg1 : vector<4xf32>) {
+ %idx = arith.constant 0 : index
+ // CHECK: spirv.Store
+ // CHECK-SAME: ["Aligned", 8]
+ vector.store %arg1, %arg0[%idx] { alignment = 8 } : memref<4xf32, #spirv.storage_class<StorageBuffer>>, vector<4xf32>
+ return
+}
+
// CHECK-LABEL: @vector_store_single_elem
// CHECK-SAME: (%[[ARG0:.*]]: memref<4xf32, #spirv.storage_class<StorageBuffer>>
// CHECK-SAME: %[[ARG1:.*]]: vector<1xf32>
diff --git a/mlir/test/Dialect/SPIRV/IR/invalid.mlir b/mlir/test/Dialect/SPIRV/IR/invalid.mlir
new file mode 100644
index 0000000000000..72eb9883a6538
--- /dev/null
+++ b/mlir/test/Dialect/SPIRV/IR/invalid.mlir
@@ -0,0 +1,43 @@
+// RUN: mlir-opt -split-input-file -verify-diagnostics %s
+
+//===----------------------------------------------------------------------===//
+// spirv.LoadOp
+//===----------------------------------------------------------------------===//
+
+func.func @aligned_load_non_positive() -> () {
+ %0 = spirv.Variable : !spirv.ptr<f32, Function>
+ // expected-error@below {{'spirv.Load' op attribute 'alignment' failed to satisfy constraint: 32-bit signless integer attribute whose value is positive and whose value is a power of two > 0}}
+ %1 = spirv.Load "Function" %0 ["Aligned", 0] : f32
+ return
+}
+
+// -----
+
+func.func @aligned_load_non_power_of_two() -> () {
+ %0 = spirv.Variable : !spirv.ptr<f32, Function>
+ // expected-error@below {{'spirv.Load' op attribute 'alignment' failed to satisfy constraint: 32-bit signless integer attribute whose value is positive and whose value is a power of two > 0}}
+ %1 = spirv.Load "Function" %0 ["Aligned", 3] : f32
+ return
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.StoreOp
+//===----------------------------------------------------------------------===//
+
+func.func @aligned_store_non_positive(%arg0 : f32) -> () {
+ %0 = spirv.Variable : !spirv.ptr<f32, Function>
+ // expected-error@below {{'spirv.Store' op attribute 'alignment' failed to satisfy constraint: 32-bit signless integer attribute whose value is positive and whose value is a power of two > 0}}
+ spirv.Store "Function" %0, %arg0 ["Aligned", 0] : f32
+ return
+}
+
+// -----
+
+func.func @aligned_store_non_power_of_two(%arg0 : f32) -> () {
+ %0 = spirv.Variable : !spirv.ptr<f32, Function>
+ // expected-error@below {{'spirv.Store' op attribute 'alignment' failed to satisfy constraint: 32-bit signless integer attribute whose value is positive and whose value is a power of two > 0}}
+ spirv.Store "Function" %0, %arg0 ["Aligned", 3] : f32
+ return
+}
|
kuhar
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good overall
|
Thanks @kuhar! I'll remember the coding guidelines for next time. :-) |
kuhar
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Co-authored-by: Jakub Kuderski <[email protected]>
Propagates the alignment attribute from
vector.{load,store}tospirv.{load,store}.