Skip to content

Commit 1c3455c

Browse files
committed
[mlir] MemRefToSPIRV propagate alignment attribute.
1 parent f4664aa commit 1c3455c

File tree

2 files changed

+21
-3
lines changed

2 files changed

+21
-3
lines changed

mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -465,7 +465,8 @@ struct MemoryRequirements {
465465
/// Given an accessed SPIR-V pointer, calculates its alignment requirements, if
466466
/// any.
467467
static FailureOr<MemoryRequirements>
468-
calculateMemoryRequirements(Value accessedPtr, bool isNontemporal) {
468+
calculateMemoryRequirements(Value accessedPtr, bool isNontemporal,
469+
uint64_t preferredAlignment) {
469470
MLIRContext *ctx = accessedPtr.getContext();
470471

471472
auto memoryAccess = spirv::MemoryAccess::None;
@@ -494,7 +495,8 @@ calculateMemoryRequirements(Value accessedPtr, bool isNontemporal) {
494495

495496
memoryAccess = memoryAccess | spirv::MemoryAccess::Aligned;
496497
auto memAccessAttr = spirv::MemoryAccessAttr::get(ctx, memoryAccess);
497-
auto alignment = IntegerAttr::get(IntegerType::get(ctx, 32), *sizeInBytes);
498+
auto alignmentValue = preferredAlignment ? preferredAlignment : *sizeInBytes;
499+
auto alignment = IntegerAttr::get(IntegerType::get(ctx, 32), alignmentValue);
498500
return MemoryRequirements{memAccessAttr, alignment};
499501
}
500502

@@ -516,7 +518,8 @@ calculateMemoryRequirements(Value accessedPtr, LoadOrStoreOp loadOrStoreOp) {
516518
return MemoryRequirements{memrefMemAccess, memrefAlignment};
517519

518520
return calculateMemoryRequirements(accessedPtr,
519-
loadOrStoreOp.getNontemporal());
521+
loadOrStoreOp.getNontemporal(),
522+
loadOrStoreOp.getAlignment().value_or(0));
520523
}
521524

522525
LogicalResult

mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,21 @@ func.func @load_i1(%src: memref<4xi1, #spirv.storage_class<StorageBuffer>>, %i :
8585
return %0: i1
8686
}
8787

88+
// CHECK-LABEL: func @load_aligned
89+
// CHECK-SAME: (%[[SRC:.+]]: memref<4xi1, #spirv.storage_class<PhysicalStorageBuffer>>, %[[IDX:.+]]: index)
90+
func.func @load_aligned(%src: memref<4xi1, #spirv.storage_class<PhysicalStorageBuffer>>, %i : index) -> i1 {
91+
// CHECK-DAG: %[[SRC_CAST:.+]] = builtin.unrealized_conversion_cast %[[SRC]] : memref<4xi1, #spirv.storage_class<PhysicalStorageBuffer>> to !spirv.ptr<!spirv.struct<(!spirv.array<4 x i8, stride=1> [0])>, PhysicalStorageBuffer>
92+
// CHECK-DAG: %[[IDX_CAST:.+]] = builtin.unrealized_conversion_cast %[[IDX]]
93+
// CHECK: %[[ZERO:.*]] = spirv.Constant 0 : i32
94+
// CHECK: %[[ADDR:.+]] = spirv.AccessChain %[[SRC_CAST]][%[[ZERO]], %[[IDX_CAST]]]
95+
// CHECK: %[[VAL:.+]] = spirv.Load "PhysicalStorageBuffer" %[[ADDR]] ["Aligned", 32] : i8
96+
// CHECK: %[[ZERO_I8:.+]] = spirv.Constant 0 : i8
97+
// CHECK: %[[BOOL:.+]] = spirv.INotEqual %[[VAL]], %[[ZERO_I8]] : i8
98+
%0 = memref.load %src[%i] { alignment = 32 } : memref<4xi1, #spirv.storage_class<PhysicalStorageBuffer>>
99+
// CHECK: return %[[BOOL]]
100+
return %0: i1
101+
}
102+
88103
// CHECK-LABEL: func @store_i1
89104
// CHECK-SAME: %[[DST:.+]]: memref<4xi1, #spirv.storage_class<StorageBuffer>>,
90105
// CHECK-SAME: %[[IDX:.+]]: index

0 commit comments

Comments
 (0)