@@ -465,7 +465,8 @@ struct MemoryRequirements {
465465// / Given an accessed SPIR-V pointer, calculates its alignment requirements, if
466466// / any.
467467static FailureOr<MemoryRequirements>
468- calculateMemoryRequirements (Value accessedPtr, bool isNontemporal) {
468+ calculateMemoryRequirements (Value accessedPtr, bool isNontemporal,
469+ uint64_t preferredAlignment) {
469470 MLIRContext *ctx = accessedPtr.getContext ();
470471
471472 auto memoryAccess = spirv::MemoryAccess::None;
@@ -494,7 +495,8 @@ calculateMemoryRequirements(Value accessedPtr, bool isNontemporal) {
494495
495496 memoryAccess = memoryAccess | spirv::MemoryAccess::Aligned;
496497 auto memAccessAttr = spirv::MemoryAccessAttr::get (ctx, memoryAccess);
497- auto alignment = IntegerAttr::get (IntegerType::get (ctx, 32 ), *sizeInBytes);
498+ auto alignmentValue = preferredAlignment ? preferredAlignment : *sizeInBytes;
499+ auto alignment = IntegerAttr::get (IntegerType::get (ctx, 32 ), alignmentValue);
498500 return MemoryRequirements{memAccessAttr, alignment};
499501}
500502
@@ -516,7 +518,8 @@ calculateMemoryRequirements(Value accessedPtr, LoadOrStoreOp loadOrStoreOp) {
516518 return MemoryRequirements{memrefMemAccess, memrefAlignment};
517519
518520 return calculateMemoryRequirements (accessedPtr,
519- loadOrStoreOp.getNontemporal ());
521+ loadOrStoreOp.getNontemporal (),
522+ loadOrStoreOp.getAlignment ().value_or (0 ));
520523}
521524
522525LogicalResult
0 commit comments