Skip to content

Commit ffbe9cf

Browse files
amd-eochoalokuhar
andauthored
[mlir][spirv] Propagate alignment requirements from vector to spirv (llvm#155278)
Propagates the alignment attribute from `vector.{load,store}` to `spirv.{load,store}`. --------- Co-authored-by: Jakub Kuderski <[email protected]>
1 parent 30002f2 commit ffbe9cf

File tree

4 files changed

+99
-5
lines changed

4 files changed

+99
-5
lines changed

mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMemoryOps.td

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -220,7 +220,7 @@ def SPIRV_LoadOp : SPIRV_Op<"Load", []> {
220220
let arguments = (ins
221221
SPIRV_AnyPtr:$ptr,
222222
OptionalAttr<SPIRV_MemoryAccessAttr>:$memory_access,
223-
OptionalAttr<I32Attr>:$alignment
223+
OptionalAttr<IntValidAlignment<I32Attr>>:$alignment
224224
);
225225

226226
let results = (outs
@@ -345,7 +345,7 @@ def SPIRV_StoreOp : SPIRV_Op<"Store", []> {
345345
SPIRV_AnyPtr:$ptr,
346346
SPIRV_Type:$value,
347347
OptionalAttr<SPIRV_MemoryAccessAttr>:$memory_access,
348-
OptionalAttr<I32Attr>:$alignment
348+
OptionalAttr<IntValidAlignment<I32Attr>>:$alignment
349349
);
350350

351351
let results = (outs);

mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp

Lines changed: 37 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -743,6 +743,22 @@ struct VectorLoadOpConverter final
743743

744744
auto vectorPtrType = spirv::PointerType::get(spirvVectorType, storageClass);
745745

746+
std::optional<uint64_t> alignment = loadOp.getAlignment();
747+
if (alignment > std::numeric_limits<uint32_t>::max()) {
748+
return rewriter.notifyMatchFailure(loadOp,
749+
"invalid alignment requirement");
750+
}
751+
752+
auto memoryAccess = spirv::MemoryAccess::None;
753+
spirv::MemoryAccessAttr memoryAccessAttr;
754+
IntegerAttr alignmentAttr;
755+
if (alignment.has_value()) {
756+
memoryAccess = memoryAccess | spirv::MemoryAccess::Aligned;
757+
memoryAccessAttr =
758+
spirv::MemoryAccessAttr::get(rewriter.getContext(), memoryAccess);
759+
alignmentAttr = rewriter.getI32IntegerAttr(alignment.value());
760+
}
761+
746762
// For single element vectors, we don't need to bitcast the access chain to
747763
// the original vector type. Both is going to be the same, a pointer
748764
// to a scalar.
@@ -753,7 +769,8 @@ struct VectorLoadOpConverter final
753769
accessChain);
754770

755771
rewriter.replaceOpWithNewOp<spirv::LoadOp>(loadOp, spirvVectorType,
756-
castedAccessChain);
772+
castedAccessChain,
773+
memoryAccessAttr, alignmentAttr);
757774

758775
return success();
759776
}
@@ -782,6 +799,12 @@ struct VectorStoreOpConverter final
782799
return rewriter.notifyMatchFailure(
783800
storeOp, "failed to get memref element pointer");
784801

802+
std::optional<uint64_t> alignment = storeOp.getAlignment();
803+
if (alignment > std::numeric_limits<uint32_t>::max()) {
804+
return rewriter.notifyMatchFailure(storeOp,
805+
"invalid alignment requirement");
806+
}
807+
785808
spirv::StorageClass storageClass = attr.getValue();
786809
auto vectorType = storeOp.getVectorType();
787810
auto vectorPtrType = spirv::PointerType::get(vectorType, storageClass);
@@ -795,8 +818,19 @@ struct VectorStoreOpConverter final
795818
: spirv::BitcastOp::create(rewriter, loc, vectorPtrType,
796819
accessChain);
797820

798-
rewriter.replaceOpWithNewOp<spirv::StoreOp>(storeOp, castedAccessChain,
799-
adaptor.getValueToStore());
821+
auto memoryAccess = spirv::MemoryAccess::None;
822+
spirv::MemoryAccessAttr memoryAccessAttr;
823+
IntegerAttr alignmentAttr;
824+
if (alignment.has_value()) {
825+
memoryAccess = memoryAccess | spirv::MemoryAccess::Aligned;
826+
memoryAccessAttr =
827+
spirv::MemoryAccessAttr::get(rewriter.getContext(), memoryAccess);
828+
alignmentAttr = rewriter.getI32IntegerAttr(alignment.value());
829+
}
830+
831+
rewriter.replaceOpWithNewOp<spirv::StoreOp>(
832+
storeOp, castedAccessChain, adaptor.getValueToStore(), memoryAccessAttr,
833+
alignmentAttr);
800834

