Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 100 additions & 6 deletions mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,16 @@ class LoadOpPattern final : public OpConversionPattern<memref::LoadOp> {
ConversionPatternRewriter &rewriter) const override;
};

/// Converts memref.load to spirv.Image + spirv.ImageFetch
class ImageLoadOpPattern final : public OpConversionPattern<memref::LoadOp> {
public:
using OpConversionPattern::OpConversionPattern;

LogicalResult
matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override;
};

/// Converts memref.store to spirv.Store on integers.
class IntStoreOpPattern final : public OpConversionPattern<memref::StoreOp> {
public:
Expand Down Expand Up @@ -527,6 +537,11 @@ IntLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
if (!memrefType.getElementType().isSignlessInteger())
return failure();

if (memrefType.getMemorySpace() ==
spirv::StorageClassAttr::get(rewriter.getContext(),
spirv::StorageClass::Image))
return failure();

const auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
Value accessChain =
spirv::getElementPtr(typeConverter, memrefType, adaptor.getMemref(),
Expand Down Expand Up @@ -643,6 +658,12 @@ LoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
auto memrefType = cast<MemRefType>(loadOp.getMemref().getType());
if (memrefType.getElementType().isSignlessInteger())
return failure();

if (memrefType.getMemorySpace() ==
spirv::StorageClassAttr::get(rewriter.getContext(),
spirv::StorageClass::Image))
return failure();

Value loadPtr = spirv::getElementPtr(
*getTypeConverter<SPIRVTypeConverter>(), memrefType, adaptor.getMemref(),
adaptor.getIndices(), loadOp.getLoc(), rewriter);
Expand All @@ -661,6 +682,79 @@ LoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
return success();
}

LogicalResult
ImageLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
auto memrefType = cast<MemRefType>(loadOp.getMemref().getType());
if (memrefType.getMemorySpace() !=
spirv::StorageClassAttr::get(rewriter.getContext(),
spirv::StorageClass::Image))
return failure();

auto loadPtr = adaptor.getMemref();
auto memoryRequirements = calculateMemoryRequirements(loadPtr, loadOp);
if (failed(memoryRequirements))
return rewriter.notifyMatchFailure(
loadOp, "failed to determine memory requirements");

const auto [memoryAccess, alignment] = *memoryRequirements;

if (!loadOp.getMemRefType().hasRank())
return rewriter.notifyMatchFailure(
loadOp, "cannot lower unranked memrefs to SPIR-V images");

// We currently only support lowering of scalar memref elements to texels in
// the R[16|32][f|i|ui] formats. Future work will enable lowering of vector
// elements to texels in richer formats.
if (!loadOp.getMemRefType().getElementType().isIntOrFloat())
return rewriter.notifyMatchFailure(
loadOp, "cannot lower memrefs who's element type is not int or float "
"to SPIR-V images");

// We currently only support sampled images since OpImageFetch does not work
// for plain images and the OpImageRead instruction needs to be materialized
// instead or texels need to be accessed via atomics through a texel pointer.
// Future work will generalize support to plain images.
if (auto convertedPointeeType = cast<spirv::PointerType>(
getTypeConverter()->convertType(loadOp.getMemRefType()));
!isa<spirv::SampledImageType>(convertedPointeeType.getPointeeType()))
return rewriter.notifyMatchFailure(loadOp,
"cannot lower memrefs which do not "
"convert to SPIR-V sampled images");

// Materialize the lowering.
auto imageLoadOp = rewriter.create<spirv::LoadOp>(loadOp->getLoc(), loadPtr,
memoryAccess, alignment);
// Extract the image from the sampled image.
auto imageOp = rewriter.create<spirv::ImageOp>(loadOp->getLoc(), imageLoadOp);

// Build a vector of coordinates or just a scalar index if we have a 1D image.
Value coords;
if (memrefType.getRank() != 1) {
const auto coordVectorType = VectorType::get(
{loadOp.getMemRefType().getRank()}, adaptor.getIndices().getType()[0]);
coords = rewriter.create<spirv::CompositeConstructOp>(
loadOp->getLoc(), coordVectorType, adaptor.getIndices());
} else {
coords = adaptor.getIndices()[0];
}

// Fetch the value out of the image.
auto resultVectorType = VectorType::get({4}, loadOp.getType());
auto fetchOp = rewriter.create<spirv::ImageFetchOp>(
loadOp->getLoc(), resultVectorType, imageOp, coords,
mlir::spirv::ImageOperandsAttr{}, ValueRange{});

// Note that because OpImageFetch returns a rank 4 vector we need to extract
// the elements corresponding to the load which will since we only support the
// R[16|32][f|i|ui] formats will always be the R(red) 0th vector element.
auto compositeExtractOp =
rewriter.create<spirv::CompositeExtractOp>(loadOp->getLoc(), fetchOp, 0);

rewriter.replaceOp(loadOp, compositeExtractOp);
return success();
}

