diff --git a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp index 7a705336bf11c..483a8db719a59 100644 --- a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp +++ b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp @@ -22,6 +22,7 @@ #include "mlir/IR/MLIRContext.h" #include "mlir/IR/Visitors.h" #include +#include #include #define DEBUG_TYPE "memref-to-spirv-pattern" @@ -465,7 +466,12 @@ struct MemoryRequirements { /// Given an accessed SPIR-V pointer, calculates its alignment requirements, if /// any. static FailureOr -calculateMemoryRequirements(Value accessedPtr, bool isNontemporal) { +calculateMemoryRequirements(Value accessedPtr, bool isNontemporal, + uint64_t preferredAlignment) { + if (preferredAlignment >= std::numeric_limits::max()) { + return failure(); + } + MLIRContext *ctx = accessedPtr.getContext(); auto memoryAccess = spirv::MemoryAccess::None; @@ -474,7 +480,10 @@ calculateMemoryRequirements(Value accessedPtr, bool isNontemporal) { } auto ptrType = cast(accessedPtr.getType()); - if (ptrType.getStorageClass() != spirv::StorageClass::PhysicalStorageBuffer) { + bool mayOmitAlignment = + !preferredAlignment && + ptrType.getStorageClass() != spirv::StorageClass::PhysicalStorageBuffer; + if (mayOmitAlignment) { if (memoryAccess == spirv::MemoryAccess::None) { return MemoryRequirements{spirv::MemoryAccessAttr{}, IntegerAttr{}}; } @@ -483,6 +492,7 @@ calculateMemoryRequirements(Value accessedPtr, bool isNontemporal) { } // PhysicalStorageBuffers require the `Aligned` attribute. + // Other storage types may show an `Aligned` attribute. auto pointeeType = dyn_cast(ptrType.getPointeeType()); if (!pointeeType) return failure(); @@ -494,7 +504,8 @@ calculateMemoryRequirements(Value accessedPtr, bool isNontemporal) { memoryAccess = memoryAccess | spirv::MemoryAccess::Aligned; auto memAccessAttr = spirv::MemoryAccessAttr::get(ctx, memoryAccess); - auto alignment = IntegerAttr::get(IntegerType::get(ctx, 32), *sizeInBytes); + auto alignmentValue = preferredAlignment ? preferredAlignment : *sizeInBytes; + auto alignment = IntegerAttr::get(IntegerType::get(ctx, 32), alignmentValue); return MemoryRequirements{memAccessAttr, alignment}; } @@ -508,16 +519,9 @@ calculateMemoryRequirements(Value accessedPtr, LoadOrStoreOp loadOrStoreOp) { llvm::is_one_of::value, "Must be called on either memref::LoadOp or memref::StoreOp"); - Operation *memrefAccessOp = loadOrStoreOp.getOperation(); - auto memrefMemAccess = memrefAccessOp->getAttrOfType( - spirv::attributeName()); - auto memrefAlignment = - memrefAccessOp->getAttrOfType("alignment"); - if (memrefMemAccess && memrefAlignment) - return MemoryRequirements{memrefMemAccess, memrefAlignment}; - return calculateMemoryRequirements(accessedPtr, - loadOrStoreOp.getNontemporal()); + loadOrStoreOp.getNontemporal(), + loadOrStoreOp.getAlignment().value_or(0)); } LogicalResult diff --git a/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir b/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir index d0ddac8cd801c..d708d456d5e91 100644 --- a/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir +++ b/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir @@ -85,6 +85,28 @@ func.func @load_i1(%src: memref<4xi1, #spirv.storage_class>, %i : return %0: i1 } +// CHECK-LABEL: func @load_aligned +// CHECK-SAME: (%[[SRC:.+]]: memref<4xi1, #spirv.storage_class>, %[[IDX:.+]]: index) +func.func @load_aligned(%src: memref<4xi1, #spirv.storage_class>, %i : index) -> i1 { + // CHECK: spirv.Load "StorageBuffer" {{.*}} ["Aligned", 32] : i8 + %0 = memref.load %src[%i] { alignment = 32 } : memref<4xi1, #spirv.storage_class> + return %0: i1 +} + +// CHECK-LABEL: func @load_aligned_nontemporal +func.func @load_aligned_nontemporal(%src: memref<4xi1, #spirv.storage_class>, %i : index) -> i1 { + // CHECK: spirv.Load "StorageBuffer" {{.*}} ["Aligned|Nontemporal", 32] : i8 + %0 = memref.load %src[%i] { alignment = 32, nontemporal = true } : memref<4xi1, #spirv.storage_class> + return %0: i1 +} + +// CHECK-LABEL: func @load_aligned_psb +func.func @load_aligned_psb(%src: memref<4xi1, #spirv.storage_class>, %i : index) -> i1 { + // CHECK: %[[VAL:.+]] = spirv.Load "PhysicalStorageBuffer" {{.*}} ["Aligned", 32] : i8 + %0 = memref.load %src[%i] { alignment = 32 } : memref<4xi1, #spirv.storage_class> + return %0: i1 +} + // CHECK-LABEL: func @store_i1 // CHECK-SAME: %[[DST:.+]]: memref<4xi1, #spirv.storage_class>, // CHECK-SAME: %[[IDX:.+]]: index