1414#include " mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
1515#include " mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
1616#include " mlir/IR/BuiltinTypes.h"
17+ #include " mlir/Support/LLVM.h"
1718#include " llvm/ADT/STLExtras.h"
1819#include " llvm/ADT/TypeSwitch.h"
20+ #include " llvm/Support/ErrorHandling.h"
1921
2022#include < cstdint>
2123
2224using namespace mlir ;
2325using namespace mlir ::spirv;
2426
27+ namespace {
28+ // Helper function to collect extensions implied by a type by visiting all its
29+ // subtypes. Maintains a set of `seen` types to avoid recursion in structs.
30+ //
31+ // Serves as the source-of-truth for type extension information. All extension
32+ // logic should be added to this class, while the
33+ // `SPIRVType::getExtensions` function should not handle extension-related logic
34+ // directly and only invoke `TypeExtensionVisitor::add(Type *)`.
35+ class TypeExtensionVisitor {
36+ public:
37+ TypeExtensionVisitor (SPIRVType::ExtensionArrayRefVector &extensions,
38+ std::optional<StorageClass> storage)
39+ : extensions(extensions), storage(storage) {}
40+
41+ // Main visitor entry point. Adds all extensions to the vector. Saves `type`
42+ // as seen and dispatches to the right concrete `.add` function.
43+ void add (SPIRVType type) {
44+ if (auto [_it, inserted] = seen.insert ({type, storage}); !inserted)
45+ return ;
46+
47+ TypeSwitch<SPIRVType>(type)
48+ .Case <ScalarType, PointerType, CooperativeMatrixType, TensorArmType>(
49+ [this ](auto concreteType) { addConcrete (concreteType); })
50+ .Case <VectorType, ArrayType, RuntimeArrayType, MatrixType, ImageType>(
51+ [this ](auto concreteType) { add (concreteType.getElementType ()); })
52+ .Case <StructType>([this ](StructType concreteType) {
53+ for (Type elementType : concreteType.getElementTypes ())
54+ add (elementType);
55+ })
56+ .Case <SampledImageType>([this ](SampledImageType concreteType) {
57+ add (concreteType.getImageType ());
58+ })
59+ .Default ([](SPIRVType) { llvm_unreachable (" Unhandled type" ); });
60+ }
61+
62+ void add (Type type) { add (cast<SPIRVType>(type)); }
63+
64+ private:
65+ // Types that add unique extensions.
66+ void addConcrete (ScalarType type);
67+ void addConcrete (PointerType type);
68+ void addConcrete (CooperativeMatrixType type);
69+ void addConcrete (TensorArmType type);
70+
71+ SPIRVType::ExtensionArrayRefVector &extensions;
72+ std::optional<StorageClass> storage;
73+ llvm::SmallDenseSet<std::pair<Type, std::optional<StorageClass>>> seen;
74+ };
75+
76+ } // namespace
77+
2578// ===----------------------------------------------------------------------===//
2679// ArrayType
2780// ===----------------------------------------------------------------------===//
@@ -65,11 +118,6 @@ Type ArrayType::getElementType() const { return getImpl()->elementType; }
65118
66119unsigned ArrayType::getArrayStride () const { return getImpl ()->stride ; }
67120
68- void ArrayType::getExtensions (SPIRVType::ExtensionArrayRefVector &extensions,
69- std::optional<StorageClass> storage) {
70- llvm::cast<SPIRVType>(getElementType ()).getExtensions (extensions, storage);
71- }
72-
73121void ArrayType::getCapabilities (
74122 SPIRVType::CapabilityArrayRefVector &capabilities,
75123 std::optional<StorageClass> storage) {
@@ -140,27 +188,6 @@ bool CompositeType::hasCompileTimeKnownNumElements() const {
140188 return !llvm::isa<CooperativeMatrixType, RuntimeArrayType>(*this );
141189}
142190
143- void CompositeType::getExtensions (
144- SPIRVType::ExtensionArrayRefVector &extensions,
145- std::optional<StorageClass> storage) {
146- TypeSwitch<Type>(*this )
147- .Case <ArrayType, CooperativeMatrixType, MatrixType, RuntimeArrayType,
148- StructType>(
149- [&](auto type) { type.getExtensions (extensions, storage); })
150- .Case <VectorType>([&](VectorType type) {
151- return llvm::cast<ScalarType>(type.getElementType ())
152- .getExtensions (extensions, storage);
153- })
154- .Case <TensorArmType>([&](TensorArmType type) {
155- static constexpr Extension ext{Extension::SPV_ARM_tensors};
156- extensions.push_back (ext);
157- return llvm::cast<ScalarType>(type.getElementType ())
158- .getExtensions (extensions, storage);
159- })
160-
161- .Default ([](Type) { llvm_unreachable (" invalid composite type" ); });
162- }
163-
164191void CompositeType::getCapabilities (
165192 SPIRVType::CapabilityArrayRefVector &capabilities,
166193 std::optional<StorageClass> storage) {
@@ -284,12 +311,10 @@ CooperativeMatrixUseKHR CooperativeMatrixType::getUse() const {
284311 return getImpl ()->use ;
285312}
286313
287- void CooperativeMatrixType::getExtensions (
288- SPIRVType::ExtensionArrayRefVector &extensions,
289- std::optional<StorageClass> storage) {
290- llvm::cast<SPIRVType>(getElementType ()).getExtensions (extensions, storage);
291- static constexpr Extension exts[] = {Extension::SPV_KHR_cooperative_matrix};
292- extensions.push_back (exts);
314+ void TypeExtensionVisitor::addConcrete (CooperativeMatrixType type) {
315+ add (type.getElementType ());
316+ static constexpr auto ext = Extension::SPV_KHR_cooperative_matrix;
317+ extensions.push_back (ext);
293318}
294319
295320void CooperativeMatrixType::getCapabilities (
@@ -403,11 +428,6 @@ ImageSamplerUseInfo ImageType::getSamplerUseInfo() const {
403428
404429ImageFormat ImageType::getImageFormat () const { return getImpl ()->format ; }
405430
406- void ImageType::getExtensions (SPIRVType::ExtensionArrayRefVector &,
407- std::optional<StorageClass>) {
408- // Image types do not require extra extensions thus far.
409- }
410-
411431void ImageType::getCapabilities (
412432 SPIRVType::CapabilityArrayRefVector &capabilities,
413433 std::optional<StorageClass>) {
@@ -454,14 +474,15 @@ StorageClass PointerType::getStorageClass() const {
454474 return getImpl ()->storageClass ;
455475}
456476
457- void PointerType::getExtensions (SPIRVType::ExtensionArrayRefVector &extensions,
458- std::optional<StorageClass> storage) {
477+ void TypeExtensionVisitor::addConcrete (PointerType type) {
459478 // Use this pointer type's storage class because this pointer indicates we are
460479 // using the pointee type in that specific storage class.
461- llvm::cast<SPIRVType>(getPointeeType ())
462- .getExtensions (extensions, getStorageClass ());
480+ std::optional<StorageClass> oldStorageClass = storage;
481+ storage = type.getStorageClass ();
482+ add (type.getPointeeType ());
483+ storage = oldStorageClass;
463484
464- if (auto scExts = spirv::getExtensions (getStorageClass ()))
485+ if (auto scExts = spirv::getExtensions (type. getStorageClass ()))
465486 extensions.push_back (*scExts);
466487}
467488
@@ -513,12 +534,6 @@ Type RuntimeArrayType::getElementType() const { return getImpl()->elementType; }
513534
514535unsigned RuntimeArrayType::getArrayStride () const { return getImpl ()->stride ; }
515536
516- void RuntimeArrayType::getExtensions (
517- SPIRVType::ExtensionArrayRefVector &extensions,
518- std::optional<StorageClass> storage) {
519- llvm::cast<SPIRVType>(getElementType ()).getExtensions (extensions, storage);
520- }
521-
522537void RuntimeArrayType::getCapabilities (
523538 SPIRVType::CapabilityArrayRefVector &capabilities,
524539 std::optional<StorageClass> storage) {
@@ -553,10 +568,9 @@ bool ScalarType::isValid(IntegerType type) {
553568 return llvm::is_contained ({1u , 8u , 16u , 32u , 64u }, type.getWidth ());
554569}
555570
556- void ScalarType::getExtensions (SPIRVType::ExtensionArrayRefVector &extensions,
557- std::optional<StorageClass> storage) {
558- if (isa<BFloat16Type>(*this )) {
559- static const Extension ext = Extension::SPV_KHR_bfloat16;
571+ void TypeExtensionVisitor::addConcrete (ScalarType type) {
572+ if (isa<BFloat16Type>(type)) {
573+ static constexpr auto ext = Extension::SPV_KHR_bfloat16;
560574 extensions.push_back (ext);
561575 }
562576
@@ -570,18 +584,16 @@ void ScalarType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
570584 case StorageClass::PushConstant:
571585 case StorageClass::StorageBuffer:
572586 case StorageClass::Uniform:
573- if (getIntOrFloatBitWidth () == 8 ) {
574- static const Extension exts[] = {Extension::SPV_KHR_8bit_storage};
575- ArrayRef<Extension> ref (exts, std::size (exts));
576- extensions.push_back (ref);
587+ if (type.getIntOrFloatBitWidth () == 8 ) {
588+ static constexpr auto ext = Extension::SPV_KHR_8bit_storage;
589+ extensions.push_back (ext);
577590 }
578591 [[fallthrough]];
579592 case StorageClass::Input:
580593 case StorageClass::Output:
581- if (getIntOrFloatBitWidth () == 16 ) {
582- static const Extension exts[] = {Extension::SPV_KHR_16bit_storage};
583- ArrayRef<Extension> ref (exts, std::size (exts));
584- extensions.push_back (ref);
594+ if (type.getIntOrFloatBitWidth () == 16 ) {
595+ static constexpr auto ext = Extension::SPV_KHR_16bit_storage;
596+ extensions.push_back (ext);
585597 }
586598 break ;
587599 default :
@@ -722,23 +734,7 @@ bool SPIRVType::isScalarOrVector() {
722734
723735void SPIRVType::getExtensions (SPIRVType::ExtensionArrayRefVector &extensions,
724736 std::optional<StorageClass> storage) {
725- if (auto scalarType = llvm::dyn_cast<ScalarType>(*this )) {
726- scalarType.getExtensions (extensions, storage);
727- } else if (auto compositeType = llvm::dyn_cast<CompositeType>(*this )) {
728- compositeType.getExtensions (extensions, storage);
729- } else if (auto imageType = llvm::dyn_cast<ImageType>(*this )) {
730- imageType.getExtensions (extensions, storage);
731- } else if (auto sampledImageType = llvm::dyn_cast<SampledImageType>(*this )) {
732- sampledImageType.getExtensions (extensions, storage);
733- } else if (auto matrixType = llvm::dyn_cast<MatrixType>(*this )) {
734- matrixType.getExtensions (extensions, storage);
735- } else if (auto ptrType = llvm::dyn_cast<PointerType>(*this )) {
736- ptrType.getExtensions (extensions, storage);
737- } else if (auto tensorArmType = llvm::dyn_cast<TensorArmType>(*this )) {
738- tensorArmType.getExtensions (extensions, storage);
739- } else {
740- llvm_unreachable (" invalid SPIR-V Type to getExtensions" );
741- }
737+ TypeExtensionVisitor{extensions, storage}.add (*this );
742738}
743739
744740void SPIRVType::getCapabilities (
@@ -818,12 +814,6 @@ SampledImageType::verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
818814 return success ();
819815}
820816
821- void SampledImageType::getExtensions (
822- SPIRVType::ExtensionArrayRefVector &extensions,
823- std::optional<StorageClass> storage) {
824- llvm::cast<ImageType>(getImageType ()).getExtensions (extensions, storage);
825- }
826-
827817void SampledImageType::getCapabilities (
828818 SPIRVType::CapabilityArrayRefVector &capabilities,
829819 std::optional<StorageClass> storage) {
@@ -1182,12 +1172,6 @@ StructType::trySetBody(ArrayRef<Type> memberTypes,
11821172 structDecorations);
11831173}
11841174
1185- void StructType::getExtensions (SPIRVType::ExtensionArrayRefVector &extensions,
1186- std::optional<StorageClass> storage) {
1187- for (Type elementType : getElementTypes ())
1188- llvm::cast<SPIRVType>(elementType).getExtensions (extensions, storage);
1189- }
1190-
11911175void StructType::getCapabilities (
11921176 SPIRVType::CapabilityArrayRefVector &capabilities,
11931177 std::optional<StorageClass> storage) {
@@ -1287,11 +1271,6 @@ unsigned MatrixType::getNumElements() const {
12871271 return (getImpl ()->columnCount ) * getNumRows ();
12881272}
12891273
1290- void MatrixType::getExtensions (SPIRVType::ExtensionArrayRefVector &extensions,
1291- std::optional<StorageClass> storage) {
1292- llvm::cast<SPIRVType>(getColumnType ()).getExtensions (extensions, storage);
1293- }
1294-
12951274void MatrixType::getCapabilities (
12961275 SPIRVType::CapabilityArrayRefVector &capabilities,
12971276 std::optional<StorageClass> storage) {
@@ -1347,12 +1326,9 @@ TensorArmType TensorArmType::cloneWith(std::optional<ArrayRef<int64_t>> shape,
13471326Type TensorArmType::getElementType () const { return getImpl ()->elementType ; }
13481327ArrayRef<int64_t > TensorArmType::getShape () const { return getImpl ()->shape ; }
13491328
1350- void TensorArmType::getExtensions (
1351- SPIRVType::ExtensionArrayRefVector &extensions,
1352- std::optional<StorageClass> storage) {
1353-
1354- llvm::cast<SPIRVType>(getElementType ()).getExtensions (extensions, storage);
1355- static constexpr Extension ext{Extension::SPV_ARM_tensors};
1329+ void TypeExtensionVisitor::addConcrete (TensorArmType type) {
1330+ add (type.getElementType ());
1331+ static constexpr auto ext = Extension::SPV_ARM_tensors;
13561332 extensions.push_back (ext);
13571333}
13581334
0 commit comments