@@ -560,6 +560,45 @@ static Type convertSubByteMemrefType(const spirv::TargetEnv &targetEnv,
560560 return wrapInStructAndGetPointer (arrayType, storageClass);
561561}
562562
563+ static spirv::Dim convertRank (int64_t rank) {
564+ switch (rank) {
565+ case 1 :
566+ return spirv::Dim::Dim1D;
567+ case 2 :
568+ return spirv::Dim::Dim2D;
569+ case 3 :
570+ return spirv::Dim::Dim3D;
571+ default :
572+ llvm_unreachable (" Invalid memref rank!" );
573+ }
574+ }
575+
576+ static spirv::ImageFormat getImageFormat (Type elementType) {
577+ return llvm::TypeSwitch<Type, spirv::ImageFormat>(elementType)
578+ .Case <Float16Type>([](Float16Type) { return spirv::ImageFormat::R16f; })
579+ .Case <Float32Type>([](Float32Type) { return spirv::ImageFormat::R32f; })
580+ .Case <IntegerType>([](IntegerType intType) {
581+ auto const isSigned = intType.isSigned () || intType.isSignless ();
582+ #define BIT_WIDTH_CASE (BIT_WIDTH ) \
583+ case BIT_WIDTH: \
584+ return isSigned ? spirv::ImageFormat::R##BIT_WIDTH##i \
585+ : spirv::ImageFormat::R##BIT_WIDTH##ui
586+
587+ switch (intType.getWidth ()) {
588+ BIT_WIDTH_CASE (16 );
589+ BIT_WIDTH_CASE (32 );
590+ default :
591+ llvm_unreachable (" Unhandled integer type!" );
592+ }
593+ })
594+ .Default ([](Type) {
595+ llvm_unreachable (" Unhandled element type!" );
596+ // We need to return something here to satisfy the type switch.
597+ return spirv::ImageFormat::R32f;
598+ });
599+ #undef BIT_WIDTH_CASE
600+ }
601+
563602static Type convertMemrefType (const spirv::TargetEnv &targetEnv,
564603 const SPIRVConversionOptions &options,
565604 MemRefType type) {
@@ -587,64 +626,23 @@ static Type convertMemrefType(const spirv::TargetEnv &targetEnv,
587626 return nullptr ;
588627 }
589628
590- const auto dim = [rank]() {
591- switch (rank) {
592- case 1 :
593- return spirv::Dim::Dim1D;
594- case 2 :
595- return spirv::Dim::Dim2D;
596- case 3 :
597- return spirv::Dim::Dim3D;
598- default :
599- llvm_unreachable (" Invalid memref rank!" );
600- }
601- }();
602-
603629 // Note that we currently only support lowering to single element texels
604630 // e.g. R32f.
605631 auto elementType = type.getElementType ();
606- if (!elementType. isIntOrFloat ( )) {
632+ if (!isa<spirv::ScalarType>(elementType )) {
607633 LLVM_DEBUG (llvm::dbgs () << type << " illegal: cannot lower memref of "
608634 << elementType << " to a SPIR-V Image\n " );
609635 return nullptr ;
610636 }
611637
612- const auto imageFormat = [&elementType]() {
613- return llvm::TypeSwitch<Type, spirv::ImageFormat>(elementType)
614- .Case <Float16Type>(
615- [](Float16Type) { return spirv::ImageFormat::R16f; })
616- .Case <Float32Type>(
617- [](Float32Type) { return spirv::ImageFormat::R32f; })
618- .Case <IntegerType>([](IntegerType intType) {
619- auto const isSigned = intType.isSigned () || intType.isSignless ();
620- #define BIT_WIDTH_CASE (BIT_WIDTH ) \
621- case BIT_WIDTH: \
622- return isSigned ? spirv::ImageFormat::R##BIT_WIDTH##i \
623- : spirv::ImageFormat::R##BIT_WIDTH##ui
624-
625- switch (intType.getWidth ()) {
626- BIT_WIDTH_CASE (16 );
627- BIT_WIDTH_CASE (32 );
628- default :
629- llvm_unreachable (" Unhandled integer type!" );
630- }
631- })
632- .Default ([](Type) {
633- llvm_unreachable (" Unhandled element type!" );
634- // We need to return something here to satisfy the type switch.
635- return spirv::ImageFormat::R32f;
636- });
637- #undef BIT_WIDTH_CASE
638- }();
639-
640638 // Currently every memref in the image storage class is converted to a
641639 // sampled image so we can hardcode the NeedSampler field. Future work
642640 // will generalize this to support regular non-sampled images.
643641 auto spvImageType = spirv::ImageType::get (
644- elementType, dim , spirv::ImageDepthInfo::DepthUnknown,
642+ elementType, convertRank (rank) , spirv::ImageDepthInfo::DepthUnknown,
645643 spirv::ImageArrayedInfo::NonArrayed,
646644 spirv::ImageSamplingInfo::SingleSampled,
647- spirv::ImageSamplerUseInfo::NeedSampler, imageFormat );
645+ spirv::ImageSamplerUseInfo::NeedSampler, getImageFormat (elementType) );
648646 auto spvSampledImageType = spirv::SampledImageType::get (spvImageType);
649647 auto imagePtrType = spirv::PointerType::get (
650648 spvSampledImageType, spirv::StorageClass::UniformConstant);
0 commit comments