Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMemoryOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ def SPIRV_LoadOp : SPIRV_Op<"Load", []> {
let arguments = (ins
SPIRV_AnyPtr:$ptr,
OptionalAttr<SPIRV_MemoryAccessAttr>:$memory_access,
OptionalAttr<I32Attr>:$alignment
OptionalAttr<IntValidAlignment<I32Attr>>:$alignment
);

let results = (outs
Expand Down Expand Up @@ -345,7 +345,7 @@ def SPIRV_StoreOp : SPIRV_Op<"Store", []> {
SPIRV_AnyPtr:$ptr,
SPIRV_Type:$value,
OptionalAttr<SPIRV_MemoryAccessAttr>:$memory_access,
OptionalAttr<I32Attr>:$alignment
OptionalAttr<IntValidAlignment<I32Attr>>:$alignment
);

let results = (outs);
Expand Down
40 changes: 37 additions & 3 deletions mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -743,6 +743,22 @@ struct VectorLoadOpConverter final

auto vectorPtrType = spirv::PointerType::get(spirvVectorType, storageClass);

std::optional<uint64_t> alignment = loadOp.getAlignment();
if (alignment > std::numeric_limits<uint32_t>::max()) {
return rewriter.notifyMatchFailure(loadOp,
"invalid alignment requirement");
}

auto memoryAccess = spirv::MemoryAccess::None;
spirv::MemoryAccessAttr memoryAccessAttr;
IntegerAttr alignmentAttr = nullptr;
if (alignment.has_value()) {
memoryAccess = memoryAccess | spirv::MemoryAccess::Aligned;
memoryAccessAttr =
spirv::MemoryAccessAttr::get(rewriter.getContext(), memoryAccess);
alignmentAttr = rewriter.getI32IntegerAttr(alignment.value());
}

// For single element vectors, we don't need to bitcast the access chain to
// the original vector type. Both is going to be the same, a pointer
// to a scalar.
Expand All @@ -753,7 +769,8 @@ struct VectorLoadOpConverter final
accessChain);

rewriter.replaceOpWithNewOp<spirv::LoadOp>(loadOp, spirvVectorType,
castedAccessChain);
castedAccessChain,
memoryAccessAttr, alignmentAttr);

return success();
}
Expand Down Expand Up @@ -782,6 +799,12 @@ struct VectorStoreOpConverter final
return rewriter.notifyMatchFailure(
storeOp, "failed to get memref element pointer");

std::optional<uint64_t> alignment = storeOp.getAlignment();
if (alignment > std::numeric_limits<uint32_t>::max()) {
return rewriter.notifyMatchFailure(storeOp,
"invalid alignment requirement");
}

spirv::StorageClass storageClass = attr.getValue();
auto vectorType = storeOp.getVectorType();
auto vectorPtrType = spirv::PointerType::get(vectorType, storageClass);
Expand All @@ -795,8 +818,19 @@ struct VectorStoreOpConverter final
: spirv::BitcastOp::create(rewriter, loc, vectorPtrType,
accessChain);

rewriter.replaceOpWithNewOp<spirv::StoreOp>(storeOp, castedAccessChain,
adaptor.getValueToStore());
auto memoryAccess = spirv::MemoryAccess::None;
spirv::MemoryAccessAttr memoryAccessAttr;
IntegerAttr alignmentAttr = nullptr;
if (alignment.has_value()) {
memoryAccess = memoryAccess | spirv::MemoryAccess::Aligned;
memoryAccessAttr =
spirv::MemoryAccessAttr::get(rewriter.getContext(), memoryAccess);
alignmentAttr = rewriter.getI32IntegerAttr(alignment.value());
}

rewriter.replaceOpWithNewOp<spirv::StoreOp>(
storeOp, castedAccessChain, adaptor.getValueToStore(), memoryAccessAttr,
alignmentAttr);

