Skip to content

Commit 32b1f16

Browse files
authored
[mlir][spirv] Rework type extension queries (#160020)
* Fix infinite recursion with nested structs. * Drop `::getExtensions` function from derived types, so that there's only one entry point that queries type extensions. * Move all extension logic to a new helper class -- this way the `::getExtensions` functions can't diverge across concrete types and 'convenience types' like `CompositeType`. We should also fix `::getCapabilities` in a similar way and move the testcase to `vce-deduction.mlir`. Issue: #159963
1 parent dfd50f9 commit 32b1f16

File tree

3 files changed

+88
-123
lines changed

3 files changed

+88
-123
lines changed

mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h

Lines changed: 0 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -89,8 +89,6 @@ class ScalarType : public SPIRVType {
8989
/// Returns true if the given float type is valid for the SPIR-V dialect.
9090
static bool isValid(IntegerType);
9191

92-
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
93-
std::optional<StorageClass> storage = std::nullopt);
9492
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,
9593
std::optional<StorageClass> storage = std::nullopt);
9694

@@ -118,8 +116,6 @@ class CompositeType : public SPIRVType {
118116
/// implementation dependent.
119117
bool hasCompileTimeKnownNumElements() const;
120118

121-
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
122-
std::optional<StorageClass> storage = std::nullopt);
123119
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,
124120
std::optional<StorageClass> storage = std::nullopt);
125121

@@ -148,8 +144,6 @@ class ArrayType : public Type::TypeBase<ArrayType, CompositeType,
148144
/// type.
149145
unsigned getArrayStride() const;
150146

151-
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
152-
std::optional<StorageClass> storage = std::nullopt);
153147
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,
154148
std::optional<StorageClass> storage = std::nullopt);
155149

@@ -193,8 +187,6 @@ class ImageType
193187
ImageFormat getImageFormat() const;
194188
// TODO: Add support for Access qualifier
195189

196-
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
197-
std::optional<StorageClass> storage = std::nullopt);
198190
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,
199191
std::optional<StorageClass> storage = std::nullopt);
200192
};
@@ -213,8 +205,6 @@ class PointerType : public Type::TypeBase<PointerType, SPIRVType,
213205

214206
StorageClass getStorageClass() const;
215207

216-
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
217-
std::optional<StorageClass> storage = std::nullopt);
218208
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,
219209
std::optional<StorageClass> storage = std::nullopt);
220210
};
@@ -239,8 +229,6 @@ class RuntimeArrayType
239229
/// type.
240230
unsigned getArrayStride() const;
241231

242-
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
243-
std::optional<StorageClass> storage = std::nullopt);
244232
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,
245233
std::optional<StorageClass> storage = std::nullopt);
246234
};
@@ -265,8 +253,6 @@ class SampledImageType
265253

266254
Type getImageType() const;
267255

268-
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
269-
std::optional<spirv::StorageClass> storage = std::nullopt);
270256
void
271257
getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,
272258
std::optional<spirv::StorageClass> storage = std::nullopt);
@@ -420,8 +406,6 @@ class StructType
420406
ArrayRef<MemberDecorationInfo> memberDecorations = {},
421407
ArrayRef<StructDecorationInfo> structDecorations = {});
422408

423-
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
424-
std::optional<StorageClass> storage = std::nullopt);
425409
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,
426410
std::optional<StorageClass> storage = std::nullopt);
427411
};
@@ -456,8 +440,6 @@ class CooperativeMatrixType
456440
/// Returns the use parameter of the cooperative matrix.
457441
CooperativeMatrixUseKHR getUse() const;
458442

459-
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
460-
std::optional<StorageClass> storage = std::nullopt);
461443
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,
462444
std::optional<StorageClass> storage = std::nullopt);
463445

@@ -512,8 +494,6 @@ class MatrixType : public Type::TypeBase<MatrixType, CompositeType,
512494
/// Returns the elements' type (i.e, single element type).
513495
Type getElementType() const;
514496

515-
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
516-
std::optional<StorageClass> storage = std::nullopt);
517497
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,
518498
std::optional<StorageClass> storage = std::nullopt);
519499
};
@@ -552,8 +532,6 @@ class TensorArmType
552532
bool hasRank() const { return !getShape().empty(); }
553533
operator ShapedType() const { return llvm::cast<ShapedType>(*this); }
554534

555-
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
556-
std::optional<StorageClass> storage = std::nullopt);
557535
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,
558536
std::optional<StorageClass> storage = std::nullopt);
559537
};

mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp

