Skip to content

Commit 9999ea9

Browse files
committed
[mlir][memref][spirv]: Address Feedback
* Outline lambdas to helper functions. * Removed incorrect `const` usage. * Generalize error conditions to SPIR-V scalar types. * Create local variable for debug location. Signed-off-by: Jack Frankland <[email protected]>
1 parent a0c68e4 commit 9999ea9

File tree

2 files changed

+56
-57
lines changed

2 files changed

+56
-57
lines changed

mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -724,10 +724,11 @@ ImageLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
724724
// We currently only support lowering of scalar memref elements to texels in
725725
// the R[16|32][f|i|ui] formats. Future work will enable lowering of vector
726726
// elements to texels in richer formats.
727-
if (!loadOp.getMemRefType().getElementType().isIntOrFloat())
727+
if (!isa<spirv::ScalarType>(loadOp.getMemRefType().getElementType()))
728728
return rewriter.notifyMatchFailure(
729-
loadOp, "cannot lower memrefs who's element type is not int or float "
730-
"to SPIR-V images");
729+
loadOp,
730+
"cannot lower memrefs who's element type is not a SPIR-V scalar type"
731+
"to SPIR-V images");
731732

732733
// We currently only support sampled images since OpImageFetch does not work
733734
// for plain images and the OpImageRead instruction needs to be materialized
@@ -741,34 +742,34 @@ ImageLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
741742
"convert to SPIR-V sampled images");
742743

743744
// Materialize the lowering.
744-
auto imageLoadOp = spirv::LoadOp::create(rewriter, loadOp->getLoc(), loadPtr,
745-
memoryAccess, alignment);
745+
Location loc = loadOp->getLoc();
746+
auto imageLoadOp =
747+
spirv::LoadOp::create(rewriter, loc, loadPtr, memoryAccess, alignment);
746748
// Extract the image from the sampled image.
747-
auto imageOp =
748-
spirv::ImageOp::create(rewriter, loadOp->getLoc(), imageLoadOp);
749+
auto imageOp = spirv::ImageOp::create(rewriter, loc, imageLoadOp);
749750

750751
// Build a vector of coordinates or just a scalar index if we have a 1D image.
751752
Value coords;
752753
if (memrefType.getRank() != 1) {
753-
const auto coordVectorType = VectorType::get(
754-
{loadOp.getMemRefType().getRank()}, adaptor.getIndices().getType()[0]);
755-
coords = spirv::CompositeConstructOp::create(
756-
rewriter, loadOp->getLoc(), coordVectorType, adaptor.getIndices());
754+
auto coordVectorType = VectorType::get({loadOp.getMemRefType().getRank()},
755+
adaptor.getIndices().getType()[0]);
756+
coords = spirv::CompositeConstructOp::create(rewriter, loc, coordVectorType,
757+
adaptor.getIndices());
757758
} else {
758759
coords = adaptor.getIndices()[0];
759760
}
760761

761762
// Fetch the value out of the image.
762763
auto resultVectorType = VectorType::get({4}, loadOp.getType());
763764
auto fetchOp = spirv::ImageFetchOp::create(
764-
rewriter, loadOp->getLoc(), resultVectorType, imageOp, coords,
765+
rewriter, loc, resultVectorType, imageOp, coords,
765766
mlir::spirv::ImageOperandsAttr{}, ValueRange{});
766767

767768
// Note that because OpImageFetch returns a rank 4 vector we need to extract
768769
// the elements corresponding to the load which will since we only support the
769770
// R[16|32][f|i|ui] formats will always be the R(red) 0th vector element.
770771
auto compositeExtractOp =
771-
spirv::CompositeExtractOp::create(rewriter, loadOp->getLoc(), fetchOp, 0);
772+
spirv::CompositeExtractOp::create(rewriter, loc, fetchOp, 0);
772773

773774
rewriter.replaceOp(loadOp, compositeExtractOp);
774775
return success();

mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp

Lines changed: 42 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -560,6 +560,45 @@ static Type convertSubByteMemrefType(const spirv::TargetEnv &targetEnv,
560560
return wrapInStructAndGetPointer(arrayType, storageClass);
561561
}
562562

