Skip to content

Commit 4d6453c

Browse files
committed
[mlir] Propagate alignment attribute in VectorToSPIRV.
1 parent f4185e6 commit 4d6453c

File tree

2 files changed

+55
-3
lines changed

2 files changed

+55
-3
lines changed

mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp

Lines changed: 38 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -743,6 +743,23 @@ struct VectorLoadOpConverter final
743743

744744
auto vectorPtrType = spirv::PointerType::get(spirvVectorType, storageClass);
745745

746+
auto alignment = loadOp.getAlignment();
747+
if (alignment.has_value() &&
748+
alignment > std::numeric_limits<uint32_t>::max()) {
749+
return rewriter.notifyMatchFailure(loadOp,
750+
"invalid alignment requirement");
751+
}
752+
753+
auto memoryAccess = spirv::MemoryAccess::None;
754+
auto memoryAccessAttr = spirv::MemoryAccessAttr{};
755+
IntegerAttr alignmentAttr = nullptr;
756+
if (alignment.has_value()) {
757+
memoryAccess = memoryAccess | spirv::MemoryAccess::Aligned;
758+
memoryAccessAttr =
759+
spirv::MemoryAccessAttr::get(rewriter.getContext(), memoryAccess);
760+
alignmentAttr = rewriter.getI32IntegerAttr(alignment.value());
761+
}
762+
746763
// For single element vectors, we don't need to bitcast the access chain to
747764
// the original vector type. Both is going to be the same, a pointer
748765
// to a scalar.
@@ -753,7 +770,8 @@ struct VectorLoadOpConverter final
753770
accessChain);
754771

755772
rewriter.replaceOpWithNewOp<spirv::LoadOp>(loadOp, spirvVectorType,
756-
castedAccessChain);
773+
castedAccessChain,
774+
memoryAccessAttr, alignmentAttr);
757775

758776
return success();
759777
}
@@ -782,6 +800,12 @@ struct VectorStoreOpConverter final
782800
return rewriter.notifyMatchFailure(
783801
storeOp, "failed to get memref element pointer");
784802

803+
auto alignment = storeOp.getAlignment();
804+
if (alignment && alignment > std::numeric_limits<uint32_t>::max()) {
805+
return rewriter.notifyMatchFailure(storeOp,
806+
"invalid alignment requirement");
807+
}
808+
785809
spirv::StorageClass storageClass = attr.getValue();
786810
auto vectorType = storeOp.getVectorType();
787811
auto vectorPtrType = spirv::PointerType::get(vectorType, storageClass);
@@ -795,8 +819,19 @@ struct VectorStoreOpConverter final
795819
: spirv::BitcastOp::create(rewriter, loc, vectorPtrType,
796820
accessChain);
797821

798-
rewriter.replaceOpWithNewOp<spirv::StoreOp>(storeOp, castedAccessChain,
799-
adaptor.getValueToStore());
822+
auto memoryAccess = spirv::MemoryAccess::None;
823+
auto memoryAccessAttr = spirv::MemoryAccessAttr{};
824+
IntegerAttr alignmentAttr = nullptr;
825+
if (alignment.has_value()) {
826+
memoryAccess = memoryAccess | spirv::MemoryAccess::Aligned;
827+
memoryAccessAttr =
828+
spirv::MemoryAccessAttr::get(rewriter.getContext(), memoryAccess);
829+
alignmentAttr = rewriter.getI32IntegerAttr(alignment.value());
830+
}
831+
832+
rewriter.replaceOpWithNewOp<spirv::StoreOp>(
833+
storeOp, castedAccessChain, adaptor.getValueToStore(), memoryAccessAttr,
834+
alignmentAttr);
800835

801836
return success();
802837
}

mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -953,6 +953,14 @@ func.func @vector_load_single_elem(%arg0 : memref<4xf32, #spirv.storage_class<St
953953
return %0: vector<1xf32>
954954
}
955955

956+
// CHECK-LABEL: @vector_load_aligned
957+
func.func @vector_load_aligned(%arg0 : memref<4xf32, #spirv.storage_class<StorageBuffer>>) -> vector<4xf32> {
958+
%idx = arith.constant 0 : index
959+
// CHECK: spirv.Load
960+
// CHECK-SAME: ["Aligned", 8]
961+
%0 = vector.load %arg0[%idx] { alignment = 8 } : memref<4xf32, #spirv.storage_class<StorageBuffer>>, vector<4xf32>
962+
return %0: vector<4xf32>
963+
}
956964

957965
// CHECK-LABEL: @vector_load_2d
958966
// CHECK-SAME: (%[[ARG0:.*]]: memref<4x4xf32, #spirv.storage_class<StorageBuffer>>) -> vector<4xf32> {
@@ -996,6 +1004,15 @@ func.func @vector_store(%arg0 : memref<4xf32, #spirv.storage_class<StorageBuffer
9961004
return
9971005
}
9981006

1007+
// CHECK-LABEL: @vector_store_aligned
1008+
func.func @vector_store_aligned(%arg0 : memref<4xf32, #spirv.storage_class<StorageBuffer>>, %arg1 : vector<4xf32>) {
1009+
%idx = arith.constant 0 : index
1010+
// CHECK: spirv.Store
1011+
// CHECK-SAME: ["Aligned", 8]
1012+
vector.store %arg1, %arg0[%idx] { alignment = 8 } : memref<4xf32, #spirv.storage_class<StorageBuffer>>, vector<4xf32>
1013+
return
1014+
}
1015+
9991016
// CHECK-LABEL: @vector_store_single_elem
10001017
// CHECK-SAME: (%[[ARG0:.*]]: memref<4xf32, #spirv.storage_class<StorageBuffer>>
10011018
// CHECK-SAME: %[[ARG1:.*]]: vector<1xf32>

0 commit comments

Comments
 (0)