From f4185e67f2178e953a5bc0ba1656d828eaeb63fa Mon Sep 17 00:00:00 2001 From: Erick Ochoa Date: Mon, 25 Aug 2025 11:25:59 -0700 Subject: [PATCH 1/4] [mlir][spirv] Constraint alignment attribute --- .../mlir/Dialect/SPIRV/IR/SPIRVMemoryOps.td | 4 +- mlir/test/Dialect/SPIRV/IR/invalid.mlir | 43 +++++++++++++++++++ 2 files changed, 45 insertions(+), 2 deletions(-) create mode 100644 mlir/test/Dialect/SPIRV/IR/invalid.mlir diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMemoryOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMemoryOps.td index aad50175546a5..6253601a7c2b2 100644 --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMemoryOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMemoryOps.td @@ -220,7 +220,7 @@ def SPIRV_LoadOp : SPIRV_Op<"Load", []> { let arguments = (ins SPIRV_AnyPtr:$ptr, OptionalAttr:$memory_access, - OptionalAttr:$alignment + OptionalAttr>:$alignment ); let results = (outs @@ -345,7 +345,7 @@ def SPIRV_StoreOp : SPIRV_Op<"Store", []> { SPIRV_AnyPtr:$ptr, SPIRV_Type:$value, OptionalAttr:$memory_access, - OptionalAttr:$alignment + OptionalAttr>:$alignment ); let results = (outs); diff --git a/mlir/test/Dialect/SPIRV/IR/invalid.mlir b/mlir/test/Dialect/SPIRV/IR/invalid.mlir new file mode 100644 index 0000000000000..72eb9883a6538 --- /dev/null +++ b/mlir/test/Dialect/SPIRV/IR/invalid.mlir @@ -0,0 +1,43 @@ +// RUN: mlir-opt -split-input-file -verify-diagnostics %s + +//===----------------------------------------------------------------------===// +// spirv.LoadOp +//===----------------------------------------------------------------------===// + +func.func @aligned_load_non_positive() -> () { + %0 = spirv.Variable : !spirv.ptr + // expected-error@below {{'spirv.Load' op attribute 'alignment' failed to satisfy constraint: 32-bit signless integer attribute whose value is positive and whose value is a power of two > 0}} + %1 = spirv.Load "Function" %0 ["Aligned", 0] : f32 + return +} + +// ----- + +func.func @aligned_load_non_power_of_two() -> () { + %0 = spirv.Variable : !spirv.ptr + // expected-error@below {{'spirv.Load' op attribute 'alignment' failed to satisfy constraint: 32-bit signless integer attribute whose value is positive and whose value is a power of two > 0}} + %1 = spirv.Load "Function" %0 ["Aligned", 3] : f32 + return +} + +// ----- + +//===----------------------------------------------------------------------===// +// spirv.StoreOp +//===----------------------------------------------------------------------===// + +func.func @aligned_store_non_positive(%arg0 : f32) -> () { + %0 = spirv.Variable : !spirv.ptr + // expected-error@below {{'spirv.Store' op attribute 'alignment' failed to satisfy constraint: 32-bit signless integer attribute whose value is positive and whose value is a power of two > 0}} + spirv.Store "Function" %0, %arg0 ["Aligned", 0] : f32 + return +} + +// ----- + +func.func @aligned_store_non_power_of_two(%arg0 : f32) -> () { + %0 = spirv.Variable : !spirv.ptr + // expected-error@below {{'spirv.Store' op attribute 'alignment' failed to satisfy constraint: 32-bit signless integer attribute whose value is positive and whose value is a power of two > 0}} + spirv.Store "Function" %0, %arg0 ["Aligned", 3] : f32 + return +} From 4d6453ca5de033597898a96d1d3b9d2a9fe4a8bc Mon Sep 17 00:00:00 2001 From: Erick Ochoa Date: Mon, 25 Aug 2025 11:10:19 -0700 Subject: [PATCH 2/4] [mlir] Propagate alignment attribute in VectorToSPIRV. --- .../VectorToSPIRV/VectorToSPIRV.cpp | 41 +++++++++++++++++-- .../VectorToSPIRV/vector-to-spirv.mlir | 17 ++++++++ 2 files changed, 55 insertions(+), 3 deletions(-) diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp index a4be7d4bb5473..e6fdb800a017c 100644 --- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp +++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp @@ -743,6 +743,23 @@ struct VectorLoadOpConverter final auto vectorPtrType = spirv::PointerType::get(spirvVectorType, storageClass); + auto alignment = loadOp.getAlignment(); + if (alignment.has_value() && + alignment > std::numeric_limits::max()) { + return rewriter.notifyMatchFailure(loadOp, + "invalid alignment requirement"); + } + + auto memoryAccess = spirv::MemoryAccess::None; + auto memoryAccessAttr = spirv::MemoryAccessAttr{}; + IntegerAttr alignmentAttr = nullptr; + if (alignment.has_value()) { + memoryAccess = memoryAccess | spirv::MemoryAccess::Aligned; + memoryAccessAttr = + spirv::MemoryAccessAttr::get(rewriter.getContext(), memoryAccess); + alignmentAttr = rewriter.getI32IntegerAttr(alignment.value()); + } + // For single element vectors, we don't need to bitcast the access chain to // the original vector type. Both is going to be the same, a pointer // to a scalar. @@ -753,7 +770,8 @@ struct VectorLoadOpConverter final accessChain); rewriter.replaceOpWithNewOp(loadOp, spirvVectorType, - castedAccessChain); + castedAccessChain, + memoryAccessAttr, alignmentAttr); return success(); } @@ -782,6 +800,12 @@ struct VectorStoreOpConverter final return rewriter.notifyMatchFailure( storeOp, "failed to get memref element pointer"); + auto alignment = storeOp.getAlignment(); + if (alignment && alignment > std::numeric_limits::max()) { + return rewriter.notifyMatchFailure(storeOp, + "invalid alignment requirement"); + } + spirv::StorageClass storageClass = attr.getValue(); auto vectorType = storeOp.getVectorType(); auto vectorPtrType = spirv::PointerType::get(vectorType, storageClass); @@ -795,8 +819,19 @@ struct VectorStoreOpConverter final : spirv::BitcastOp::create(rewriter, loc, vectorPtrType, accessChain); - rewriter.replaceOpWithNewOp(storeOp, castedAccessChain, - adaptor.getValueToStore()); + auto memoryAccess = spirv::MemoryAccess::None; + auto memoryAccessAttr = spirv::MemoryAccessAttr{}; + IntegerAttr alignmentAttr = nullptr; + if (alignment.has_value()) { + memoryAccess = memoryAccess | spirv::MemoryAccess::Aligned; + memoryAccessAttr = + spirv::MemoryAccessAttr::get(rewriter.getContext(), memoryAccess); + alignmentAttr = rewriter.getI32IntegerAttr(alignment.value()); + } + + rewriter.replaceOpWithNewOp( + storeOp, castedAccessChain, adaptor.getValueToStore(), memoryAccessAttr, + alignmentAttr); return success(); } diff --git a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir index 8918f91ef9145..4b56897821dbb 100644 --- a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir +++ b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir @@ -953,6 +953,14 @@ func.func @vector_load_single_elem(%arg0 : memref<4xf32, #spirv.storage_class } +// CHECK-LABEL: @vector_load_aligned +func.func @vector_load_aligned(%arg0 : memref<4xf32, #spirv.storage_class>) -> vector<4xf32> { + %idx = arith.constant 0 : index + // CHECK: spirv.Load + // CHECK-SAME: ["Aligned", 8] + %0 = vector.load %arg0[%idx] { alignment = 8 } : memref<4xf32, #spirv.storage_class>, vector<4xf32> + return %0: vector<4xf32> +} // CHECK-LABEL: @vector_load_2d // CHECK-SAME: (%[[ARG0:.*]]: memref<4x4xf32, #spirv.storage_class>) -> vector<4xf32> { @@ -996,6 +1004,15 @@ func.func @vector_store(%arg0 : memref<4xf32, #spirv.storage_class>, %arg1 : vector<4xf32>) { + %idx = arith.constant 0 : index + // CHECK: spirv.Store + // CHECK-SAME: ["Aligned", 8] + vector.store %arg1, %arg0[%idx] { alignment = 8 } : memref<4xf32, #spirv.storage_class>, vector<4xf32> + return +} + // CHECK-LABEL: @vector_store_single_elem // CHECK-SAME: (%[[ARG0:.*]]: memref<4xf32, #spirv.storage_class> // CHECK-SAME: %[[ARG1:.*]]: vector<1xf32> From d61481d418b725b69328704856a6e8c4e8b16a36 Mon Sep 17 00:00:00 2001 From: Erick Ochoa Date: Thu, 28 Aug 2025 05:34:54 -0700 Subject: [PATCH 3/4] Address review comments --- mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp | 13 ++++++------- mlir/test/Dialect/SPIRV/IR/invalid.mlir | 2 +- 2 files changed, 7 insertions(+), 8 deletions(-) diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp index e6fdb800a017c..3eb07f2549ed2 100644 --- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp +++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp @@ -743,15 +743,14 @@ struct VectorLoadOpConverter final auto vectorPtrType = spirv::PointerType::get(spirvVectorType, storageClass); - auto alignment = loadOp.getAlignment(); - if (alignment.has_value() && - alignment > std::numeric_limits::max()) { + std::optional alignment = loadOp.getAlignment(); + if (alignment > std::numeric_limits::max()) { return rewriter.notifyMatchFailure(loadOp, "invalid alignment requirement"); } auto memoryAccess = spirv::MemoryAccess::None; - auto memoryAccessAttr = spirv::MemoryAccessAttr{}; + spirv::MemoryAccessAttr memoryAccessAttr; IntegerAttr alignmentAttr = nullptr; if (alignment.has_value()) { memoryAccess = memoryAccess | spirv::MemoryAccess::Aligned; @@ -800,8 +799,8 @@ struct VectorStoreOpConverter final return rewriter.notifyMatchFailure( storeOp, "failed to get memref element pointer"); - auto alignment = storeOp.getAlignment(); - if (alignment && alignment > std::numeric_limits::max()) { + std::optional alignment = storeOp.getAlignment(); + if (alignment > std::numeric_limits::max()) { return rewriter.notifyMatchFailure(storeOp, "invalid alignment requirement"); } @@ -820,7 +819,7 @@ struct VectorStoreOpConverter final accessChain); auto memoryAccess = spirv::MemoryAccess::None; - auto memoryAccessAttr = spirv::MemoryAccessAttr{}; + spirv::MemoryAccessAttr memoryAccessAttr; IntegerAttr alignmentAttr = nullptr; if (alignment.has_value()) { memoryAccess = memoryAccess | spirv::MemoryAccess::Aligned; diff --git a/mlir/test/Dialect/SPIRV/IR/invalid.mlir b/mlir/test/Dialect/SPIRV/IR/invalid.mlir index 72eb9883a6538..122b92cca2e9b 100644 --- a/mlir/test/Dialect/SPIRV/IR/invalid.mlir +++ b/mlir/test/Dialect/SPIRV/IR/invalid.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt -split-input-file -verify-diagnostics %s +// RUN: mlir-opt --split-input-file --verify-diagnostics %s //===----------------------------------------------------------------------===// // spirv.LoadOp From 0174c8628729fbf4b205ab4e9c6709bd0ffa9475 Mon Sep 17 00:00:00 2001 From: Erick Ochoa Lopez Date: Thu, 28 Aug 2025 11:26:42 -0400 Subject: [PATCH 4/4] Apply suggestions from code review Co-authored-by: Jakub Kuderski --- mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp | 4 ++-- mlir/test/Dialect/SPIRV/IR/invalid.mlir | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp index 3eb07f2549ed2..036cbad0bcfe8 100644 --- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp +++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp @@ -751,7 +751,7 @@ struct VectorLoadOpConverter final auto memoryAccess = spirv::MemoryAccess::None; spirv::MemoryAccessAttr memoryAccessAttr; - IntegerAttr alignmentAttr = nullptr; + IntegerAttr alignmentAttr; if (alignment.has_value()) { memoryAccess = memoryAccess | spirv::MemoryAccess::Aligned; memoryAccessAttr = @@ -820,7 +820,7 @@ struct VectorStoreOpConverter final auto memoryAccess = spirv::MemoryAccess::None; spirv::MemoryAccessAttr memoryAccessAttr; - IntegerAttr alignmentAttr = nullptr; + IntegerAttr alignmentAttr; if (alignment.has_value()) { memoryAccess = memoryAccess | spirv::MemoryAccess::Aligned; memoryAccessAttr = diff --git a/mlir/test/Dialect/SPIRV/IR/invalid.mlir b/mlir/test/Dialect/SPIRV/IR/invalid.mlir index 122b92cca2e9b..e0100748a0d68 100644 --- a/mlir/test/Dialect/SPIRV/IR/invalid.mlir +++ b/mlir/test/Dialect/SPIRV/IR/invalid.mlir @@ -29,7 +29,7 @@ func.func @aligned_load_non_power_of_two() -> () { func.func @aligned_store_non_positive(%arg0 : f32) -> () { %0 = spirv.Variable : !spirv.ptr // expected-error@below {{'spirv.Store' op attribute 'alignment' failed to satisfy constraint: 32-bit signless integer attribute whose value is positive and whose value is a power of two > 0}} - spirv.Store "Function" %0, %arg0 ["Aligned", 0] : f32 + spirv.Store "Function" %0, %arg0 ["Aligned", 0] : f32 return } @@ -38,6 +38,6 @@ func.func @aligned_store_non_positive(%arg0 : f32) -> () { func.func @aligned_store_non_power_of_two(%arg0 : f32) -> () { %0 = spirv.Variable : !spirv.ptr // expected-error@below {{'spirv.Store' op attribute 'alignment' failed to satisfy constraint: 32-bit signless integer attribute whose value is positive and whose value is a power of two > 0}} - spirv.Store "Function" %0, %arg0 ["Aligned", 3] : f32 + spirv.Store "Function" %0, %arg0 ["Aligned", 3] : f32 return }