1414#include " mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
1515#include " mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
1616#include " mlir/IR/BuiltinTypes.h"
17+ #include " mlir/Support/LLVM.h"
1718#include " llvm/ADT/STLExtras.h"
1819#include " llvm/ADT/TypeSwitch.h"
20+ #include " llvm/Support/ErrorHandling.h"
1921
2022#include < cstdint>
2123
2224using namespace mlir ;
2325using namespace mlir ::spirv;
2426
27+ namespace {
28+ // Helper function to collect extensions implied by a type by visiting all its
29+ // subtypes. Maintains a set of `seen` types to avoid recursion in structs.
30+ //
31+ // Serves as the source-of-truth for type extension information. All extension
32+ // logic should be added to this class, while
33+ // `*Type::getExtensions` functions should not handle extension-related logic
34+ // directly and only invoke `TypeExtensionVisitor::add(Type *)`.
35+ class TypeExtensionVisitor {
36+ SPIRVType::ExtensionArrayRefVector &extensions;
37+ std::optional<StorageClass> storage;
38+ DenseSet<Type> seen;
39+
40+ public:
41+ TypeExtensionVisitor (SPIRVType::ExtensionArrayRefVector &extensions,
42+ std::optional<StorageClass> storage)
43+ : extensions(extensions), storage(storage) {}
44+
45+ // Main visitor entry point. Adds all extensions to the vector. Saves `type`
46+ // as seen and dispatches to the right concrete `.add` function.
47+ void add (SPIRVType type) {
48+ if (auto [_it, inserted] = seen.insert (type); !inserted)
49+ return ;
50+
51+ TypeSwitch<SPIRVType>(type)
52+ .Case <ScalarType, PointerType, CooperativeMatrixType, TensorArmType,
53+ VectorType, ArrayType, RuntimeArrayType, StructType, MatrixType,
54+ ImageType, SampledImageType>(
55+ [this ](auto concreteType) { add (concreteType); })
56+ .Default ([](SPIRVType) { llvm_unreachable (" Unhandled type" ); });
57+ }
58+
59+ // Convenience overloads for use in `T::getExtensions` functions.
60+ void add (Type type) { add (cast<SPIRVType>(type)); }
61+ void add (Type *type) { add (cast<SPIRVType>(*type)); }
62+
63+ // Types that add unique extensions.
64+ void add (ScalarType type);
65+ void add (PointerType type);
66+ void add (CooperativeMatrixType type);
67+ void add (TensorArmType type);
68+
69+ // Trivial passthrough without any new extensions.
70+ void add (VectorType type) { add (type.getElementType ()); }
71+ void add (ArrayType type) { add (type.getElementType ()); }
72+ void add (RuntimeArrayType type) { add (type.getElementType ()); }
73+ void add (StructType type) {
74+ for (Type elementType : type.getElementTypes ())
75+ add (elementType);
76+ }
77+ void add (MatrixType type) { add (type.getElementType ()); }
78+ void add (ImageType type) { add (type.getElementType ()); }
79+ void add (SampledImageType type) { add (type.getImageType ()); }
80+ };
81+
82+ } // namespace
83+
2584// ===----------------------------------------------------------------------===//
2685// ArrayType
2786// ===----------------------------------------------------------------------===//
@@ -67,7 +126,7 @@ unsigned ArrayType::getArrayStride() const { return getImpl()->stride; }
67126
68127void ArrayType::getExtensions (SPIRVType::ExtensionArrayRefVector &extensions,
69128 std::optional<StorageClass> storage) {
70- llvm::cast<SPIRVType>( getElementType ()). getExtensions ( extensions, storage);
129+ TypeExtensionVisitor{ extensions, storage}. add ( this );
71130}
72131
73132void ArrayType::getCapabilities (
@@ -143,22 +202,7 @@ bool CompositeType::hasCompileTimeKnownNumElements() const {
143202void CompositeType::getExtensions (
144203 SPIRVType::ExtensionArrayRefVector &extensions,
145204 std::optional<StorageClass> storage) {
146- TypeSwitch<Type>(*this )
147- .Case <ArrayType, CooperativeMatrixType, MatrixType, RuntimeArrayType,
148- StructType>(
149- [&](auto type) { type.getExtensions (extensions, storage); })
150- .Case <VectorType>([&](VectorType type) {
151- return llvm::cast<ScalarType>(type.getElementType ())
152- .getExtensions (extensions, storage);
153- })
154- .Case <TensorArmType>([&](TensorArmType type) {
155- static constexpr Extension ext{Extension::SPV_ARM_tensors};
156- extensions.push_back (ext);
157- return llvm::cast<ScalarType>(type.getElementType ())
158- .getExtensions (extensions, storage);
159- })
160-
161- .Default ([](Type) { llvm_unreachable (" invalid composite type" ); });
205+ TypeExtensionVisitor{extensions, storage}.add (cast<SPIRVType>(*this ));
162206}
163207
164208void CompositeType::getCapabilities (
@@ -284,12 +328,16 @@ CooperativeMatrixUseKHR CooperativeMatrixType::getUse() const {
284328 return getImpl ()->use ;
285329}
286330
331+ void TypeExtensionVisitor::add (CooperativeMatrixType type) {
332+ add (type.getElementType ());
333+ static constexpr auto ext = Extension::SPV_KHR_cooperative_matrix;
334+ extensions.push_back (ext);
335+ }
336+
287337void CooperativeMatrixType::getExtensions (
288338 SPIRVType::ExtensionArrayRefVector &extensions,
289339 std::optional<StorageClass> storage) {
290- llvm::cast<SPIRVType>(getElementType ()).getExtensions (extensions, storage);
291- static constexpr Extension exts[] = {Extension::SPV_KHR_cooperative_matrix};
292- extensions.push_back (exts);
340+ TypeExtensionVisitor{extensions, storage}.add (this );
293341}
294342
295343void CooperativeMatrixType::getCapabilities (
@@ -403,9 +451,9 @@ ImageSamplerUseInfo ImageType::getSamplerUseInfo() const {
403451
404452ImageFormat ImageType::getImageFormat () const { return getImpl ()->format ; }
405453
406- void ImageType::getExtensions (SPIRVType::ExtensionArrayRefVector &,
407- std::optional<StorageClass>) {
408- // Image types do not require extra extensions thus far.
454+ void ImageType::getExtensions (SPIRVType::ExtensionArrayRefVector &extensions ,
455+ std::optional<StorageClass> storage ) {
456+ TypeExtensionVisitor{ extensions, storage}. add ( this );
409457}
410458
411459void ImageType::getCapabilities (
@@ -454,17 +502,23 @@ StorageClass PointerType::getStorageClass() const {
454502 return getImpl ()->storageClass ;
455503}
456504
457- void PointerType::getExtensions (SPIRVType::ExtensionArrayRefVector &extensions,
458- std::optional<StorageClass> storage) {
505+ void TypeExtensionVisitor::add (PointerType type) {
459506 // Use this pointer type's storage class because this pointer indicates we are
460507 // using the pointee type in that specific storage class.
461- llvm::cast<SPIRVType>(getPointeeType ())
462- .getExtensions (extensions, getStorageClass ());
508+ std::optional<StorageClass> oldStorageClass = storage;
509+ storage = type.getStorageClass ();
510+ add (type.getPointeeType ());
511+ storage = oldStorageClass;
463512
464- if (auto scExts = spirv::getExtensions (getStorageClass ()))
513+ if (auto scExts = spirv::getExtensions (type. getStorageClass ()))
465514 extensions.push_back (*scExts);
466515}
467516
517+ void PointerType::getExtensions (SPIRVType::ExtensionArrayRefVector &extensions,
518+ std::optional<StorageClass> storage) {
519+ TypeExtensionVisitor{extensions, storage}.add (this );
520+ }
521+
468522void PointerType::getCapabilities (
469523 SPIRVType::CapabilityArrayRefVector &capabilities,
470524 std::optional<StorageClass> storage) {
@@ -516,7 +570,7 @@ unsigned RuntimeArrayType::getArrayStride() const { return getImpl()->stride; }
516570void RuntimeArrayType::getExtensions (
517571 SPIRVType::ExtensionArrayRefVector &extensions,
518572 std::optional<StorageClass> storage) {
519- llvm::cast<SPIRVType>( getElementType ()). getExtensions ( extensions, storage);
573+ TypeExtensionVisitor{ extensions, storage}. add ( this );
520574}
521575
522576void RuntimeArrayType::getCapabilities (
@@ -553,10 +607,9 @@ bool ScalarType::isValid(IntegerType type) {
553607 return llvm::is_contained ({1u , 8u , 16u , 32u , 64u }, type.getWidth ());
554608}
555609
556- void ScalarType::getExtensions (SPIRVType::ExtensionArrayRefVector &extensions,
557- std::optional<StorageClass> storage) {
558- if (isa<BFloat16Type>(*this )) {
559- static const Extension ext = Extension::SPV_KHR_bfloat16;
610+ void TypeExtensionVisitor::add (ScalarType type) {
611+ if (isa<BFloat16Type>(type)) {
612+ static constexpr auto ext = Extension::SPV_KHR_bfloat16;
560613 extensions.push_back (ext);
561614 }
562615
@@ -570,25 +623,28 @@ void ScalarType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
570623 case StorageClass::PushConstant:
571624 case StorageClass::StorageBuffer:
572625 case StorageClass::Uniform:
573- if (getIntOrFloatBitWidth () == 8 ) {
574- static const Extension exts[] = {Extension::SPV_KHR_8bit_storage};
575- ArrayRef<Extension> ref (exts, std::size (exts));
576- extensions.push_back (ref);
626+ if (type.getIntOrFloatBitWidth () == 8 ) {
627+ static constexpr auto ext = Extension::SPV_KHR_8bit_storage;
628+ extensions.push_back (ext);
577629 }
578630 [[fallthrough]];
579631 case StorageClass::Input:
580632 case StorageClass::Output:
581- if (getIntOrFloatBitWidth () == 16 ) {
582- static const Extension exts[] = {Extension::SPV_KHR_16bit_storage};
583- ArrayRef<Extension> ref (exts, std::size (exts));
584- extensions.push_back (ref);
633+ if (type.getIntOrFloatBitWidth () == 16 ) {
634+ static constexpr auto ext = Extension::SPV_KHR_16bit_storage;
635+ extensions.push_back (ext);
585636 }
586637 break ;
587638 default :
588639 break ;
589640 }
590641}
591642
643+ void ScalarType::getExtensions (SPIRVType::ExtensionArrayRefVector &extensions,
644+ std::optional<StorageClass> storage) {
645+ TypeExtensionVisitor{extensions, storage}.add (this );
646+ }
647+
592648void ScalarType::getCapabilities (
593649 SPIRVType::CapabilityArrayRefVector &capabilities,
594650 std::optional<StorageClass> storage) {
@@ -722,23 +778,7 @@ bool SPIRVType::isScalarOrVector() {
722778
723779void SPIRVType::getExtensions (SPIRVType::ExtensionArrayRefVector &extensions,
724780 std::optional<StorageClass> storage) {
725- if (auto scalarType = llvm::dyn_cast<ScalarType>(*this )) {
726- scalarType.getExtensions (extensions, storage);
727- } else if (auto compositeType = llvm::dyn_cast<CompositeType>(*this )) {
728- compositeType.getExtensions (extensions, storage);
729- } else if (auto imageType = llvm::dyn_cast<ImageType>(*this )) {
730- imageType.getExtensions (extensions, storage);
731- } else if (auto sampledImageType = llvm::dyn_cast<SampledImageType>(*this )) {
732- sampledImageType.getExtensions (extensions, storage);
733- } else if (auto matrixType = llvm::dyn_cast<MatrixType>(*this )) {
734- matrixType.getExtensions (extensions, storage);
735- } else if (auto ptrType = llvm::dyn_cast<PointerType>(*this )) {
736- ptrType.getExtensions (extensions, storage);
737- } else if (auto tensorArmType = llvm::dyn_cast<TensorArmType>(*this )) {
738- tensorArmType.getExtensions (extensions, storage);
739- } else {
740- llvm_unreachable (" invalid SPIR-V Type to getExtensions" );
741- }
781+ TypeExtensionVisitor{extensions, storage}.add (this );
742782}
743783
744784void SPIRVType::getCapabilities (
@@ -821,7 +861,7 @@ SampledImageType::verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
821861void SampledImageType::getExtensions (
822862 SPIRVType::ExtensionArrayRefVector &extensions,
823863 std::optional<StorageClass> storage) {
824- llvm::cast<ImageType>( getImageType ()). getExtensions ( extensions, storage);
864+ TypeExtensionVisitor{ extensions, storage}. add ( this );
825865}
826866
827867void SampledImageType::getCapabilities (
@@ -1184,8 +1224,7 @@ StructType::trySetBody(ArrayRef<Type> memberTypes,
11841224
11851225void StructType::getExtensions (SPIRVType::ExtensionArrayRefVector &extensions,
11861226 std::optional<StorageClass> storage) {
1187- for (Type elementType : getElementTypes ())
1188- llvm::cast<SPIRVType>(elementType).getExtensions (extensions, storage);
1227+ TypeExtensionVisitor{extensions, storage}.add (this );
11891228}
11901229
11911230void StructType::getCapabilities (
@@ -1289,7 +1328,7 @@ unsigned MatrixType::getNumElements() const {
12891328
12901329void MatrixType::getExtensions (SPIRVType::ExtensionArrayRefVector &extensions,
12911330 std::optional<StorageClass> storage) {
1292- llvm::cast<SPIRVType>( getColumnType ()). getExtensions ( extensions, storage);
1331+ TypeExtensionVisitor{ extensions, storage}. add ( this );
12931332}
12941333
12951334void MatrixType::getCapabilities (
@@ -1347,13 +1386,16 @@ TensorArmType TensorArmType::cloneWith(std::optional<ArrayRef<int64_t>> shape,
13471386Type TensorArmType::getElementType () const { return getImpl ()->elementType ; }
13481387ArrayRef<int64_t > TensorArmType::getShape () const { return getImpl ()->shape ; }
13491388
1389+ void TypeExtensionVisitor::add (TensorArmType type) {
1390+ add (type.getElementType ());
1391+ static constexpr auto ext = Extension::SPV_ARM_tensors;
1392+ extensions.push_back (ext);
1393+ }
1394+
13501395void TensorArmType::getExtensions (
13511396 SPIRVType::ExtensionArrayRefVector &extensions,
13521397 std::optional<StorageClass> storage) {
1353-
1354- llvm::cast<SPIRVType>(getElementType ()).getExtensions (extensions, storage);
1355- static constexpr Extension ext{Extension::SPV_ARM_tensors};
1356- extensions.push_back (ext);
1398+ TypeExtensionVisitor{extensions, storage}.add (this );
13571399}
13581400
13591401void TensorArmType::getCapabilities (
0 commit comments