801835
return success();
802836
}

mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -953,6 +953,14 @@ func.func @vector_load_single_elem(%arg0 : memref<4xf32, #spirv.storage_class<St
953953
return %0: vector<1xf32>
954954
}
955955

956+
// CHECK-LABEL: @vector_load_aligned
957+
func.func @vector_load_aligned(%arg0 : memref<4xf32, #spirv.storage_class<StorageBuffer>>) -> vector<4xf32> {
958+
%idx = arith.constant 0 : index
959+
// CHECK: spirv.Load
960+
// CHECK-SAME: ["Aligned", 8]
961+
%0 = vector.load %arg0[%idx] { alignment = 8 } : memref<4xf32, #spirv.storage_class<StorageBuffer>>, vector<4xf32>
962+
return %0: vector<4xf32>
963+
}
956964

957965
// CHECK-LABEL: @vector_load_2d
958966
// CHECK-SAME: (%[[ARG0:.*]]: memref<4x4xf32, #spirv.storage_class<StorageBuffer>>) -> vector<4xf32> {
@@ -996,6 +1004,15 @@ func.func @vector_store(%arg0 : memref<4xf32, #spirv.storage_class<StorageBuffer
9961004
return
9971005
}
9981006

1007+
// CHECK-LABEL: @vector_store_aligned
1008+
func.func @vector_store_aligned(%arg0 : memref<4xf32, #spirv.storage_class<StorageBuffer>>, %arg1 : vector<4xf32>) {
1009+
%idx = arith.constant 0 : index
1010+
// CHECK: spirv.Store
1011+
// CHECK-SAME: ["Aligned", 8]
1012+
vector.store %arg1, %arg0[%idx] { alignment = 8 } : memref<4xf32, #spirv.storage_class<StorageBuffer>>, vector<4xf32>
1013+
return
1014+
}
1015+
9991016
// CHECK-LABEL: @vector_store_single_elem
10001017
// CHECK-SAME: (%[[ARG0:.*]]: memref<4xf32, #spirv.storage_class<StorageBuffer>>
10011018
// CHECK-SAME: %[[ARG1:.*]]: vector<1xf32>
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
// RUN: mlir-opt --split-input-file --verify-diagnostics %s
2+
3+
//===----------------------------------------------------------------------===//
4+
// spirv.LoadOp
5+
//===----------------------------------------------------------------------===//
6+
7+
func.func @aligned_load_non_positive() -> () {
8+
%0 = spirv.Variable : !spirv.ptr<f32, Function>
9+
// expected-error@below {{'spirv.Load' op attribute 'alignment' failed to satisfy constraint: 32-bit signless integer attribute whose value is positive and whose value is a power of two > 0}}
10+
%1 = spirv.Load "Function" %0 ["Aligned", 0] : f32
11+
return
12+
}
13+
14+
// -----
15+
16+
func.func @aligned_load_non_power_of_two() -> () {
17+
%0 = spirv.Variable : !spirv.ptr<f32, Function>
18+
// expected-error@below {{'spirv.Load' op attribute 'alignment' failed to satisfy constraint: 32-bit signless integer attribute whose value is positive and whose value is a power of two > 0}}
19+
%1 = spirv.Load "Function" %0 ["Aligned", 3] : f32
20+
return
21+
}
22+
23+
// -----
24+
25+
//===----------------------------------------------------------------------===//
26+
// spirv.StoreOp
27+
//===----------------------------------------------------------------------===//
28+
29+
func.func @aligned_store_non_positive(%arg0 : f32) -> () {
30+
%0 = spirv.Variable : !spirv.ptr<f32, Function>
31+
// expected-error@below {{'spirv.Store' op attribute 'alignment' failed to satisfy constraint: 32-bit signless integer attribute whose value is positive and whose value is a power of two > 0}}
32+
spirv.Store "Function" %0, %arg0 ["Aligned", 0] : f32
33+
return
34+
}
35+
36+
// -----
37+
38+
func.func @aligned_store_non_power_of_two(%arg0 : f32) -> () {
39+
%0 = spirv.Variable : !spirv.ptr<f32, Function>
40+
// expected-error@below {{'spirv.Store' op attribute 'alignment' failed to satisfy constraint: 32-bit signless integer attribute whose value is positive and whose value is a power of two > 0}}
41+
spirv.Store "Function" %0, %arg0 ["Aligned", 3] : f32
42+
return
43+
}

0 commit comments

Comments
 (0)