Skip to content

Conversation

@amd-eochoalo
Copy link
Contributor

@amd-eochoalo amd-eochoalo commented Aug 25, 2025

Propagates the alignment attribute from vector.{load,store} to spirv.{load,store}.

@amd-eochoalo amd-eochoalo self-assigned this Aug 27, 2025
@amd-eochoalo amd-eochoalo force-pushed the eochoa/2025-08-25/spirv-alignment branch from 9fd4b1e to e4e62f1 Compare August 27, 2025 18:41
@amd-eochoalo amd-eochoalo force-pushed the eochoa/2025-08-25/spirv-alignment branch from e4e62f1 to 4d6453c Compare August 27, 2025 18:45
@amd-eochoalo amd-eochoalo marked this pull request as ready for review August 27, 2025 19:04
@llvmbot
Copy link
Member

llvmbot commented Aug 27, 2025

@llvm/pr-subscribers-mlir-spirv

@llvm/pr-subscribers-mlir

Author: Erick Ochoa Lopez (amd-eochoalo)

Changes

Propagates the alignment attribute from vector.{load,store} to spirv.{load,store}.


Full diff: https://github.com/llvm/llvm-project/pull/155278.diff

4 Files Affected:

  • (modified) mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMemoryOps.td (+2-2)
  • (modified) mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp (+38-3)
  • (modified) mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir (+17)
  • (added) mlir/test/Dialect/SPIRV/IR/invalid.mlir (+43)
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMemoryOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMemoryOps.td
index aad50175546a5..6253601a7c2b2 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMemoryOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMemoryOps.td
@@ -220,7 +220,7 @@ def SPIRV_LoadOp : SPIRV_Op<"Load", []> {
   let arguments = (ins
     SPIRV_AnyPtr:$ptr,
     OptionalAttr<SPIRV_MemoryAccessAttr>:$memory_access,
-    OptionalAttr<I32Attr>:$alignment
+    OptionalAttr<IntValidAlignment<I32Attr>>:$alignment
   );
 
   let results = (outs
@@ -345,7 +345,7 @@ def SPIRV_StoreOp : SPIRV_Op<"Store", []> {
     SPIRV_AnyPtr:$ptr,
     SPIRV_Type:$value,
     OptionalAttr<SPIRV_MemoryAccessAttr>:$memory_access,
-    OptionalAttr<I32Attr>:$alignment
+    OptionalAttr<IntValidAlignment<I32Attr>>:$alignment
   );
 
   let results = (outs);
diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index a4be7d4bb5473..e6fdb800a017c 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -743,6 +743,23 @@ struct VectorLoadOpConverter final
 
     auto vectorPtrType = spirv::PointerType::get(spirvVectorType, storageClass);
 
+    auto alignment = loadOp.getAlignment();
+    if (alignment.has_value() &&
+        alignment > std::numeric_limits<uint32_t>::max()) {
+      return rewriter.notifyMatchFailure(loadOp,
+                                         "invalid alignment requirement");
+    }
+
+    auto memoryAccess = spirv::MemoryAccess::None;
+    auto memoryAccessAttr = spirv::MemoryAccessAttr{};
+    IntegerAttr alignmentAttr = nullptr;
+    if (alignment.has_value()) {
+      memoryAccess = memoryAccess | spirv::MemoryAccess::Aligned;
+      memoryAccessAttr =
+          spirv::MemoryAccessAttr::get(rewriter.getContext(), memoryAccess);
+      alignmentAttr = rewriter.getI32IntegerAttr(alignment.value());
+    }
+
     // For single element vectors, we don't need to bitcast the access chain to
     // the original vector type. Both is going to be the same, a pointer
     // to a scalar.
@@ -753,7 +770,8 @@ struct VectorLoadOpConverter final
                                        accessChain);
 
     rewriter.replaceOpWithNewOp<spirv::LoadOp>(loadOp, spirvVectorType,
-                                               castedAccessChain);
+                                               castedAccessChain,
+                                               memoryAccessAttr, alignmentAttr);
 
     return success();
   }
@@ -782,6 +800,12 @@ struct VectorStoreOpConverter final
       return rewriter.notifyMatchFailure(
           storeOp, "failed to get memref element pointer");
 
+    auto alignment = storeOp.getAlignment();
+    if (alignment && alignment > std::numeric_limits<uint32_t>::max()) {
+      return rewriter.notifyMatchFailure(storeOp,
+                                         "invalid alignment requirement");
+    }
+
     spirv::StorageClass storageClass = attr.getValue();
     auto vectorType = storeOp.getVectorType();
     auto vectorPtrType = spirv::PointerType::get(vectorType, storageClass);
@@ -795,8 +819,19 @@ struct VectorStoreOpConverter final
             : spirv::BitcastOp::create(rewriter, loc, vectorPtrType,
                                        accessChain);
 
-    rewriter.replaceOpWithNewOp<spirv::StoreOp>(storeOp, castedAccessChain,
-                                                adaptor.getValueToStore());
+    auto memoryAccess = spirv::MemoryAccess::None;
+    auto memoryAccessAttr = spirv::MemoryAccessAttr{};
+    IntegerAttr alignmentAttr = nullptr;
+    if (alignment.has_value()) {
+      memoryAccess = memoryAccess | spirv::MemoryAccess::Aligned;
+      memoryAccessAttr =
+          spirv::MemoryAccessAttr::get(rewriter.getContext(), memoryAccess);
+      alignmentAttr = rewriter.getI32IntegerAttr(alignment.value());
+    }
+
+    rewriter.replaceOpWithNewOp<spirv::StoreOp>(
+        storeOp, castedAccessChain, adaptor.getValueToStore(), memoryAccessAttr,
+        alignmentAttr);
 
     return success();
   }