Lines changed: 76 additions & 100 deletions
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,67 @@
1414
#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
1515
#include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
1616
#include "mlir/IR/BuiltinTypes.h"
17+
#include "mlir/Support/LLVM.h"
1718
#include "llvm/ADT/STLExtras.h"
1819
#include "llvm/ADT/TypeSwitch.h"
20+
#include "llvm/Support/ErrorHandling.h"
1921

2022
#include <cstdint>
2123

2224
using namespace mlir;
2325
using namespace mlir::spirv;
2426

27+
namespace {
28+
// Helper function to collect extensions implied by a type by visiting all its
29+
// subtypes. Maintains a set of `seen` types to avoid recursion in structs.
30+
//
31+
// Serves as the source-of-truth for type extension information. All extension
32+
// logic should be added to this class, while the
33+
// `SPIRVType::getExtensions` function should not handle extension-related logic
34+
// directly and only invoke `TypeExtensionVisitor::add(Type *)`.
35+
class TypeExtensionVisitor {
36+
public:
37+
TypeExtensionVisitor(SPIRVType::ExtensionArrayRefVector &extensions,
38+
std::optional<StorageClass> storage)
39+
: extensions(extensions), storage(storage) {}
40+
41+
// Main visitor entry point. Adds all extensions to the vector. Saves `type`
42+
// as seen and dispatches to the right concrete `.add` function.
43+
void add(SPIRVType type) {
44+
if (auto [_it, inserted] = seen.insert({type, storage}); !inserted)
45+
return;
46+
47+
TypeSwitch<SPIRVType>(type)
48+
.Case<ScalarType, PointerType, CooperativeMatrixType, TensorArmType>(
49+
[this](auto concreteType) { addConcrete(concreteType); })
50+
.Case<VectorType, ArrayType, RuntimeArrayType, MatrixType, ImageType>(
51+
[this](auto concreteType) { add(concreteType.getElementType()); })
52+
.Case<StructType>([this](StructType concreteType) {
53+
for (Type elementType : concreteType.getElementTypes())
54+
add(elementType);
55+
})
56+
.Case<SampledImageType>([this](SampledImageType concreteType) {
57+
add(concreteType.getImageType());
58+
})
59+
.Default([](SPIRVType) { llvm_unreachable("Unhandled type"); });
60+
}
61+
62+
void add(Type type) { add(cast<SPIRVType>(type)); }
63+
64+
private:
65+
// Types that add unique extensions.
66+
void addConcrete(ScalarType type);
67+
void addConcrete(PointerType type);
68+
void addConcrete(CooperativeMatrixType type);
69+
void addConcrete(TensorArmType type);
70+
71+
SPIRVType::ExtensionArrayRefVector &extensions;
72+
std::optional<StorageClass> storage;
73+
llvm::SmallDenseSet<std::pair<Type, std::optional<StorageClass>>> seen;
74+
};
75+
76+
} // namespace
77+
2578
//===----------------------------------------------------------------------===//
2679
// ArrayType
2780
//===----------------------------------------------------------------------===//
@@ -65,11 +118,6 @@ Type ArrayType::getElementType() const { return getImpl()->elementType; }
65118

66119
unsigned ArrayType::getArrayStride() const { return getImpl()->stride; }
67120

