Skip to content

Commit 526787f

Browse files
committed
[mlir] Fix calculateMemoryRequirements in MemRefToSPIRV.
There was an early return in calculateMemoryRequirements that looked explicitly for alignment and only set the alignment attribute. However, this was not correct for the following reasons: * Alignment was set only if both the alignment and the memory_access attributes were both present in the memref operation, without handling the case when only the alignment was exclusively present. * In the case alignment and memory_access attributes were both present, the memory_access attribute would not be updated to aligned if the memory_access attribute was not marked aligned. * In the case alignment and memory_access attributes were both present, other memory requirements (e.g., non_temporal) would not be added as attributes.
1 parent 1c3455c commit 526787f

File tree

2 files changed

+21
-9
lines changed

2 files changed

+21
-9
lines changed

mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -475,7 +475,10 @@ calculateMemoryRequirements(Value accessedPtr, bool isNontemporal,
475475
}
476476

477477
auto ptrType = cast<spirv::PointerType>(accessedPtr.getType());
478-
if (ptrType.getStorageClass() != spirv::StorageClass::PhysicalStorageBuffer) {
478+
bool mayOmitAlignment =
479+
!preferredAlignment &&
480+
ptrType.getStorageClass() != spirv::StorageClass::PhysicalStorageBuffer;
481+
if (mayOmitAlignment) {
479482
if (memoryAccess == spirv::MemoryAccess::None) {
480483
return MemoryRequirements{spirv::MemoryAccessAttr{}, IntegerAttr{}};
481484
}
@@ -484,6 +487,7 @@ calculateMemoryRequirements(Value accessedPtr, bool isNontemporal,
484487
}
485488

486489
// PhysicalStorageBuffers require the `Aligned` attribute.
490+
// Other storage types may show an `Aligned` attribute.
487491
auto pointeeType = dyn_cast<spirv::ScalarType>(ptrType.getPointeeType());
488492
if (!pointeeType)
489493
return failure();
@@ -510,13 +514,6 @@ calculateMemoryRequirements(Value accessedPtr, LoadOrStoreOp loadOrStoreOp) {
510514
llvm::is_one_of<LoadOrStoreOp, memref::LoadOp, memref::StoreOp>::value,
511515
"Must be called on either memref::LoadOp or memref::StoreOp");
512516

513-
Operation *memrefAccessOp = loadOrStoreOp.getOperation();
514-
auto memrefMemAccess = memrefAccessOp->getAttrOfType<spirv::MemoryAccessAttr>(
515-
spirv::attributeName<spirv::MemoryAccess>());
516-
auto memrefAlignment = loadOrStoreOp.getAlignmentAttr();
517-
if (memrefMemAccess && memrefAlignment)
518-
return MemoryRequirements{memrefMemAccess, memrefAlignment};
519-
520517
return calculateMemoryRequirements(accessedPtr,
521518
loadOrStoreOp.getNontemporal(),
522519
loadOrStoreOp.getAlignment().value_or(0));

mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,8 +86,23 @@ func.func @load_i1(%src: memref<4xi1, #spirv.storage_class<StorageBuffer>>, %i :
8686
}
8787

8888
// CHECK-LABEL: func @load_aligned
89+
// CHECK-SAME: (%[[SRC:.+]]: memref<4xi1, #spirv.storage_class<StorageBuffer>>, %[[IDX:.+]]: index)
90+
func.func @load_aligned(%src: memref<4xi1, #spirv.storage_class<StorageBuffer>>, %i : index) -> i1 {
91+
// CHECK-DAG: %[[SRC_CAST:.+]] = builtin.unrealized_conversion_cast %[[SRC]] : memref<4xi1, #spirv.storage_class<StorageBuffer>> to !spirv.ptr<!spirv.struct<(!spirv.array<4 x i8, stride=1> [0])>, StorageBuffer>
92+
// CHECK-DAG: %[[IDX_CAST:.+]] = builtin.unrealized_conversion_cast %[[IDX]]
93+
// CHECK: %[[ZERO:.*]] = spirv.Constant 0 : i32
94+
// CHECK: %[[ADDR:.+]] = spirv.AccessChain %[[SRC_CAST]][%[[ZERO]], %[[IDX_CAST]]]
95+
// CHECK: %[[VAL:.+]] = spirv.Load "StorageBuffer" %[[ADDR]] ["Aligned", 32] : i8
96+
// CHECK: %[[ZERO_I8:.+]] = spirv.Constant 0 : i8
97+
// CHECK: %[[BOOL:.+]] = spirv.INotEqual %[[VAL]], %[[ZERO_I8]] : i8
98+
%0 = memref.load %src[%i] { alignment = 32 } : memref<4xi1, #spirv.storage_class<StorageBuffer>>
99+
// CHECK: return %[[BOOL]]
100+
return %0: i1
101+
}
102+
103+
// CHECK-LABEL: func @load_aligned_psb
89104
// CHECK-SAME: (%[[SRC:.+]]: memref<4xi1, #spirv.storage_class<PhysicalStorageBuffer>>, %[[IDX:.+]]: index)
90-
func.func @load_aligned(%src: memref<4xi1, #spirv.storage_class<PhysicalStorageBuffer>>, %i : index) -> i1 {
105+
func.func @load_aligned_psb(%src: memref<4xi1, #spirv.storage_class<PhysicalStorageBuffer>>, %i : index) -> i1 {
91106
// CHECK-DAG: %[[SRC_CAST:.+]] = builtin.unrealized_conversion_cast %[[SRC]] : memref<4xi1, #spirv.storage_class<PhysicalStorageBuffer>> to !spirv.ptr<!spirv.struct<(!spirv.array<4 x i8, stride=1> [0])>, PhysicalStorageBuffer>
92107
// CHECK-DAG: %[[IDX_CAST:.+]] = builtin.unrealized_conversion_cast %[[IDX]]
93108
// CHECK: %[[ZERO:.*]] = spirv.Constant 0 : i32

0 commit comments

Comments
 (0)