LogicalResult
IntStoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
Expand Down Expand Up @@ -952,11 +1046,11 @@ LogicalResult ExtractAlignedPointerAsIndexOpPattern::matchAndRewrite(
namespace mlir {
void populateMemRefToSPIRVPatterns(const SPIRVTypeConverter &typeConverter,
RewritePatternSet &patterns) {
patterns
.add<AllocaOpPattern, AllocOpPattern, AtomicRMWOpPattern,
DeallocOpPattern, IntLoadOpPattern, IntStoreOpPattern, LoadOpPattern,
MemorySpaceCastOpPattern, StoreOpPattern, ReinterpretCastPattern,
CastPattern, ExtractAlignedPointerAsIndexOpPattern>(
typeConverter, patterns.getContext());
patterns.add<AllocaOpPattern, AllocOpPattern, AtomicRMWOpPattern,
DeallocOpPattern, IntLoadOpPattern, ImageLoadOpPattern,
IntStoreOpPattern, LoadOpPattern, MemorySpaceCastOpPattern,
StoreOpPattern, ReinterpretCastPattern, CastPattern,
ExtractAlignedPointerAsIndexOpPattern>(typeConverter,
patterns.getContext());
}
} // namespace mlir
77 changes: 77 additions & 0 deletions mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -575,6 +575,83 @@ static Type convertMemrefType(const spirv::TargetEnv &targetEnv,
}
spirv::StorageClass storageClass = attr.getValue();

// Images are a special case since they are an opaque type from which elements
// may be accessed via image specific ops or directly through a texture
// pointer.
if (storageClass == spirv::StorageClass::Image) {
const int64_t rank = type.getRank();
if (rank < 1 || rank > 3) {
LLVM_DEBUG(llvm::dbgs()
<< type << " illegal: cannot lower memref of rank " << rank
<< " to a SPIR-V Image\n");
return nullptr;
}

const auto dim = [rank]() {
#define DIM_CASE(DIM) \
case DIM: \
return spirv::Dim::Dim##DIM##D
switch (rank) {
DIM_CASE(1);
DIM_CASE(2);
DIM_CASE(3);
default:
llvm_unreachable("Invalid memref rank!");
}
#undef DIM_CASE
}();

// Note that we currently only support lowering to single element texels
// e.g. R32f.
auto elementType = type.getElementType();
if (!elementType.isIntOrFloat()) {
LLVM_DEBUG(llvm::dbgs() << type << " illegal: cannot lower memref of "
<< elementType << " to a SPIR-V Image\n");
return nullptr;
}

const auto imageFormat = [&elementType]() {
return llvm::TypeSwitch<Type, spirv::ImageFormat>(elementType)
.Case<Float16Type>(
[](Float16Type) { return spirv::ImageFormat::R16f; })
.Case<Float32Type>(
[](Float32Type) { return spirv::ImageFormat::R32f; })
.Case<IntegerType>([](IntegerType intType) {
auto const isSigned = intType.isSigned() || intType.isSignless();
#define BIT_WIDTH_CASE(BIT_WIDTH) \
case BIT_WIDTH: \
return isSigned ? spirv::ImageFormat::R##BIT_WIDTH##i \
: spirv::ImageFormat::R##BIT_WIDTH##ui

switch (intType.getWidth()) {
BIT_WIDTH_CASE(16);
BIT_WIDTH_CASE(32);
default:
llvm_unreachable("Unhandled integer type!");
}
})
.Default([](Type) {
llvm_unreachable("Unhandled element type!");
// We need to return something here to satisfy the type switch.
return spirv::ImageFormat::R32f;
});
#undef BIT_WIDTH_CASE
}();

// Currently every memref in the image storage class is converted to a
// sampled image so we can hardcode the NeedSampler field. Future work
// will generalize this to support regular non-sampled images.
auto spvImageType = spirv::ImageType::get(
elementType, dim, spirv::ImageDepthInfo::DepthUnknown,
spirv::ImageArrayedInfo::NonArrayed,
spirv::ImageSamplingInfo::SingleSampled,
spirv::ImageSamplerUseInfo::NeedSampler, imageFormat);
auto spvSampledImageType = spirv::SampledImageType::get(spvImageType);
auto imagePtrType = spirv::PointerType::get(
spvSampledImageType, spirv::StorageClass::UniformConstant);
return imagePtrType;
}

if (isa<IntegerType>(type.getElementType())) {
if (type.getElementTypeBitWidth() == 1)
return convertBoolMemrefType(targetEnv, options, type, storageClass);
Expand Down
Loading