14
14
#include " mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
15
15
#include " mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
16
16
#include " mlir/IR/BuiltinTypes.h"
17
+ #include " mlir/Support/LLVM.h"
17
18
#include " llvm/ADT/STLExtras.h"
18
19
#include " llvm/ADT/TypeSwitch.h"
20
+ #include " llvm/Support/ErrorHandling.h"
19
21
20
22
#include < cstdint>
21
23
22
24
using namespace mlir ;
23
25
using namespace mlir ::spirv;
24
26
27
+ namespace {
28
+ // Helper function to collect extensions implied by a type by visiting all its
29
+ // subtypes. Maintains a set of `seen` types to avoid recursion in structs.
30
+ //
31
+ // Serves as the source-of-truth for type extension information. All extension
32
+ // logic should be added to this class, while the
33
+ // `SPIRVType::getExtensions` function should not handle extension-related logic
34
+ // directly and only invoke `TypeExtensionVisitor::add(Type *)`.
35
+ class TypeExtensionVisitor {
36
+ public:
37
+ TypeExtensionVisitor (SPIRVType::ExtensionArrayRefVector &extensions,
38
+ std::optional<StorageClass> storage)
39
+ : extensions(extensions), storage(storage) {}
40
+
41
+ // Main visitor entry point. Adds all extensions to the vector. Saves `type`
42
+ // as seen and dispatches to the right concrete `.add` function.
43
+ void add (SPIRVType type) {
44
+ if (auto [_it, inserted] = seen.insert ({type, storage}); !inserted)
45
+ return ;
46
+
47
+ TypeSwitch<SPIRVType>(type)
48
+ .Case <ScalarType, PointerType, CooperativeMatrixType, TensorArmType>(
49
+ [this ](auto concreteType) { addConcrete (concreteType); })
50
+ .Case <VectorType, ArrayType, RuntimeArrayType, MatrixType, ImageType>(
51
+ [this ](auto concreteType) { add (concreteType.getElementType ()); })
52
+ .Case <StructType>([this ](StructType concreteType) {
53
+ for (Type elementType : concreteType.getElementTypes ())
54
+ add (elementType);
55
+ })
56
+ .Case <SampledImageType>([this ](SampledImageType concreteType) {
57
+ add (concreteType.getImageType ());
58
+ })
59
+ .Default ([](SPIRVType) { llvm_unreachable (" Unhandled type" ); });
60
+ }
61
+
62
+ void add (Type type) { add (cast<SPIRVType>(type)); }
63
+
64
+ private:
65
+ // Types that add unique extensions.
66
+ void addConcrete (ScalarType type);
67
+ void addConcrete (PointerType type);
68
+ void addConcrete (CooperativeMatrixType type);
69
+ void addConcrete (TensorArmType type);
70
+
71
+ SPIRVType::ExtensionArrayRefVector &extensions;
72
+ std::optional<StorageClass> storage;
73
+ llvm::SmallDenseSet<std::pair<Type, std::optional<StorageClass>>> seen;
74
+ };
75
+
76
+ } // namespace
77
+
25
78
// ===----------------------------------------------------------------------===//
26
79
// ArrayType
27
80
// ===----------------------------------------------------------------------===//
@@ -65,11 +118,6 @@ Type ArrayType::getElementType() const { return getImpl()->elementType; }
65
118
66
119
unsigned ArrayType::getArrayStride () const { return getImpl ()->stride ; }
67
120
68
- void ArrayType::getExtensions (SPIRVType::ExtensionArrayRefVector &extensions,
69
- std::optional<StorageClass> storage) {
70
- llvm::cast<SPIRVType>(getElementType ()).getExtensions (extensions, storage);
71
- }
72
-
73
121
void ArrayType::getCapabilities (
74
122
SPIRVType::CapabilityArrayRefVector &capabilities,
75
123
std::optional<StorageClass> storage) {
@@ -140,27 +188,6 @@ bool CompositeType::hasCompileTimeKnownNumElements() const {
140
188
return !llvm::isa<CooperativeMatrixType, RuntimeArrayType>(*this );
141
189
}
142
190
143
- void CompositeType::getExtensions (
144
- SPIRVType::ExtensionArrayRefVector &extensions,
145
- std::optional<StorageClass> storage) {
146
- TypeSwitch<Type>(*this )
147
- .Case <ArrayType, CooperativeMatrixType, MatrixType, RuntimeArrayType,
148
- StructType>(
149
- [&](auto type) { type.getExtensions (extensions, storage); })
150
- .Case <VectorType>([&](VectorType type) {
151
- return llvm::cast<ScalarType>(type.getElementType ())
152
- .getExtensions (extensions, storage);
153
- })
154
- .Case <TensorArmType>([&](TensorArmType type) {
155
- static constexpr Extension ext{Extension::SPV_ARM_tensors};
156
- extensions.push_back (ext);
157
- return llvm::cast<ScalarType>(type.getElementType ())
158
- .getExtensions (extensions, storage);
159
- })
160
-
161
- .Default ([](Type) { llvm_unreachable (" invalid composite type" ); });
162
- }
163
-
164
191
void CompositeType::getCapabilities (
165
192
SPIRVType::CapabilityArrayRefVector &capabilities,
166
193
std::optional<StorageClass> storage) {
@@ -284,12 +311,10 @@ CooperativeMatrixUseKHR CooperativeMatrixType::getUse() const {
284
311
return getImpl ()->use ;
285
312
}
286
313
287
- void CooperativeMatrixType::getExtensions (
288
- SPIRVType::ExtensionArrayRefVector &extensions,
289
- std::optional<StorageClass> storage) {
290
- llvm::cast<SPIRVType>(getElementType ()).getExtensions (extensions, storage);
291
- static constexpr Extension exts[] = {Extension::SPV_KHR_cooperative_matrix};
292
- extensions.push_back (exts);
314
+ void TypeExtensionVisitor::addConcrete (CooperativeMatrixType type) {
315
+ add (type.getElementType ());
316
+ static constexpr auto ext = Extension::SPV_KHR_cooperative_matrix;
317
+ extensions.push_back (ext);
293
318
}
294
319
295
320
void CooperativeMatrixType::getCapabilities (
@@ -403,11 +428,6 @@ ImageSamplerUseInfo ImageType::getSamplerUseInfo() const {
403
428
404
429
ImageFormat ImageType::getImageFormat () const { return getImpl ()->format ; }
405
430
406
- void ImageType::getExtensions (SPIRVType::ExtensionArrayRefVector &,
407
- std::optional<StorageClass>) {
408
- // Image types do not require extra extensions thus far.
409
- }
410
-
411
431
void ImageType::getCapabilities (
412
432
SPIRVType::CapabilityArrayRefVector &capabilities,
413
433
std::optional<StorageClass>) {
@@ -454,14 +474,15 @@ StorageClass PointerType::getStorageClass() const {
454
474
return getImpl ()->storageClass ;
455
475
}
456
476
457
- void PointerType::getExtensions (SPIRVType::ExtensionArrayRefVector &extensions,
458
- std::optional<StorageClass> storage) {
477
+ void TypeExtensionVisitor::addConcrete (PointerType type) {
459
478
// Use this pointer type's storage class because this pointer indicates we are
460
479
// using the pointee type in that specific storage class.
461
- llvm::cast<SPIRVType>(getPointeeType ())
462
- .getExtensions (extensions, getStorageClass ());
480
+ std::optional<StorageClass> oldStorageClass = storage;
481
+ storage = type.getStorageClass ();
482
+ add (type.getPointeeType ());
483
+ storage = oldStorageClass;
463
484
464
- if (auto scExts = spirv::getExtensions (getStorageClass ()))
485
+ if (auto scExts = spirv::getExtensions (type. getStorageClass ()))
465
486
extensions.push_back (*scExts);
466
487
}
467
488
@@ -513,12 +534,6 @@ Type RuntimeArrayType::getElementType() const { return getImpl()->elementType; }
513
534
514
535
unsigned RuntimeArrayType::getArrayStride () const { return getImpl ()->stride ; }
515
536
516
- void RuntimeArrayType::getExtensions (
517
- SPIRVType::ExtensionArrayRefVector &extensions,
518
- std::optional<StorageClass> storage) {
519
- llvm::cast<SPIRVType>(getElementType ()).getExtensions (extensions, storage);
520
- }
521
-
522
537
void RuntimeArrayType::getCapabilities (
523
538
SPIRVType::CapabilityArrayRefVector &capabilities,
524
539
std::optional<StorageClass> storage) {
@@ -553,10 +568,9 @@ bool ScalarType::isValid(IntegerType type) {
553
568
return llvm::is_contained ({1u , 8u , 16u , 32u , 64u }, type.getWidth ());
554
569
}
555
570
556
- void ScalarType::getExtensions (SPIRVType::ExtensionArrayRefVector &extensions,
557
- std::optional<StorageClass> storage) {
558
- if (isa<BFloat16Type>(*this )) {
559
- static const Extension ext = Extension::SPV_KHR_bfloat16;
571
+ void TypeExtensionVisitor::addConcrete (ScalarType type) {
572
+ if (isa<BFloat16Type>(type)) {
573
+ static constexpr auto ext = Extension::SPV_KHR_bfloat16;
560
574
extensions.push_back (ext);
561
575
}
562
576
@@ -570,18 +584,16 @@ void ScalarType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
570
584
case StorageClass::PushConstant:
571
585
case StorageClass::StorageBuffer:
572
586
case StorageClass::Uniform:
573
- if (getIntOrFloatBitWidth () == 8 ) {
574
- static const Extension exts[] = {Extension::SPV_KHR_8bit_storage};
575
- ArrayRef<Extension> ref (exts, std::size (exts));
576
- extensions.push_back (ref);
587
+ if (type.getIntOrFloatBitWidth () == 8 ) {
588
+ static constexpr auto ext = Extension::SPV_KHR_8bit_storage;
589
+ extensions.push_back (ext);
577
590
}
578
591
[[fallthrough]];
579
592
case StorageClass::Input:
580
593
case StorageClass::Output:
581
- if (getIntOrFloatBitWidth () == 16 ) {
582
- static const Extension exts[] = {Extension::SPV_KHR_16bit_storage};
583
- ArrayRef<Extension> ref (exts, std::size (exts));
584
- extensions.push_back (ref);
594
+ if (type.getIntOrFloatBitWidth () == 16 ) {
595
+ static constexpr auto ext = Extension::SPV_KHR_16bit_storage;
596
+ extensions.push_back (ext);
585
597
}
586
598
break ;
587
599
default :
@@ -722,23 +734,7 @@ bool SPIRVType::isScalarOrVector() {
722
734
723
735
void SPIRVType::getExtensions (SPIRVType::ExtensionArrayRefVector &extensions,
724
736
std::optional<StorageClass> storage) {
725
- if (auto scalarType = llvm::dyn_cast<ScalarType>(*this )) {
726
- scalarType.getExtensions (extensions, storage);
727
- } else if (auto compositeType = llvm::dyn_cast<CompositeType>(*this )) {
728
- compositeType.getExtensions (extensions, storage);
729
- } else if (auto imageType = llvm::dyn_cast<ImageType>(*this )) {
730
- imageType.getExtensions (extensions, storage);
731
- } else if (auto sampledImageType = llvm::dyn_cast<SampledImageType>(*this )) {
732
- sampledImageType.getExtensions (extensions, storage);
733
- } else if (auto matrixType = llvm::dyn_cast<MatrixType>(*this )) {
734
- matrixType.getExtensions (extensions, storage);
735
- } else if (auto ptrType = llvm::dyn_cast<PointerType>(*this )) {
736
- ptrType.getExtensions (extensions, storage);
737
- } else if (auto tensorArmType = llvm::dyn_cast<TensorArmType>(*this )) {
738
- tensorArmType.getExtensions (extensions, storage);
739
- } else {
740
- llvm_unreachable (" invalid SPIR-V Type to getExtensions" );
741
- }
737
+ TypeExtensionVisitor{extensions, storage}.add (*this );
742
738
}
743
739
744
740
void SPIRVType::getCapabilities (
@@ -818,12 +814,6 @@ SampledImageType::verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
818
814
return success ();
819
815
}
820
816
821
- void SampledImageType::getExtensions (
822
- SPIRVType::ExtensionArrayRefVector &extensions,
823
- std::optional<StorageClass> storage) {
824
- llvm::cast<ImageType>(getImageType ()).getExtensions (extensions, storage);
825
- }
826
-
827
817
void SampledImageType::getCapabilities (
828
818
SPIRVType::CapabilityArrayRefVector &capabilities,
829
819
std::optional<StorageClass> storage) {
@@ -1182,12 +1172,6 @@ StructType::trySetBody(ArrayRef<Type> memberTypes,
1182
1172
structDecorations);
1183
1173
}
1184
1174
1185
- void StructType::getExtensions (SPIRVType::ExtensionArrayRefVector &extensions,
1186
- std::optional<StorageClass> storage) {
1187
- for (Type elementType : getElementTypes ())
1188
- llvm::cast<SPIRVType>(elementType).getExtensions (extensions, storage);
1189
- }
1190
-
1191
1175
void StructType::getCapabilities (
1192
1176
SPIRVType::CapabilityArrayRefVector &capabilities,
1193
1177
std::optional<StorageClass> storage) {
@@ -1287,11 +1271,6 @@ unsigned MatrixType::getNumElements() const {
1287
1271
return (getImpl ()->columnCount ) * getNumRows ();
1288
1272
}
1289
1273
1290
- void MatrixType::getExtensions (SPIRVType::ExtensionArrayRefVector &extensions,
1291
- std::optional<StorageClass> storage) {
1292
- llvm::cast<SPIRVType>(getColumnType ()).getExtensions (extensions, storage);
1293
- }
1294
-
1295
1274
void MatrixType::getCapabilities (
1296
1275
SPIRVType::CapabilityArrayRefVector &capabilities,
1297
1276
std::optional<StorageClass> storage) {
@@ -1347,12 +1326,9 @@ TensorArmType TensorArmType::cloneWith(std::optional<ArrayRef<int64_t>> shape,
1347
1326
Type TensorArmType::getElementType () const { return getImpl ()->elementType ; }
1348
1327
ArrayRef<int64_t > TensorArmType::getShape () const { return getImpl ()->shape ; }
1349
1328
1350
- void TensorArmType::getExtensions (
1351
- SPIRVType::ExtensionArrayRefVector &extensions,
1352
- std::optional<StorageClass> storage) {
1353
-
1354
- llvm::cast<SPIRVType>(getElementType ()).getExtensions (extensions, storage);
1355
- static constexpr Extension ext{Extension::SPV_ARM_tensors};
1329
+ void TypeExtensionVisitor::addConcrete (TensorArmType type) {
1330
+ add (type.getElementType ());
1331
+ static constexpr auto ext = Extension::SPV_ARM_tensors;
1356
1332
extensions.push_back (ext);
1357
1333
}
1358
1334
0 commit comments