Skip to content

Commit 6d231fb

Browse files
authored
[mlir] MemRefToSPIRV propagate alignment attributes from MemRef ops. (#151723)
This patchset: * propagates alignment attributes from memref operations into the SPIR-V dialect, * fixes an error in the logic which previously propagated alignment attributes but did not add other MemoryAccess attributes. * adds a failure condition in the case where the alignment attribute from the memref dialect (64-bit wide) does not fit in SPIR-V's alignment attribute (specified to be 32-bit wide).
1 parent 44aedac commit 6d231fb

File tree

2 files changed

+38
-12
lines changed

2 files changed

+38
-12
lines changed

mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
#include "mlir/IR/MLIRContext.h"
2323
#include "mlir/IR/Visitors.h"
2424
#include <cassert>
25+
#include <limits>
2526
#include <optional>
2627

2728
#define DEBUG_TYPE "memref-to-spirv-pattern"
@@ -475,7 +476,12 @@ struct MemoryRequirements {
475476
/// Given an accessed SPIR-V pointer, calculates its alignment requirements, if
476477
/// any.
477478
static FailureOr<MemoryRequirements>
478-
calculateMemoryRequirements(Value accessedPtr, bool isNontemporal) {
479+
calculateMemoryRequirements(Value accessedPtr, bool isNontemporal,
480+
uint64_t preferredAlignment) {
481+
if (preferredAlignment >= std::numeric_limits<uint32_t>::max()) {
482+
return failure();
483+
}
484+
479485
MLIRContext *ctx = accessedPtr.getContext();
480486

481487
auto memoryAccess = spirv::MemoryAccess::None;
@@ -484,7 +490,10 @@ calculateMemoryRequirements(Value accessedPtr, bool isNontemporal) {
484490
}
485491

486492
auto ptrType = cast<spirv::PointerType>(accessedPtr.getType());
487-
if (ptrType.getStorageClass() != spirv::StorageClass::PhysicalStorageBuffer) {
493+
bool mayOmitAlignment =
494+
!preferredAlignment &&
495+
ptrType.getStorageClass() != spirv::StorageClass::PhysicalStorageBuffer;
496+
if (mayOmitAlignment) {
488497
if (memoryAccess == spirv::MemoryAccess::None) {
489498
return MemoryRequirements{spirv::MemoryAccessAttr{}, IntegerAttr{}};
490499
}
@@ -493,6 +502,7 @@ calculateMemoryRequirements(Value accessedPtr, bool isNontemporal) {
493502
}
494503

495504
// PhysicalStorageBuffers require the `Aligned` attribute.
505+
// Other storage types may show an `Aligned` attribute.
496506
auto pointeeType = dyn_cast<spirv::ScalarType>(ptrType.getPointeeType());
497507
if (!pointeeType)
498508
return failure();
@@ -504,7 +514,8 @@ calculateMemoryRequirements(Value accessedPtr, bool isNontemporal) {
504514

505515
memoryAccess = memoryAccess | spirv::MemoryAccess::Aligned;
506516
auto memAccessAttr = spirv::MemoryAccessAttr::get(ctx, memoryAccess);
507-
auto alignment = IntegerAttr::get(IntegerType::get(ctx, 32), *sizeInBytes);
517+
auto alignmentValue = preferredAlignment ? preferredAlignment : *sizeInBytes;
518+
auto alignment = IntegerAttr::get(IntegerType::get(ctx, 32), alignmentValue);
508519
return MemoryRequirements{memAccessAttr, alignment};
509520
}
510521

@@ -518,16 +529,9 @@ calculateMemoryRequirements(Value accessedPtr, LoadOrStoreOp loadOrStoreOp) {
518529
llvm::is_one_of<LoadOrStoreOp, memref::LoadOp, memref::StoreOp>::value,
519530
"Must be called on either memref::LoadOp or memref::StoreOp");
520531

521-
Operation *memrefAccessOp = loadOrStoreOp.getOperation();
522-
auto memrefMemAccess = memrefAccessOp->getAttrOfType<spirv::MemoryAccessAttr>(
523-
spirv::attributeName<spirv::MemoryAccess>());
524-
auto memrefAlignment =
525-
memrefAccessOp->getAttrOfType<IntegerAttr>("alignment");
526-
if (memrefMemAccess && memrefAlignment)
527-
return MemoryRequirements{memrefMemAccess, memrefAlignment};
528-
529532
return calculateMemoryRequirements(accessedPtr,
530-
loadOrStoreOp.getNontemporal());
533+
loadOrStoreOp.getNontemporal(),
534+
loadOrStoreOp.getAlignment().value_or(0));
531535
}
532536

533537
LogicalResult

mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,28 @@ func.func @load_i1(%src: memref<4xi1, #spirv.storage_class<StorageBuffer>>, %i :
8585
return %0: i1
8686
}
8787

88+
// CHECK-LABEL: func @load_aligned
89+
// CHECK-SAME: (%[[SRC:.+]]: memref<4xi1, #spirv.storage_class<StorageBuffer>>, %[[IDX:.+]]: index)
90+
func.func @load_aligned(%src: memref<4xi1, #spirv.storage_class<StorageBuffer>>, %i : index) -> i1 {
91+
// CHECK: spirv.Load "StorageBuffer" {{.*}} ["Aligned", 32] : i8
92+
%0 = memref.load %src[%i] { alignment = 32 } : memref<4xi1, #spirv.storage_class<StorageBuffer>>
93+
return %0: i1
94+
}
95+
96+
// CHECK-LABEL: func @load_aligned_nontemporal
97+
func.func @load_aligned_nontemporal(%src: memref<4xi1, #spirv.storage_class<StorageBuffer>>, %i : index) -> i1 {
98+
// CHECK: spirv.Load "StorageBuffer" {{.*}} ["Aligned|Nontemporal", 32] : i8
99+
%0 = memref.load %src[%i] { alignment = 32, nontemporal = true } : memref<4xi1, #spirv.storage_class<StorageBuffer>>
100+
return %0: i1
101+
}
102+
103+
// CHECK-LABEL: func @load_aligned_psb
104+
func.func @load_aligned_psb(%src: memref<4xi1, #spirv.storage_class<PhysicalStorageBuffer>>, %i : index) -> i1 {
105+
// CHECK: %[[VAL:.+]] = spirv.Load "PhysicalStorageBuffer" {{.*}} ["Aligned", 32] : i8
106+
%0 = memref.load %src[%i] { alignment = 32 } : memref<4xi1, #spirv.storage_class<PhysicalStorageBuffer>>
107+
return %0: i1
108+
}
109+
88110
// CHECK-LABEL: func @store_i1
89111
// CHECK-SAME: %[[DST:.+]]: memref<4xi1, #spirv.storage_class<StorageBuffer>>,
90112
// CHECK-SAME: %[[IDX:.+]]: index

0 commit comments

Comments
 (0)