563+
static spirv::Dim convertRank(int64_t rank) {
564+
switch (rank) {
565+
case 1:
566+
return spirv::Dim::Dim1D;
567+
case 2:
568+
return spirv::Dim::Dim2D;
569+
case 3:
570+
return spirv::Dim::Dim3D;
571+
default:
572+
llvm_unreachable("Invalid memref rank!");
573+
}
574+
}
575+
576+
static spirv::ImageFormat getImageFormat(Type elementType) {
577+
return llvm::TypeSwitch<Type, spirv::ImageFormat>(elementType)
578+
.Case<Float16Type>([](Float16Type) { return spirv::ImageFormat::R16f; })
579+
.Case<Float32Type>([](Float32Type) { return spirv::ImageFormat::R32f; })
580+
.Case<IntegerType>([](IntegerType intType) {
581+
auto const isSigned = intType.isSigned() || intType.isSignless();
582+
#define BIT_WIDTH_CASE(BIT_WIDTH) \
583+
case BIT_WIDTH: \
584+
return isSigned ? spirv::ImageFormat::R##BIT_WIDTH##i \
585+
: spirv::ImageFormat::R##BIT_WIDTH##ui
586+
587+
switch (intType.getWidth()) {
588+
BIT_WIDTH_CASE(16);
589+
BIT_WIDTH_CASE(32);
590+
default:
591+
llvm_unreachable("Unhandled integer type!");
592+
}
593+
})
594+
.Default([](Type) {
595+
llvm_unreachable("Unhandled element type!");
596+
// We need to return something here to satisfy the type switch.
597+
return spirv::ImageFormat::R32f;
598+
});
599+
#undef BIT_WIDTH_CASE
600+
}
601+
563602
static Type convertMemrefType(const spirv::TargetEnv &targetEnv,
564603
const SPIRVConversionOptions &options,
565604
MemRefType type) {
@@ -587,64 +626,23 @@ static Type convertMemrefType(const spirv::TargetEnv &targetEnv,
587626
return nullptr;
588627
}
589628

590-
const auto dim = [rank]() {
591-
switch (rank) {
592-
case 1:
593-
return spirv::Dim::Dim1D;
594-
case 2:
595-
return spirv::Dim::Dim2D;
596-
case 3:
597-
return spirv::Dim::Dim3D;
598-
default:
599-
llvm_unreachable("Invalid memref rank!");
600-
}
601-
}();
602-
603629
// Note that we currently only support lowering to single element texels
604630
// e.g. R32f.
605631
auto elementType = type.getElementType();
606-
if (!elementType.isIntOrFloat()) {
632+
if (!isa<spirv::ScalarType>(elementType)) {
607633
LLVM_DEBUG(llvm::dbgs() << type << " illegal: cannot lower memref of "
608634
<< elementType << " to a SPIR-V Image\n");
609635
return nullptr;
610636
}
611637

612-
const auto imageFormat = [&elementType]() {
613-
return llvm::TypeSwitch<Type, spirv::ImageFormat>(elementType)
614-
.Case<Float16Type>(
615-
[](Float16Type) { return spirv::ImageFormat::R16f; })
616-
.Case<Float32Type>(
617-
[](Float32Type) { return spirv::ImageFormat::R32f; })
618-
.Case<IntegerType>([](IntegerType intType) {
619-
auto const isSigned = intType.isSigned() || intType.isSignless();
620-
#define BIT_WIDTH_CASE(BIT_WIDTH) \
621-
case BIT_WIDTH: \
622-
return isSigned ? spirv::ImageFormat::R##BIT_WIDTH##i \
623-
: spirv::ImageFormat::R##BIT_WIDTH##ui
624-
625-
switch (intType.getWidth()) {
626-
BIT_WIDTH_CASE(16);
627-
BIT_WIDTH_CASE(32);
628-
default:
629-
llvm_unreachable("Unhandled integer type!");
630-
}
631-
})
632-
.Default([](Type) {
633-
llvm_unreachable("Unhandled element type!");
634-
// We need to return something here to satisfy the type switch.
635-
return spirv::ImageFormat::R32f;
636-
});
637-
#undef BIT_WIDTH_CASE
638-
}();
639-
640638
// Currently every memref in the image storage class is converted to a
641639
// sampled image so we can hardcode the NeedSampler field. Future work
642640
// will generalize this to support regular non-sampled images.
643641
auto spvImageType = spirv::ImageType::get(
644-
elementType, dim, spirv::ImageDepthInfo::DepthUnknown,
642+
elementType, convertRank(rank), spirv::ImageDepthInfo::DepthUnknown,
645643
spirv::ImageArrayedInfo::NonArrayed,
646644
spirv::ImageSamplingInfo::SingleSampled,
647-
spirv::ImageSamplerUseInfo::NeedSampler, imageFormat);
645+
spirv::ImageSamplerUseInfo::NeedSampler, getImageFormat(elementType));
648646
auto spvSampledImageType = spirv::SampledImageType::get(spvImageType);
649647
auto imagePtrType = spirv::PointerType::get(
650648
spvSampledImageType, spirv::StorageClass::UniformConstant);

0 commit comments

Comments
 (0)