68-
void ArrayType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
69-
std::optional<StorageClass> storage) {
70-
llvm::cast<SPIRVType>(getElementType()).getExtensions(extensions, storage);
71-
}
72-
73121
void ArrayType::getCapabilities(
74122
SPIRVType::CapabilityArrayRefVector &capabilities,
75123
std::optional<StorageClass> storage) {
@@ -140,27 +188,6 @@ bool CompositeType::hasCompileTimeKnownNumElements() const {
140188
return !llvm::isa<CooperativeMatrixType, RuntimeArrayType>(*this);
141189
}
142190

143-
void CompositeType::getExtensions(
144-
SPIRVType::ExtensionArrayRefVector &extensions,
145-
std::optional<StorageClass> storage) {
146-
TypeSwitch<Type>(*this)
147-
.Case<ArrayType, CooperativeMatrixType, MatrixType, RuntimeArrayType,
148-
StructType>(
149-
[&](auto type) { type.getExtensions(extensions, storage); })
150-
.Case<VectorType>([&](VectorType type) {
151-
return llvm::cast<ScalarType>(type.getElementType())
152-
.getExtensions(extensions, storage);
153-
})
154-
.Case<TensorArmType>([&](TensorArmType type) {
155-
static constexpr Extension ext{Extension::SPV_ARM_tensors};
156-
extensions.push_back(ext);
157-
return llvm::cast<ScalarType>(type.getElementType())
158-
.getExtensions(extensions, storage);
159-
})
160-
161-
.Default([](Type) { llvm_unreachable("invalid composite type"); });
162-
}
163-
164191
void CompositeType::getCapabilities(
165192
SPIRVType::CapabilityArrayRefVector &capabilities,
166193
std::optional<StorageClass> storage) {
@@ -284,12 +311,10 @@ CooperativeMatrixUseKHR CooperativeMatrixType::getUse() const {
284311
return getImpl()->use;
285312
}
286313

287-
void CooperativeMatrixType::getExtensions(
288-
SPIRVType::ExtensionArrayRefVector &extensions,
289-
std::optional<StorageClass> storage) {
290-
llvm::cast<SPIRVType>(getElementType()).getExtensions(extensions, storage);
291-
static constexpr Extension exts[] = {Extension::SPV_KHR_cooperative_matrix};
292-
extensions.push_back(exts);
314+
void TypeExtensionVisitor::addConcrete(CooperativeMatrixType type) {
315+
add(type.getElementType());
316+
static constexpr auto ext = Extension::SPV_KHR_cooperative_matrix;
317+
extensions.push_back(ext);
293318
}
294319

295320
void CooperativeMatrixType::getCapabilities(
@@ -403,11 +428,6 @@ ImageSamplerUseInfo ImageType::getSamplerUseInfo() const {
403428

404429
ImageFormat ImageType::getImageFormat() const { return getImpl()->format; }
405430

406-
void ImageType::getExtensions(SPIRVType::ExtensionArrayRefVector &,
407-
std::optional<StorageClass>) {
408-
// Image types do not require extra extensions thus far.
409-
}
410-
411431
void ImageType::getCapabilities(
412432
SPIRVType::CapabilityArrayRefVector &capabilities,
413433
std::optional<StorageClass>) {
@@ -454,14 +474,15 @@ StorageClass PointerType::getStorageClass() const {
454474
return getImpl()->storageClass;
455475
}
456476

457-
void PointerType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
458-
std::optional<StorageClass> storage) {
477+
void TypeExtensionVisitor::addConcrete(PointerType type) {
459478
// Use this pointer type's storage class because this pointer indicates we are
460479
// using the pointee type in that specific storage class.
461-
llvm::cast<SPIRVType>(getPointeeType())
462-
.getExtensions(extensions, getStorageClass());
480+
std::optional<StorageClass> oldStorageClass = storage;
481+
storage = type.getStorageClass();
482+
add(type.getPointeeType());
483+
storage = oldStorageClass;
463484

464-
if (auto scExts = spirv::getExtensions(getStorageClass()))
485+
if (auto scExts = spirv::getExtensions(type.getStorageClass()))
465486
extensions.push_back(*scExts);
466487
}
467488

@@ -513,12 +534,6 @@ Type RuntimeArrayType::getElementType() const { return getImpl()->elementType; }
513534

514535
unsigned RuntimeArrayType::getArrayStride() const { return getImpl()->stride; }
515536

516-
void RuntimeArrayType::getExtensions(
517-
SPIRVType::ExtensionArrayRefVector &extensions,
518-
std::optional<StorageClass> storage) {
519-
llvm::cast<SPIRVType>(getElementType()).getExtensions(extensions, storage);
520-
}
521-
522537
void RuntimeArrayType::getCapabilities(
523538
SPIRVType::CapabilityArrayRefVector &capabilities,
524539
std::optional<StorageClass> storage) {
@@ -553,10 +568,9 @@ bool ScalarType::isValid(IntegerType type) {
553568
return llvm::is_contained({1u, 8u, 16u, 32u, 64u}, type.getWidth());
554569
}
555570

556-
void ScalarType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
557-
std::optional<StorageClass> storage) {
558-
if (isa<BFloat16Type>(*this)) {
559-
static const Extension ext = Extension::SPV_KHR_bfloat16;
571+
void TypeExtensionVisitor::addConcrete(ScalarType type) {
572+
if (isa<BFloat16Type>(type)) {
573+
static constexpr auto ext = Extension::SPV_KHR_bfloat16;
560574
extensions.push_back(ext);
561575
}
562576

@@ -570,18 +584,16 @@ void ScalarType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
570584
case StorageClass::PushConstant:
571585
case StorageClass::StorageBuffer:
572586
case StorageClass::Uniform:
573-
if (getIntOrFloatBitWidth() == 8) {
574-
static const Extension exts[] = {Extension::SPV_KHR_8bit_storage};
575-
ArrayRef<Extension> ref(exts, std::size(exts));
576-
extensions.push_back(ref);
587+
if (type.getIntOrFloatBitWidth() == 8) {
588+
static constexpr auto ext = Extension::SPV_KHR_8bit_storage;
589+
extensions.push_back(ext);
577590
}
578591
[[fallthrough]];
579592
case StorageClass::Input:
580593
case StorageClass::Output:
581-
if (getIntOrFloatBitWidth() == 16) {
582-
static const Extension exts[] = {Extension::SPV_KHR_16bit_storage};
583-
ArrayRef<Extension> ref(exts, std::size(exts));
584-
extensions.push_back(ref);
594+
if (type.getIntOrFloatBitWidth() == 16) {
595+
static constexpr auto ext = Extension::SPV_KHR_16bit_storage;
596+
extensions.push_back(ext);
585597
}
586598
break;
587599
default:
@@ -722,23 +734,7 @@ bool SPIRVType::isScalarOrVector() {
722734

723735
void SPIRVType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
724736
std::optional<StorageClass> storage) {
725-
if (auto scalarType = llvm::dyn_cast<ScalarType>(*this)) {
726-
scalarType.getExtensions(extensions, storage);
727-
} else if (auto compositeType = llvm::dyn_cast<CompositeType>(*this)) {
728-
compositeType.getExtensions(extensions, storage);
729-
} else if (auto imageType = llvm::dyn_cast<ImageType>(*this)) {
730-
imageType.getExtensions(extensions, storage);
731-
} else if (auto sampledImageType = llvm::dyn_cast<SampledImageType>(*this)) {
732-
sampledImageType.getExtensions(extensions, storage);
733-
} else if (auto matrixType = llvm::dyn_cast<MatrixType>(*this)) {
734-
matrixType.getExtensions(extensions, storage);
735-
} else if (auto ptrType = llvm::dyn_cast<PointerType>(*this)) {
736-
ptrType.getExtensions(extensions, storage);
737-
} else if (auto tensorArmType = llvm::dyn_cast<TensorArmType>(*this)) {
738-
tensorArmType.getExtensions(extensions, storage);
739-
} else {
740-
llvm_unreachable("invalid SPIR-V Type to getExtensions");
741-
}
737+
TypeExtensionVisitor{extensions, storage}.add(*this);
742738
}
743739

744740
void SPIRVType::getCapabilities(
@@ -818,12 +814,6 @@ SampledImageType::verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
818814
return success();
819815
}
820816

821-
void SampledImageType::getExtensions(
822-
SPIRVType::ExtensionArrayRefVector &extensions,
823-
std::optional<StorageClass> storage) {
824-
llvm::cast<ImageType>(getImageType()).getExtensions(extensions, storage);
825-
}
826-
827817
void SampledImageType::getCapabilities(
828818
SPIRVType::CapabilityArrayRefVector &capabilities,
829819
std::optional<StorageClass> storage) {
@@ -1182,12 +1172,6 @@ StructType::trySetBody(ArrayRef<Type> memberTypes,
11821172
structDecorations);
11831173
}
11841174

1185-
void StructType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
1186-
std::optional<StorageClass> storage) {
1187-
for (Type elementType : getElementTypes())
1188-
llvm::cast<SPIRVType>(elementType).getExtensions(extensions, storage);
1189-
}
1190-
11911175
void StructType::getCapabilities(
11921176
SPIRVType::CapabilityArrayRefVector &capabilities,
11931177
std::optional<StorageClass> storage) {
@@ -1287,11 +1271,6 @@ unsigned MatrixType::getNumElements() const {
12871271
return (getImpl()->columnCount) * getNumRows();
12881272
}
12891273

1290-
void MatrixType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
1291-
std::optional<StorageClass> storage) {
1292-
llvm::cast<SPIRVType>(getColumnType()).getExtensions(extensions, storage);
1293-
}
1294-
12951274
void MatrixType::getCapabilities(
12961275
SPIRVType::CapabilityArrayRefVector &capabilities,
12971276
std::optional<StorageClass> storage) {
@@ -1347,12 +1326,9 @@ TensorArmType TensorArmType::cloneWith(std::optional<ArrayRef<int64_t>> shape,
13471326
Type TensorArmType::getElementType() const { return getImpl()->elementType; }
13481327
ArrayRef<int64_t> TensorArmType::getShape() const { return getImpl()->shape; }
13491328

1350-
void TensorArmType::getExtensions(
1351-
SPIRVType::ExtensionArrayRefVector &extensions,
1352-
std::optional<StorageClass> storage) {
1353-
1354-
llvm::cast<SPIRVType>(getElementType()).getExtensions(extensions, storage);
1355-
static constexpr Extension ext{Extension::SPV_ARM_tensors};
1329+
void TypeExtensionVisitor::addConcrete(TensorArmType type) {
1330+
add(type.getElementType());
1331+
static constexpr auto ext = Extension::SPV_ARM_tensors;
13561332
extensions.push_back(ext);
13571333
}
13581334

0 commit comments

Comments
 (0)