@@ -244,6 +244,16 @@ class LoadOpPattern final : public OpConversionPattern<memref::LoadOp> {
244
244
ConversionPatternRewriter &rewriter) const override ;
245
245
};
246
246
247
+ // / Converts memref.load to spirv.Image + spirv.ImageFetch
248
+ class ImageLoadOpPattern final : public OpConversionPattern<memref::LoadOp> {
249
+ public:
250
+ using OpConversionPattern::OpConversionPattern;
251
+
252
+ LogicalResult
253
+ matchAndRewrite (memref::LoadOp loadOp, OpAdaptor adaptor,
254
+ ConversionPatternRewriter &rewriter) const override ;
255
+ };
256
+
247
257
// / Converts memref.store to spirv.Store on integers.
248
258
class IntStoreOpPattern final : public OpConversionPattern<memref::StoreOp> {
249
259
public:
@@ -528,6 +538,17 @@ IntLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
528
538
if (!memrefType.getElementType ().isSignlessInteger ())
529
539
return failure ();
530
540
541
+ auto memorySpaceAttr =
542
+ dyn_cast_if_present<spirv::StorageClassAttr>(memrefType.getMemorySpace ());
543
+ if (!memorySpaceAttr)
544
+ return rewriter.notifyMatchFailure (
545
+ loadOp, " missing memory space SPIR-V storage class attribute" );
546
+
547
+ if (memorySpaceAttr.getValue () == spirv::StorageClass::Image)
548
+ return rewriter.notifyMatchFailure (
549
+ loadOp,
550
+ " failed to lower memref in image storage class to storage buffer" );
551
+
531
552
const auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
532
553
Value accessChain =
533
554
spirv::getElementPtr (typeConverter, memrefType, adaptor.getMemref (),
@@ -644,6 +665,18 @@ LoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
644
665
auto memrefType = cast<MemRefType>(loadOp.getMemref ().getType ());
645
666
if (memrefType.getElementType ().isSignlessInteger ())
646
667
return failure ();
668
+
669
+ auto memorySpaceAttr =
670
+ dyn_cast_if_present<spirv::StorageClassAttr>(memrefType.getMemorySpace ());
671
+ if (!memorySpaceAttr)
672
+ return rewriter.notifyMatchFailure (
673
+ loadOp, " missing memory space SPIR-V storage class attribute" );
674
+
675
+ if (memorySpaceAttr.getValue () == spirv::StorageClass::Image)
676
+ return rewriter.notifyMatchFailure (
677
+ loadOp,
678
+ " failed to lower memref in image storage class to storage buffer" );
679
+
647
680
Value loadPtr = spirv::getElementPtr (
648
681
*getTypeConverter<SPIRVTypeConverter>(), memrefType, adaptor.getMemref (),
649
682
adaptor.getIndices (), loadOp.getLoc (), rewriter);
@@ -662,6 +695,87 @@ LoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
662
695
return success ();
663
696
}
664
697
698
+ LogicalResult
699
+ ImageLoadOpPattern::matchAndRewrite (memref::LoadOp loadOp, OpAdaptor adaptor,
700
+ ConversionPatternRewriter &rewriter) const {
701
+ auto memrefType = cast<MemRefType>(loadOp.getMemref ().getType ());
702
+
703
+ auto memorySpaceAttr =
704
+ dyn_cast_if_present<spirv::StorageClassAttr>(memrefType.getMemorySpace ());
705
+ if (!memorySpaceAttr)
706
+ return rewriter.notifyMatchFailure (
707
+ loadOp, " missing memory space SPIR-V storage class attribute" );
708
+
709
+ if (memorySpaceAttr.getValue () != spirv::StorageClass::Image)
710
+ return rewriter.notifyMatchFailure (
711
+ loadOp, " failed to lower memref in non-image storage class to image" );
712
+
713
+ Value loadPtr = adaptor.getMemref ();
714
+ auto memoryRequirements = calculateMemoryRequirements (loadPtr, loadOp);
715
+ if (failed (memoryRequirements))
716
+ return rewriter.notifyMatchFailure (
717
+ loadOp, " failed to determine memory requirements" );
718
+
719
+ const auto [memoryAccess, alignment] = *memoryRequirements;
720
+
721
+ if (!loadOp.getMemRefType ().hasRank ())
722
+ return rewriter.notifyMatchFailure (
723
+ loadOp, " cannot lower unranked memrefs to SPIR-V images" );
724
+
725
+ // We currently only support lowering of scalar memref elements to texels in
726
+ // the R[16|32][f|i|ui] formats. Future work will enable lowering of vector
727
+ // elements to texels in richer formats.
728
+ if (!isa<spirv::ScalarType>(loadOp.getMemRefType ().getElementType ()))
729
+ return rewriter.notifyMatchFailure (
730
+ loadOp,
731
+ " cannot lower memrefs who's element type is not a SPIR-V scalar type"
732
+ " to SPIR-V images" );
733
+
734
+ // We currently only support sampled images since OpImageFetch does not work
735
+ // for plain images and the OpImageRead instruction needs to be materialized
736
+ // instead or texels need to be accessed via atomics through a texel pointer.
737
+ // Future work will generalize support to plain images.
738
+ auto convertedPointeeType = cast<spirv::PointerType>(
739
+ getTypeConverter ()->convertType (loadOp.getMemRefType ()));
740
+ if (!isa<spirv::SampledImageType>(convertedPointeeType.getPointeeType ()))
741
+ return rewriter.notifyMatchFailure (loadOp,
742
+ " cannot lower memrefs which do not "
743
+ " convert to SPIR-V sampled images" );
744
+
745
+ // Materialize the lowering.
746
+ Location loc = loadOp->getLoc ();
747
+ auto imageLoadOp =
748
+ spirv::LoadOp::create (rewriter, loc, loadPtr, memoryAccess, alignment);
749
+ // Extract the image from the sampled image.
750
+ auto imageOp = spirv::ImageOp::create (rewriter, loc, imageLoadOp);
751
+
752
+ // Build a vector of coordinates or just a scalar index if we have a 1D image.
753
+ Value coords;
754
+ if (memrefType.getRank () != 1 ) {
755
+ auto coordVectorType = VectorType::get ({loadOp.getMemRefType ().getRank ()},
756
+ adaptor.getIndices ().getType ()[0 ]);
757
+ coords = spirv::CompositeConstructOp::create (rewriter, loc, coordVectorType,
758
+ adaptor.getIndices ());
759
+ } else {
760
+ coords = adaptor.getIndices ()[0 ];
761
+ }
762
+
763
+ // Fetch the value out of the image.
764
+ auto resultVectorType = VectorType::get ({4 }, loadOp.getType ());
765
+ auto fetchOp = spirv::ImageFetchOp::create (
766
+ rewriter, loc, resultVectorType, imageOp, coords,
767
+ mlir::spirv::ImageOperandsAttr{}, ValueRange{});
768
+
769
+ // Note that because OpImageFetch returns a rank 4 vector we need to extract
770
+ // the elements corresponding to the load which will since we only support the
771
+ // R[16|32][f|i|ui] formats will always be the R(red) 0th vector element.
772
+ auto compositeExtractOp =
773
+ spirv::CompositeExtractOp::create (rewriter, loc, fetchOp, 0 );
774
+
775
+ rewriter.replaceOp (loadOp, compositeExtractOp);
776
+ return success ();
777
+ }
778
+
665
779
LogicalResult
666
780
IntStoreOpPattern::matchAndRewrite (memref::StoreOp storeOp, OpAdaptor adaptor,
667
781
ConversionPatternRewriter &rewriter) const {
@@ -953,11 +1067,11 @@ LogicalResult ExtractAlignedPointerAsIndexOpPattern::matchAndRewrite(
953
1067
namespace mlir {
954
1068
void populateMemRefToSPIRVPatterns (const SPIRVTypeConverter &typeConverter,
955
1069
RewritePatternSet &patterns) {
956
- patterns
957
- . add <AllocaOpPattern, AllocOpPattern, AtomicRMWOpPattern ,
958
- DeallocOpPattern, IntLoadOpPattern, IntStoreOpPattern, LoadOpPattern,
959
- MemorySpaceCastOpPattern, StoreOpPattern, ReinterpretCastPattern,
960
- CastPattern, ExtractAlignedPointerAsIndexOpPattern>(
961
- typeConverter, patterns.getContext ());
1070
+ patterns. add <AllocaOpPattern, AllocOpPattern, AtomicRMWOpPattern,
1071
+ DeallocOpPattern, IntLoadOpPattern, ImageLoadOpPattern ,
1072
+ IntStoreOpPattern, LoadOpPattern, MemorySpaceCastOpPattern ,
1073
+ StoreOpPattern, ReinterpretCastPattern, CastPattern ,
1074
+ ExtractAlignedPointerAsIndexOpPattern>(typeConverter,
1075
+ patterns.getContext ());
962
1076
}
963
1077
} // namespace mlir
0 commit comments