22
22
#include " mlir/IR/MLIRContext.h"
23
23
#include " mlir/IR/Visitors.h"
24
24
#include < cassert>
25
+ #include < limits>
25
26
#include < optional>
26
27
27
28
#define DEBUG_TYPE " memref-to-spirv-pattern"
@@ -475,7 +476,12 @@ struct MemoryRequirements {
475
476
// / Given an accessed SPIR-V pointer, calculates its alignment requirements, if
476
477
// / any.
477
478
static FailureOr<MemoryRequirements>
478
- calculateMemoryRequirements (Value accessedPtr, bool isNontemporal) {
479
+ calculateMemoryRequirements (Value accessedPtr, bool isNontemporal,
480
+ uint64_t preferredAlignment) {
481
+ if (preferredAlignment >= std::numeric_limits<uint32_t >::max ()) {
482
+ return failure ();
483
+ }
484
+
479
485
MLIRContext *ctx = accessedPtr.getContext ();
480
486
481
487
auto memoryAccess = spirv::MemoryAccess::None;
@@ -484,7 +490,10 @@ calculateMemoryRequirements(Value accessedPtr, bool isNontemporal) {
484
490
}
485
491
486
492
auto ptrType = cast<spirv::PointerType>(accessedPtr.getType ());
487
- if (ptrType.getStorageClass () != spirv::StorageClass::PhysicalStorageBuffer) {
493
+ bool mayOmitAlignment =
494
+ !preferredAlignment &&
495
+ ptrType.getStorageClass () != spirv::StorageClass::PhysicalStorageBuffer;
496
+ if (mayOmitAlignment) {
488
497
if (memoryAccess == spirv::MemoryAccess::None) {
489
498
return MemoryRequirements{spirv::MemoryAccessAttr{}, IntegerAttr{}};
490
499
}
@@ -493,6 +502,7 @@ calculateMemoryRequirements(Value accessedPtr, bool isNontemporal) {
493
502
}
494
503
495
504
// PhysicalStorageBuffers require the `Aligned` attribute.
505
+ // Other storage types may show an `Aligned` attribute.
496
506
auto pointeeType = dyn_cast<spirv::ScalarType>(ptrType.getPointeeType ());
497
507
if (!pointeeType)
498
508
return failure ();
@@ -504,7 +514,8 @@ calculateMemoryRequirements(Value accessedPtr, bool isNontemporal) {
504
514
505
515
memoryAccess = memoryAccess | spirv::MemoryAccess::Aligned;
506
516
auto memAccessAttr = spirv::MemoryAccessAttr::get (ctx, memoryAccess);
507
- auto alignment = IntegerAttr::get (IntegerType::get (ctx, 32 ), *sizeInBytes);
517
+ auto alignmentValue = preferredAlignment ? preferredAlignment : *sizeInBytes;
518
+ auto alignment = IntegerAttr::get (IntegerType::get (ctx, 32 ), alignmentValue);
508
519
return MemoryRequirements{memAccessAttr, alignment};
509
520
}
510
521
@@ -518,16 +529,9 @@ calculateMemoryRequirements(Value accessedPtr, LoadOrStoreOp loadOrStoreOp) {
518
529
llvm::is_one_of<LoadOrStoreOp, memref::LoadOp, memref::StoreOp>::value,
519
530
" Must be called on either memref::LoadOp or memref::StoreOp" );
520
531
521
- Operation *memrefAccessOp = loadOrStoreOp.getOperation ();
522
- auto memrefMemAccess = memrefAccessOp->getAttrOfType <spirv::MemoryAccessAttr>(
523
- spirv::attributeName<spirv::MemoryAccess>());
524
- auto memrefAlignment =
525
- memrefAccessOp->getAttrOfType <IntegerAttr>(" alignment" );
526
- if (memrefMemAccess && memrefAlignment)
527
- return MemoryRequirements{memrefMemAccess, memrefAlignment};
528
-
529
532
return calculateMemoryRequirements (accessedPtr,
530
- loadOrStoreOp.getNontemporal ());
533
+ loadOrStoreOp.getNontemporal (),
534
+ loadOrStoreOp.getAlignment ().value_or (0 ));
531
535
}
532
536
533
537
LogicalResult
0 commit comments