@@ -743,6 +743,23 @@ struct VectorLoadOpConverter final
743743
744744 auto vectorPtrType = spirv::PointerType::get (spirvVectorType, storageClass);
745745
746+ auto alignment = loadOp.getAlignment ();
747+ if (alignment.has_value () &&
748+ alignment > std::numeric_limits<uint32_t >::max ()) {
749+ return rewriter.notifyMatchFailure (loadOp,
750+ " invalid alignment requirement" );
751+ }
752+
753+ auto memoryAccess = spirv::MemoryAccess::None;
754+ auto memoryAccessAttr = spirv::MemoryAccessAttr{};
755+ IntegerAttr alignmentAttr = nullptr ;
756+ if (alignment.has_value ()) {
757+ memoryAccess = memoryAccess | spirv::MemoryAccess::Aligned;
758+ memoryAccessAttr =
759+ spirv::MemoryAccessAttr::get (rewriter.getContext (), memoryAccess);
760+ alignmentAttr = rewriter.getI32IntegerAttr (alignment.value ());
761+ }
762+
746763 // For single element vectors, we don't need to bitcast the access chain to
747764 // the original vector type. Both is going to be the same, a pointer
748765 // to a scalar.
@@ -753,7 +770,8 @@ struct VectorLoadOpConverter final
753770 accessChain);
754771
755772 rewriter.replaceOpWithNewOp <spirv::LoadOp>(loadOp, spirvVectorType,
756- castedAccessChain);
773+ castedAccessChain,
774+ memoryAccessAttr, alignmentAttr);
757775
758776 return success ();
759777 }
@@ -782,6 +800,12 @@ struct VectorStoreOpConverter final
782800 return rewriter.notifyMatchFailure (
783801 storeOp, " failed to get memref element pointer" );
784802
803+ auto alignment = storeOp.getAlignment ();
804+ if (alignment && alignment > std::numeric_limits<uint32_t >::max ()) {
805+ return rewriter.notifyMatchFailure (storeOp,
806+ " invalid alignment requirement" );
807+ }
808+
785809 spirv::StorageClass storageClass = attr.getValue ();
786810 auto vectorType = storeOp.getVectorType ();
787811 auto vectorPtrType = spirv::PointerType::get (vectorType, storageClass);
@@ -795,8 +819,19 @@ struct VectorStoreOpConverter final
795819 : spirv::BitcastOp::create (rewriter, loc, vectorPtrType,
796820 accessChain);
797821
798- rewriter.replaceOpWithNewOp <spirv::StoreOp>(storeOp, castedAccessChain,
799- adaptor.getValueToStore ());
822+ auto memoryAccess = spirv::MemoryAccess::None;
823+ auto memoryAccessAttr = spirv::MemoryAccessAttr{};
824+ IntegerAttr alignmentAttr = nullptr ;
825+ if (alignment.has_value ()) {
826+ memoryAccess = memoryAccess | spirv::MemoryAccess::Aligned;
827+ memoryAccessAttr =
828+ spirv::MemoryAccessAttr::get (rewriter.getContext (), memoryAccess);
829+ alignmentAttr = rewriter.getI32IntegerAttr (alignment.value ());
830+ }
831+
832+ rewriter.replaceOpWithNewOp <spirv::StoreOp>(
833+ storeOp, castedAccessChain, adaptor.getValueToStore (), memoryAccessAttr,
834+ alignmentAttr);
800835
801836 return success ();
802837 }
0 commit comments