Skip to content

Commit 96c8b9e

Browse files
[mlir][memref][spirv] Add SPIR-V Image Lowering (#150978)
Adds an initial conversion in the Memref -> SPIR-V lowering for images. Any memref in the "Image" storage-class/address-space will be considered for lowering to the `!spirv.image` type during Memref to SPIR-V conversion. Initially only the reading of sampled images are support and images are read via the `OpImageFetch` instruction. Future work should expand the conversion patterns to target non-sampled images and add support for image write operations. Images are supported for fp32, fp16, int32, uint32, int16 and uint16 types and lit tests have been added to verify this is the case along with negative testing to check the cases where images aren't supported. --------- Signed-off-by: Jack Frankland <[email protected]>
1 parent 19803d8 commit 96c8b9e

File tree

3 files changed

+382
-6
lines changed

3 files changed

+382
-6
lines changed

mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp

Lines changed: 120 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -244,6 +244,16 @@ class LoadOpPattern final : public OpConversionPattern<memref::LoadOp> {
244244
ConversionPatternRewriter &rewriter) const override;
245245
};
246246

247+
/// Converts memref.load to spirv.Image + spirv.ImageFetch
248+
class ImageLoadOpPattern final : public OpConversionPattern<memref::LoadOp> {
249+
public:
250+
using OpConversionPattern::OpConversionPattern;
251+
252+
LogicalResult
253+
matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
254+
ConversionPatternRewriter &rewriter) const override;
255+
};
256+
247257
/// Converts memref.store to spirv.Store on integers.
248258
class IntStoreOpPattern final : public OpConversionPattern<memref::StoreOp> {
249259
public:
@@ -528,6 +538,17 @@ IntLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
528538
if (!memrefType.getElementType().isSignlessInteger())
529539
return failure();
530540

541+
auto memorySpaceAttr =
542+
dyn_cast_if_present<spirv::StorageClassAttr>(memrefType.getMemorySpace());
543+
if (!memorySpaceAttr)
544+
return rewriter.notifyMatchFailure(
545+
loadOp, "missing memory space SPIR-V storage class attribute");
546+
547+
if (memorySpaceAttr.getValue() == spirv::StorageClass::Image)
548+
return rewriter.notifyMatchFailure(
549+
loadOp,
550+
"failed to lower memref in image storage class to storage buffer");
551+
531552
const auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
532553
Value accessChain =
533554
spirv::getElementPtr(typeConverter, memrefType, adaptor.getMemref(),
@@ -644,6 +665,18 @@ LoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
644665
auto memrefType = cast<MemRefType>(loadOp.getMemref().getType());
645666
if (memrefType.getElementType().isSignlessInteger())
646667
return failure();
668+
669+
auto memorySpaceAttr =
670+
dyn_cast_if_present<spirv::StorageClassAttr>(memrefType.getMemorySpace());
671+
if (!memorySpaceAttr)
672+
return rewriter.notifyMatchFailure(
673+
loadOp, "missing memory space SPIR-V storage class attribute");
674+
675+
if (memorySpaceAttr.getValue() == spirv::StorageClass::Image)
676+
return rewriter.notifyMatchFailure(
677+
loadOp,
678+
"failed to lower memref in image storage class to storage buffer");
679+
647680
Value loadPtr = spirv::getElementPtr(
648681
*getTypeConverter<SPIRVTypeConverter>(), memrefType, adaptor.getMemref(),
649682
adaptor.getIndices(), loadOp.getLoc(), rewriter);
@@ -662,6 +695,87 @@ LoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
662695
return success();
663696
}
664697

698+
LogicalResult
699+
ImageLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
700+
ConversionPatternRewriter &rewriter) const {
701+
auto memrefType = cast<MemRefType>(loadOp.getMemref().getType());
702+
703+
auto memorySpaceAttr =
704+
dyn_cast_if_present<spirv::StorageClassAttr>(memrefType.getMemorySpace());
705+
if (!memorySpaceAttr)
706+
return rewriter.notifyMatchFailure(
707+
loadOp, "missing memory space SPIR-V storage class attribute");
708+
709+
if (memorySpaceAttr.getValue() != spirv::StorageClass::Image)
710+
return rewriter.notifyMatchFailure(
711+
loadOp, "failed to lower memref in non-image storage class to image");
712+
713+
Value loadPtr = adaptor.getMemref();
714+
auto memoryRequirements = calculateMemoryRequirements(loadPtr, loadOp);
715+
if (failed(memoryRequirements))
716+
return rewriter.notifyMatchFailure(
717+
loadOp, "failed to determine memory requirements");
718+
719+
const auto [memoryAccess, alignment] = *memoryRequirements;
720+
721+
if (!loadOp.getMemRefType().hasRank())
722+
return rewriter.notifyMatchFailure(
723+
loadOp, "cannot lower unranked memrefs to SPIR-V images");
724+
725+
// We currently only support lowering of scalar memref elements to texels in
726+
// the R[16|32][f|i|ui] formats. Future work will enable lowering of vector
727+
// elements to texels in richer formats.
728+
if (!isa<spirv::ScalarType>(loadOp.getMemRefType().getElementType()))
729+
return rewriter.notifyMatchFailure(
730+
loadOp,
731+
"cannot lower memrefs who's element type is not a SPIR-V scalar type"
732+
"to SPIR-V images");
733+
734+
// We currently only support sampled images since OpImageFetch does not work
735+
// for plain images and the OpImageRead instruction needs to be materialized
736+
// instead or texels need to be accessed via atomics through a texel pointer.
737+
// Future work will generalize support to plain images.
738+
auto convertedPointeeType = cast<spirv::PointerType>(
739+
getTypeConverter()->convertType(loadOp.getMemRefType()));
740+
if (!isa<spirv::SampledImageType>(convertedPointeeType.getPointeeType()))
741+
return rewriter.notifyMatchFailure(loadOp,
742+
"cannot lower memrefs which do not "
743+
"convert to SPIR-V sampled images");
744+
745+
// Materialize the lowering.
746+
Location loc = loadOp->getLoc();
747+
auto imageLoadOp =
748+
spirv::LoadOp::create(rewriter, loc, loadPtr, memoryAccess, alignment);
749+
// Extract the image from the sampled image.
750+
auto imageOp = spirv::ImageOp::create(rewriter, loc, imageLoadOp);
751+
752+
// Build a vector of coordinates or just a scalar index if we have a 1D image.
753+
Value coords;
754+
if (memrefType.getRank() != 1) {
755+
auto coordVectorType = VectorType::get({loadOp.getMemRefType().getRank()},
756+
adaptor.getIndices().getType()[0]);
757+
coords = spirv::CompositeConstructOp::create(rewriter, loc, coordVectorType,
758+
adaptor.getIndices());
759+
} else {
760+
coords = adaptor.getIndices()[0];
761+
}
762+
763+
// Fetch the value out of the image.
764+
auto resultVectorType = VectorType::get({4}, loadOp.getType());
765+
auto fetchOp = spirv::ImageFetchOp::create(
766+
rewriter, loc, resultVectorType, imageOp, coords,
767+
mlir::spirv::ImageOperandsAttr{}, ValueRange{});
768+
769+
// Note that because OpImageFetch returns a rank 4 vector we need to extract
770+
// the elements corresponding to the load which will since we only support the
771+
// R[16|32][f|i|ui] formats will always be the R(red) 0th vector element.
772+
auto compositeExtractOp =
773+
spirv::CompositeExtractOp::create(rewriter, loc, fetchOp, 0);
774+
775+
rewriter.replaceOp(loadOp, compositeExtractOp);
776+
return success();
777+
}
778+
665779
LogicalResult
666780
IntStoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
667781
ConversionPatternRewriter &rewriter) const {
@@ -953,11 +1067,11 @@ LogicalResult ExtractAlignedPointerAsIndexOpPattern::matchAndRewrite(
9531067
namespace mlir {
9541068
void populateMemRefToSPIRVPatterns(const SPIRVTypeConverter &typeConverter,
9551069
RewritePatternSet &patterns) {
956-
patterns
957-
.add<AllocaOpPattern, AllocOpPattern, AtomicRMWOpPattern,
958-
DeallocOpPattern, IntLoadOpPattern, IntStoreOpPattern, LoadOpPattern,
959-
MemorySpaceCastOpPattern, StoreOpPattern, ReinterpretCastPattern,
960-
CastPattern, ExtractAlignedPointerAsIndexOpPattern>(
961-
typeConverter, patterns.getContext());
1070+
patterns.add<AllocaOpPattern, AllocOpPattern, AtomicRMWOpPattern,
1071+
DeallocOpPattern, IntLoadOpPattern, ImageLoadOpPattern,
1072+
IntStoreOpPattern, LoadOpPattern, MemorySpaceCastOpPattern,
1073+
StoreOpPattern, ReinterpretCastPattern, CastPattern,
1074+
ExtractAlignedPointerAsIndexOpPattern>(typeConverter,
1075+
patterns.getContext());
9621076
}
9631077
} // namespace mlir

mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -608,6 +608,45 @@ static Type convertSubByteMemrefType(const spirv::TargetEnv &targetEnv,
608608
return wrapInStructAndGetPointer(arrayType, storageClass);
609609
}
610610

611+
static spirv::Dim convertRank(int64_t rank) {
612+
switch (rank) {
613+
case 1:
614+
return spirv::Dim::Dim1D;
615+
case 2:
616+
return spirv::Dim::Dim2D;
617+
case 3:
618+
return spirv::Dim::Dim3D;
619+
default:
620+
llvm_unreachable("Invalid memref rank!");
621+
}
622+
}
623+
624+
static spirv::ImageFormat getImageFormat(Type elementType) {
625+
return llvm::TypeSwitch<Type, spirv::ImageFormat>(elementType)
626+
.Case<Float16Type>([](Float16Type) { return spirv::ImageFormat::R16f; })
627+
.Case<Float32Type>([](Float32Type) { return spirv::ImageFormat::R32f; })
628+
.Case<IntegerType>([](IntegerType intType) {
629+
auto const isSigned = intType.isSigned() || intType.isSignless();
630+
#define BIT_WIDTH_CASE(BIT_WIDTH) \
631+
case BIT_WIDTH: \
632+
return isSigned ? spirv::ImageFormat::R##BIT_WIDTH##i \
633+
: spirv::ImageFormat::R##BIT_WIDTH##ui
634+
635+
switch (intType.getWidth()) {
636+
BIT_WIDTH_CASE(16);
637+
BIT_WIDTH_CASE(32);
638+
default:
639+
llvm_unreachable("Unhandled integer type!");
640+
}
641+
})
642+
.Default([](Type) {
643+
llvm_unreachable("Unhandled element type!");
644+
// We need to return something here to satisfy the type switch.
645+
return spirv::ImageFormat::R32f;
646+
});
647+
#undef BIT_WIDTH_CASE
648+
}
649+
611650
static Type convertMemrefType(const spirv::TargetEnv &targetEnv,
612651
const SPIRVConversionOptions &options,
613652
MemRefType type) {
@@ -623,6 +662,41 @@ static Type convertMemrefType(const spirv::TargetEnv &targetEnv,
623662
}
624663
spirv::StorageClass storageClass = attr.getValue();
625664

665+
// Images are a special case since they are an opaque type from which elements
666+
// may be accessed via image specific ops or directly through a texture
667+
// pointer.
668+
if (storageClass == spirv::StorageClass::Image) {
669+
const int64_t rank = type.getRank();
670+
if (rank < 1 || rank > 3) {
671+
LLVM_DEBUG(llvm::dbgs()
672+
<< type << " illegal: cannot lower memref of rank " << rank
673+
<< " to a SPIR-V Image\n");
674+
return nullptr;
675+
}
676+
677+
// Note that we currently only support lowering to single element texels
678+
// e.g. R32f.
679+
auto elementType = type.getElementType();
680+
if (!isa<spirv::ScalarType>(elementType)) {
681+
LLVM_DEBUG(llvm::dbgs() << type << " illegal: cannot lower memref of "
682+
<< elementType << " to a SPIR-V Image\n");
683+
return nullptr;
684+
}
685+
686+
// Currently every memref in the image storage class is converted to a
687+
// sampled image so we can hardcode the NeedSampler field. Future work
688+
// will generalize this to support regular non-sampled images.
689+
auto spvImageType = spirv::ImageType::get(
690+
elementType, convertRank(rank), spirv::ImageDepthInfo::DepthUnknown,
691+
spirv::ImageArrayedInfo::NonArrayed,
692+
spirv::ImageSamplingInfo::SingleSampled,
693+
spirv::ImageSamplerUseInfo::NeedSampler, getImageFormat(elementType));
694+
auto spvSampledImageType = spirv::SampledImageType::get(spvImageType);
695+
auto imagePtrType = spirv::PointerType::get(
696+
spvSampledImageType, spirv::StorageClass::UniformConstant);
697+
return imagePtrType;
698+
}
699+
626700
if (isa<IntegerType>(type.getElementType())) {
627701
if (type.getElementTypeBitWidth() == 1)
628702
return convertBoolMemrefType(targetEnv, options, type, storageClass);

0 commit comments

Comments
 (0)