2222#include " mlir/IR/MLIRContext.h"
2323#include " mlir/IR/Visitors.h"
2424#include < cassert>
25+ #include < limits>
2526#include < optional>
2627
2728#define DEBUG_TYPE " memref-to-spirv-pattern"
@@ -475,7 +476,12 @@ struct MemoryRequirements {
475476// / Given an accessed SPIR-V pointer, calculates its alignment requirements, if
476477// / any.
477478static FailureOr<MemoryRequirements>
478- calculateMemoryRequirements (Value accessedPtr, bool isNontemporal) {
479+ calculateMemoryRequirements (Value accessedPtr, bool isNontemporal,
480+ uint64_t preferredAlignment) {
481+ if (preferredAlignment >= std::numeric_limits<uint32_t >::max ()) {
482+ return failure ();
483+ }
484+
479485 MLIRContext *ctx = accessedPtr.getContext ();
480486
481487 auto memoryAccess = spirv::MemoryAccess::None;
@@ -484,7 +490,10 @@ calculateMemoryRequirements(Value accessedPtr, bool isNontemporal) {
484490 }
485491
486492 auto ptrType = cast<spirv::PointerType>(accessedPtr.getType ());
487- if (ptrType.getStorageClass () != spirv::StorageClass::PhysicalStorageBuffer) {
493+ bool mayOmitAlignment =
494+ !preferredAlignment &&
495+ ptrType.getStorageClass () != spirv::StorageClass::PhysicalStorageBuffer;
496+ if (mayOmitAlignment) {
488497 if (memoryAccess == spirv::MemoryAccess::None) {
489498 return MemoryRequirements{spirv::MemoryAccessAttr{}, IntegerAttr{}};
490499 }
@@ -493,6 +502,7 @@ calculateMemoryRequirements(Value accessedPtr, bool isNontemporal) {
493502 }
494503
495504 // PhysicalStorageBuffers require the `Aligned` attribute.
505+ // Other storage types may show an `Aligned` attribute.
496506 auto pointeeType = dyn_cast<spirv::ScalarType>(ptrType.getPointeeType ());
497507 if (!pointeeType)
498508 return failure ();
@@ -504,7 +514,8 @@ calculateMemoryRequirements(Value accessedPtr, bool isNontemporal) {
504514
505515 memoryAccess = memoryAccess | spirv::MemoryAccess::Aligned;
506516 auto memAccessAttr = spirv::MemoryAccessAttr::get (ctx, memoryAccess);
507- auto alignment = IntegerAttr::get (IntegerType::get (ctx, 32 ), *sizeInBytes);
517+ auto alignmentValue = preferredAlignment ? preferredAlignment : *sizeInBytes;
518+ auto alignment = IntegerAttr::get (IntegerType::get (ctx, 32 ), alignmentValue);
508519 return MemoryRequirements{memAccessAttr, alignment};
509520}
510521
@@ -518,16 +529,9 @@ calculateMemoryRequirements(Value accessedPtr, LoadOrStoreOp loadOrStoreOp) {
518529 llvm::is_one_of<LoadOrStoreOp, memref::LoadOp, memref::StoreOp>::value,
519530 " Must be called on either memref::LoadOp or memref::StoreOp" );
520531
521- Operation *memrefAccessOp = loadOrStoreOp.getOperation ();
522- auto memrefMemAccess = memrefAccessOp->getAttrOfType <spirv::MemoryAccessAttr>(
523- spirv::attributeName<spirv::MemoryAccess>());
524- auto memrefAlignment =
525- memrefAccessOp->getAttrOfType <IntegerAttr>(" alignment" );
526- if (memrefMemAccess && memrefAlignment)
527- return MemoryRequirements{memrefMemAccess, memrefAlignment};
528-
529532 return calculateMemoryRequirements (accessedPtr,
530- loadOrStoreOp.getNontemporal ());
533+ loadOrStoreOp.getNontemporal (),
534+ loadOrStoreOp.getAlignment ().value_or (0 ));
531535}
532536
533537LogicalResult
0 commit comments