return success();
}
Expand Down
17 changes: 17 additions & 0 deletions mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -953,6 +953,14 @@ func.func @vector_load_single_elem(%arg0 : memref<4xf32, #spirv.storage_class<St
return %0: vector<1xf32>
}

// CHECK-LABEL: @vector_load_aligned
func.func @vector_load_aligned(%arg0 : memref<4xf32, #spirv.storage_class<StorageBuffer>>) -> vector<4xf32> {
%idx = arith.constant 0 : index
// CHECK: spirv.Load
// CHECK-SAME: ["Aligned", 8]
%0 = vector.load %arg0[%idx] { alignment = 8 } : memref<4xf32, #spirv.storage_class<StorageBuffer>>, vector<4xf32>
return %0: vector<4xf32>
}

// CHECK-LABEL: @vector_load_2d
// CHECK-SAME: (%[[ARG0:.*]]: memref<4x4xf32, #spirv.storage_class<StorageBuffer>>) -> vector<4xf32> {
Expand Down Expand Up @@ -996,6 +1004,15 @@ func.func @vector_store(%arg0 : memref<4xf32, #spirv.storage_class<StorageBuffer
return
}

// CHECK-LABEL: @vector_store_aligned
func.func @vector_store_aligned(%arg0 : memref<4xf32, #spirv.storage_class<StorageBuffer>>, %arg1 : vector<4xf32>) {
%idx = arith.constant 0 : index
// CHECK: spirv.Store
// CHECK-SAME: ["Aligned", 8]
vector.store %arg1, %arg0[%idx] { alignment = 8 } : memref<4xf32, #spirv.storage_class<StorageBuffer>>, vector<4xf32>
return
}

// CHECK-LABEL: @vector_store_single_elem
// CHECK-SAME: (%[[ARG0:.*]]: memref<4xf32, #spirv.storage_class<StorageBuffer>>
// CHECK-SAME: %[[ARG1:.*]]: vector<1xf32>
Expand Down
43 changes: 43 additions & 0 deletions mlir/test/Dialect/SPIRV/IR/invalid.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
// RUN: mlir-opt --split-input-file --verify-diagnostics %s

//===----------------------------------------------------------------------===//
// spirv.LoadOp
//===----------------------------------------------------------------------===//

func.func @aligned_load_non_positive() -> () {
%0 = spirv.Variable : !spirv.ptr<f32, Function>
// expected-error@below {{'spirv.Load' op attribute 'alignment' failed to satisfy constraint: 32-bit signless integer attribute whose value is positive and whose value is a power of two > 0}}
%1 = spirv.Load "Function" %0 ["Aligned", 0] : f32
return
}

// -----

func.func @aligned_load_non_power_of_two() -> () {
%0 = spirv.Variable : !spirv.ptr<f32, Function>
// expected-error@below {{'spirv.Load' op attribute 'alignment' failed to satisfy constraint: 32-bit signless integer attribute whose value is positive and whose value is a power of two > 0}}
%1 = spirv.Load "Function" %0 ["Aligned", 3] : f32
return
}

// -----

//===----------------------------------------------------------------------===//
// spirv.StoreOp
//===----------------------------------------------------------------------===//

func.func @aligned_store_non_positive(%arg0 : f32) -> () {
%0 = spirv.Variable : !spirv.ptr<f32, Function>
// expected-error@below {{'spirv.Store' op attribute 'alignment' failed to satisfy constraint: 32-bit signless integer attribute whose value is positive and whose value is a power of two > 0}}
spirv.Store "Function" %0, %arg0 ["Aligned", 0] : f32
return
}

// -----

func.func @aligned_store_non_power_of_two(%arg0 : f32) -> () {
%0 = spirv.Variable : !spirv.ptr<f32, Function>
// expected-error@below {{'spirv.Store' op attribute 'alignment' failed to satisfy constraint: 32-bit signless integer attribute whose value is positive and whose value is a power of two > 0}}
spirv.Store "Function" %0, %arg0 ["Aligned", 3] : f32
return
}