diff --git a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
index 8918f91ef9145..4b56897821dbb 100644
--- a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
+++ b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
@@ -953,6 +953,14 @@ func.func @vector_load_single_elem(%arg0 : memref<4xf32, #spirv.storage_class<St
   return %0: vector<1xf32>
 }
 
+// CHECK-LABEL: @vector_load_aligned
+func.func @vector_load_aligned(%arg0 : memref<4xf32, #spirv.storage_class<StorageBuffer>>) -> vector<4xf32> {
+  %idx = arith.constant 0 : index
+  // CHECK: spirv.Load
+  // CHECK-SAME: ["Aligned", 8]
+  %0 = vector.load %arg0[%idx] { alignment = 8 } : memref<4xf32, #spirv.storage_class<StorageBuffer>>, vector<4xf32>
+  return %0: vector<4xf32>
+}
 
 // CHECK-LABEL: @vector_load_2d
 //  CHECK-SAME: (%[[ARG0:.*]]: memref<4x4xf32, #spirv.storage_class<StorageBuffer>>) -> vector<4xf32> {
@@ -996,6 +1004,15 @@ func.func @vector_store(%arg0 : memref<4xf32, #spirv.storage_class<StorageBuffer
   return
 }
 
+// CHECK-LABEL: @vector_store_aligned
+func.func @vector_store_aligned(%arg0 : memref<4xf32, #spirv.storage_class<StorageBuffer>>, %arg1 : vector<4xf32>) {
+  %idx = arith.constant 0 : index
+  // CHECK: spirv.Store
+  // CHECK-SAME: ["Aligned", 8]
+  vector.store %arg1, %arg0[%idx] { alignment = 8 } : memref<4xf32, #spirv.storage_class<StorageBuffer>>, vector<4xf32>
+  return
+}
+
 // CHECK-LABEL: @vector_store_single_elem
 //  CHECK-SAME: (%[[ARG0:.*]]: memref<4xf32, #spirv.storage_class<StorageBuffer>>
 //  CHECK-SAME:  %[[ARG1:.*]]: vector<1xf32>
diff --git a/mlir/test/Dialect/SPIRV/IR/invalid.mlir b/mlir/test/Dialect/SPIRV/IR/invalid.mlir
new file mode 100644
index 0000000000000..72eb9883a6538
--- /dev/null
+++ b/mlir/test/Dialect/SPIRV/IR/invalid.mlir
@@ -0,0 +1,43 @@
+// RUN: mlir-opt -split-input-file -verify-diagnostics %s 
+
+//===----------------------------------------------------------------------===//
+// spirv.LoadOp
+//===----------------------------------------------------------------------===//
+
+func.func @aligned_load_non_positive() -> () {
+  %0 = spirv.Variable : !spirv.ptr<f32, Function>
+  // expected-error@below {{'spirv.Load' op attribute 'alignment' failed to satisfy constraint: 32-bit signless integer attribute whose value is positive and whose value is a power of two > 0}}
+  %1 = spirv.Load "Function" %0 ["Aligned", 0] : f32
+  return
+}
+
+// -----
+
+func.func @aligned_load_non_power_of_two() -> () {
+  %0 = spirv.Variable : !spirv.ptr<f32, Function>
+  // expected-error@below {{'spirv.Load' op attribute 'alignment' failed to satisfy constraint: 32-bit signless integer attribute whose value is positive and whose value is a power of two > 0}}
+  %1 = spirv.Load "Function" %0 ["Aligned", 3] : f32
+  return
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.StoreOp
+//===----------------------------------------------------------------------===//
+
+func.func @aligned_store_non_positive(%arg0 : f32) -> () {
+  %0 = spirv.Variable : !spirv.ptr<f32, Function>
+  // expected-error@below {{'spirv.Store' op attribute 'alignment' failed to satisfy constraint: 32-bit signless integer attribute whose value is positive and whose value is a power of two > 0}}
+  spirv.Store  "Function" %0, %arg0 ["Aligned", 0] : f32
+  return
+}
+
+// -----
+
+func.func @aligned_store_non_power_of_two(%arg0 : f32) -> () {
+  %0 = spirv.Variable : !spirv.ptr<f32, Function>
+  // expected-error@below {{'spirv.Store' op attribute 'alignment' failed to satisfy constraint: 32-bit signless integer attribute whose value is positive and whose value is a power of two > 0}}
+  spirv.Store  "Function" %0, %arg0 ["Aligned", 3] : f32
+  return
+}

@kuhar kuhar requested review from Hardcode84 and krzysz00 August 28, 2025 00:36
Copy link
Member

@kuhar kuhar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good overall

@amd-eochoalo amd-eochoalo requested a review from kuhar August 28, 2025 13:10
@amd-eochoalo
Copy link
Contributor Author

Thanks @kuhar! I'll remember the coding guidelines for next time. :-)

Copy link
Member

@kuhar kuhar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@amd-eochoalo amd-eochoalo merged commit ffbe9cf into llvm:main Aug 28, 2025
10 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants