@@ -244,6 +244,16 @@ class LoadOpPattern final : public OpConversionPattern<memref::LoadOp> {
244244 ConversionPatternRewriter &rewriter) const override ;
245245};
246246
247+ // / Converts memref.load to spirv.Image + spirv.ImageFetch
248+ class ImageLoadOpPattern final : public OpConversionPattern<memref::LoadOp> {
249+ public:
250+ using OpConversionPattern::OpConversionPattern;
251+
252+ LogicalResult
253+ matchAndRewrite (memref::LoadOp loadOp, OpAdaptor adaptor,
254+ ConversionPatternRewriter &rewriter) const override ;
255+ };
256+
247257// / Converts memref.store to spirv.Store on integers.
248258class IntStoreOpPattern final : public OpConversionPattern<memref::StoreOp> {
249259public:
@@ -528,6 +538,17 @@ IntLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
528538 if (!memrefType.getElementType ().isSignlessInteger ())
529539 return failure ();
530540
541+ auto memorySpaceAttr =
542+ dyn_cast_if_present<spirv::StorageClassAttr>(memrefType.getMemorySpace ());
543+ if (!memorySpaceAttr)
544+ return rewriter.notifyMatchFailure (
545+ loadOp, " missing memory space SPIR-V storage class attribute" );
546+
547+ if (memorySpaceAttr.getValue () == spirv::StorageClass::Image)
548+ return rewriter.notifyMatchFailure (
549+ loadOp,
550+ " failed to lower memref in image storage class to storage buffer" );
551+
531552 const auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
532553 Value accessChain =
533554 spirv::getElementPtr (typeConverter, memrefType, adaptor.getMemref (),
@@ -644,6 +665,18 @@ LoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
644665 auto memrefType = cast<MemRefType>(loadOp.getMemref ().getType ());
645666 if (memrefType.getElementType ().isSignlessInteger ())
646667 return failure ();
668+
669+ auto memorySpaceAttr =
670+ dyn_cast_if_present<spirv::StorageClassAttr>(memrefType.getMemorySpace ());
671+ if (!memorySpaceAttr)
672+ return rewriter.notifyMatchFailure (
673+ loadOp, " missing memory space SPIR-V storage class attribute" );
674+
675+ if (memorySpaceAttr.getValue () == spirv::StorageClass::Image)
676+ return rewriter.notifyMatchFailure (
677+ loadOp,
678+ " failed to lower memref in image storage class to storage buffer" );
679+
647680 Value loadPtr = spirv::getElementPtr (
648681 *getTypeConverter<SPIRVTypeConverter>(), memrefType, adaptor.getMemref (),
649682 adaptor.getIndices (), loadOp.getLoc (), rewriter);
@@ -662,6 +695,87 @@ LoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
662695 return success ();
663696}
664697
698+ LogicalResult
699+ ImageLoadOpPattern::matchAndRewrite (memref::LoadOp loadOp, OpAdaptor adaptor,
700+ ConversionPatternRewriter &rewriter) const {
701+ auto memrefType = cast<MemRefType>(loadOp.getMemref ().getType ());
702+
703+ auto memorySpaceAttr =
704+ dyn_cast_if_present<spirv::StorageClassAttr>(memrefType.getMemorySpace ());
705+ if (!memorySpaceAttr)
706+ return rewriter.notifyMatchFailure (
707+ loadOp, " missing memory space SPIR-V storage class attribute" );
708+
709+ if (memorySpaceAttr.getValue () != spirv::StorageClass::Image)
710+ return rewriter.notifyMatchFailure (
711+ loadOp, " failed to lower memref in non-image storage class to image" );
712+
713+ Value loadPtr = adaptor.getMemref ();
714+ auto memoryRequirements = calculateMemoryRequirements (loadPtr, loadOp);
715+ if (failed (memoryRequirements))
716+ return rewriter.notifyMatchFailure (
717+ loadOp, " failed to determine memory requirements" );
718+
719+ const auto [memoryAccess, alignment] = *memoryRequirements;
720+
721+ if (!loadOp.getMemRefType ().hasRank ())
722+ return rewriter.notifyMatchFailure (
723+ loadOp, " cannot lower unranked memrefs to SPIR-V images" );
724+
725+ // We currently only support lowering of scalar memref elements to texels in
726+ // the R[16|32][f|i|ui] formats. Future work will enable lowering of vector
727+ // elements to texels in richer formats.
728+ if (!isa<spirv::ScalarType>(loadOp.getMemRefType ().getElementType ()))
729+ return rewriter.notifyMatchFailure (
730+ loadOp,
731+ " cannot lower memrefs who's element type is not a SPIR-V scalar type"
732+ " to SPIR-V images" );
733+
734+ // We currently only support sampled images since OpImageFetch does not work
735+ // for plain images and the OpImageRead instruction needs to be materialized
736+ // instead or texels need to be accessed via atomics through a texel pointer.
737+ // Future work will generalize support to plain images.
738+ auto convertedPointeeType = cast<spirv::PointerType>(
739+ getTypeConverter ()->convertType (loadOp.getMemRefType ()));
740+ if (!isa<spirv::SampledImageType>(convertedPointeeType.getPointeeType ()))
741+ return rewriter.notifyMatchFailure (loadOp,
742+ " cannot lower memrefs which do not "
743+ " convert to SPIR-V sampled images" );
744+
745+ // Materialize the lowering.
746+ Location loc = loadOp->getLoc ();
747+ auto imageLoadOp =
748+ spirv::LoadOp::create (rewriter, loc, loadPtr, memoryAccess, alignment);
749+ // Extract the image from the sampled image.
750+ auto imageOp = spirv::ImageOp::create (rewriter, loc, imageLoadOp);
751+
752+ // Build a vector of coordinates or just a scalar index if we have a 1D image.
753+ Value coords;
754+ if (memrefType.getRank () != 1 ) {
755+ auto coordVectorType = VectorType::get ({loadOp.getMemRefType ().getRank ()},
756+ adaptor.getIndices ().getType ()[0 ]);
757+ coords = spirv::CompositeConstructOp::create (rewriter, loc, coordVectorType,
758+ adaptor.getIndices ());
759+ } else {
760+ coords = adaptor.getIndices ()[0 ];
761+ }
762+
763+ // Fetch the value out of the image.
764+ auto resultVectorType = VectorType::get ({4 }, loadOp.getType ());
765+ auto fetchOp = spirv::ImageFetchOp::create (
766+ rewriter, loc, resultVectorType, imageOp, coords,
767+ mlir::spirv::ImageOperandsAttr{}, ValueRange{});
768+
769+ // Note that because OpImageFetch returns a rank 4 vector we need to extract
770+ // the elements corresponding to the load which will since we only support the
771+ // R[16|32][f|i|ui] formats will always be the R(red) 0th vector element.
772+ auto compositeExtractOp =
773+ spirv::CompositeExtractOp::create (rewriter, loc, fetchOp, 0 );
774+
775+ rewriter.replaceOp (loadOp, compositeExtractOp);
776+ return success ();
777+ }
778+
665779LogicalResult
666780IntStoreOpPattern::matchAndRewrite (memref::StoreOp storeOp, OpAdaptor adaptor,
667781 ConversionPatternRewriter &rewriter) const {
@@ -953,11 +1067,11 @@ LogicalResult ExtractAlignedPointerAsIndexOpPattern::matchAndRewrite(
9531067namespace mlir {
9541068void populateMemRefToSPIRVPatterns (const SPIRVTypeConverter &typeConverter,
9551069 RewritePatternSet &patterns) {
956- patterns
957- . add <AllocaOpPattern, AllocOpPattern, AtomicRMWOpPattern ,
958- DeallocOpPattern, IntLoadOpPattern, IntStoreOpPattern, LoadOpPattern,
959- MemorySpaceCastOpPattern, StoreOpPattern, ReinterpretCastPattern,
960- CastPattern, ExtractAlignedPointerAsIndexOpPattern>(
961- typeConverter, patterns.getContext ());
1070+ patterns. add <AllocaOpPattern, AllocOpPattern, AtomicRMWOpPattern,
1071+ DeallocOpPattern, IntLoadOpPattern, ImageLoadOpPattern ,
1072+ IntStoreOpPattern, LoadOpPattern, MemorySpaceCastOpPattern ,
1073+ StoreOpPattern, ReinterpretCastPattern, CastPattern ,
1074+ ExtractAlignedPointerAsIndexOpPattern>(typeConverter,
1075+ patterns.getContext ());
9621076}
9631077} // namespace mlir
0 commit comments