Skip to content

Commit 901902b

Browse files
committed
[mlir][spirv] Rework type extension queries
* Fix infinite recursion with nested structs. * Move all extension logic to a new helper class -- this way `::getExtensions` functions can't diverge across concrete types and 'convenience types' like `CompositeType`. We should also fix `::getCapabilities` in a similar way and move the testcase to `vce-deduction.mlir`. Issue: #159963
1 parent 105fc90 commit 901902b

File tree

2 files changed

+120
-67
lines changed

2 files changed

+120
-67
lines changed

mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp

Lines changed: 108 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,73 @@
1414
#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
1515
#include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
1616
#include "mlir/IR/BuiltinTypes.h"
17+
#include "mlir/Support/LLVM.h"
1718
#include "llvm/ADT/STLExtras.h"
1819
#include "llvm/ADT/TypeSwitch.h"
20+
#include "llvm/Support/ErrorHandling.h"
1921

2022
#include <cstdint>
2123

2224
using namespace mlir;
2325
using namespace mlir::spirv;
2426

27+
namespace {
28+
// Helper function to collect extensions implied by a type by visiting all its
29+
// subtypes. Maintains a set of `seen` types to avoid recursion in structs.
30+
//
31+
// Serves as the source-of-truth for type extension information. All extension
32+
// logic should be added to this class, while
33+
// `*Type::getExtensions` functions should not handle extension-related logic
34+
// directly and only invoke `TypeExtensionVisitor::add(Type *)`.
35+
class TypeExtensionVisitor {
36+
SPIRVType::ExtensionArrayRefVector &extensions;
37+
std::optional<StorageClass> storage;
38+
DenseSet<Type> seen;
39+
40+
public:
41+
TypeExtensionVisitor(SPIRVType::ExtensionArrayRefVector &extensions,
42+
std::optional<StorageClass> storage)
43+
: extensions(extensions), storage(storage) {}
44+
45+
// Main visitor entry point. Adds all extensions to the vector. Saves `type`
46+
// as seen and dispatches to the right concrete `.add` function.
47+
void add(SPIRVType type) {
48+
if (auto [_it, inserted] = seen.insert(type); !inserted)
49+
return;
50+
51+
TypeSwitch<SPIRVType>(type)
52+
.Case<ScalarType, PointerType, CooperativeMatrixType, TensorArmType,
53+
VectorType, ArrayType, RuntimeArrayType, StructType, MatrixType,
54+
ImageType, SampledImageType>(
55+
[this](auto concreteType) { add(concreteType); })
56+
.Default([](SPIRVType) { llvm_unreachable("Unhandled type"); });
57+
}
58+
59+
// Convenience overloads for use in `T::getExtensions` functions.
60+
void add(Type type) { add(cast<SPIRVType>(type)); }
61+
void add(Type *type) { add(cast<SPIRVType>(*type)); }
62+
63+
// Types that add unique extensions.
64+
void add(ScalarType type);
65+
void add(PointerType type);
66+
void add(CooperativeMatrixType type);
67+
void add(TensorArmType type);
68+
69+
// Trivial passthrough without any new extensions.
70+
void add(VectorType type) { add(type.getElementType()); }
71+
void add(ArrayType type) { add(type.getElementType()); }
72+
void add(RuntimeArrayType type) { add(type.getElementType()); }
73+
void add(StructType type) {
74+
for (Type elementType : type.getElementTypes())
75+
add(elementType);
76+
}
77+
void add(MatrixType type) { add(type.getElementType()); }
78+
void add(ImageType type) { add(type.getElementType()); }
79+
void add(SampledImageType type) { add(type.getImageType()); }
80+
};
81+
82+
} // namespace
83+
2584
//===----------------------------------------------------------------------===//
2685
// ArrayType
2786
//===----------------------------------------------------------------------===//
@@ -67,7 +126,7 @@ unsigned ArrayType::getArrayStride() const { return getImpl()->stride; }
67126

68127
void ArrayType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
69128
std::optional<StorageClass> storage) {
70-
llvm::cast<SPIRVType>(getElementType()).getExtensions(extensions, storage);
129+
TypeExtensionVisitor{extensions, storage}.add(this);
71130
}
72131

73132
void ArrayType::getCapabilities(
@@ -143,22 +202,7 @@ bool CompositeType::hasCompileTimeKnownNumElements() const {
143202
void CompositeType::getExtensions(
144203
SPIRVType::ExtensionArrayRefVector &extensions,
145204
std::optional<StorageClass> storage) {
146-
TypeSwitch<Type>(*this)
147-
.Case<ArrayType, CooperativeMatrixType, MatrixType, RuntimeArrayType,
148-
StructType>(
149-
[&](auto type) { type.getExtensions(extensions, storage); })
150-
.Case<VectorType>([&](VectorType type) {
151-
return llvm::cast<ScalarType>(type.getElementType())
152-
.getExtensions(extensions, storage);
153-
})
154-
.Case<TensorArmType>([&](TensorArmType type) {
155-
static constexpr Extension ext{Extension::SPV_ARM_tensors};
156-
extensions.push_back(ext);
157-
return llvm::cast<ScalarType>(type.getElementType())
158-
.getExtensions(extensions, storage);
159-
})
160-
161-
.Default([](Type) { llvm_unreachable("invalid composite type"); });
205+
TypeExtensionVisitor{extensions, storage}.add(cast<SPIRVType>(*this));
162206
}
163207

164208
void CompositeType::getCapabilities(
@@ -284,12 +328,16 @@ CooperativeMatrixUseKHR CooperativeMatrixType::getUse() const {
284328
return getImpl()->use;
285329
}
286330

331+
void TypeExtensionVisitor::add(CooperativeMatrixType type) {
332+
add(type.getElementType());
333+
static constexpr auto ext = Extension::SPV_KHR_cooperative_matrix;
334+
extensions.push_back(ext);
335+
}
336+
287337
void CooperativeMatrixType::getExtensions(
288338
SPIRVType::ExtensionArrayRefVector &extensions,
289339
std::optional<StorageClass> storage) {
290-
llvm::cast<SPIRVType>(getElementType()).getExtensions(extensions, storage);
291-
static constexpr Extension exts[] = {Extension::SPV_KHR_cooperative_matrix};
292-
extensions.push_back(exts);
340+
TypeExtensionVisitor{extensions, storage}.add(this);
293341
}
294342

295343
void CooperativeMatrixType::getCapabilities(
@@ -403,9 +451,9 @@ ImageSamplerUseInfo ImageType::getSamplerUseInfo() const {
403451

404452
ImageFormat ImageType::getImageFormat() const { return getImpl()->format; }
405453

406-
void ImageType::getExtensions(SPIRVType::ExtensionArrayRefVector &,
407-
std::optional<StorageClass>) {
408-
// Image types do not require extra extensions thus far.
454+
void ImageType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
455+
std::optional<StorageClass> storage) {
456+
TypeExtensionVisitor{extensions, storage}.add(this);
409457
}
410458

411459
void ImageType::getCapabilities(
@@ -454,17 +502,23 @@ StorageClass PointerType::getStorageClass() const {
454502
return getImpl()->storageClass;
455503
}
456504

457-
void PointerType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
458-
std::optional<StorageClass> storage) {
505+
void TypeExtensionVisitor::add(PointerType type) {
459506
// Use this pointer type's storage class because this pointer indicates we are
460507
// using the pointee type in that specific storage class.
461-
llvm::cast<SPIRVType>(getPointeeType())
462-
.getExtensions(extensions, getStorageClass());
508+
std::optional<StorageClass> oldStorageClass = storage;
509+
storage = type.getStorageClass();
510+
add(type.getPointeeType());
511+
storage = oldStorageClass;
463512

464-
if (auto scExts = spirv::getExtensions(getStorageClass()))
513+
if (auto scExts = spirv::getExtensions(type.getStorageClass()))
465514
extensions.push_back(*scExts);
466515
}
467516

517+
void PointerType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
518+
std::optional<StorageClass> storage) {
519+
TypeExtensionVisitor{extensions, storage}.add(this);
520+
}
521+
468522
void PointerType::getCapabilities(
469523
SPIRVType::CapabilityArrayRefVector &capabilities,
470524
std::optional<StorageClass> storage) {
@@ -516,7 +570,7 @@ unsigned RuntimeArrayType::getArrayStride() const { return getImpl()->stride; }
516570
void RuntimeArrayType::getExtensions(
517571
SPIRVType::ExtensionArrayRefVector &extensions,
518572
std::optional<StorageClass> storage) {
519-
llvm::cast<SPIRVType>(getElementType()).getExtensions(extensions, storage);
573+
TypeExtensionVisitor{extensions, storage}.add(this);
520574
}
521575

522576
void RuntimeArrayType::getCapabilities(
@@ -553,10 +607,9 @@ bool ScalarType::isValid(IntegerType type) {
553607
return llvm::is_contained({1u, 8u, 16u, 32u, 64u}, type.getWidth());
554608
}
555609

556-
void ScalarType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
557-
std::optional<StorageClass> storage) {
558-
if (isa<BFloat16Type>(*this)) {
559-
static const Extension ext = Extension::SPV_KHR_bfloat16;
610+
void TypeExtensionVisitor::add(ScalarType type) {
611+
if (isa<BFloat16Type>(type)) {
612+
static constexpr auto ext = Extension::SPV_KHR_bfloat16;
560613
extensions.push_back(ext);
561614
}
562615

@@ -570,25 +623,28 @@ void ScalarType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
570623
case StorageClass::PushConstant:
571624
case StorageClass::StorageBuffer:
572625
case StorageClass::Uniform:
573-
if (getIntOrFloatBitWidth() == 8) {
574-
static const Extension exts[] = {Extension::SPV_KHR_8bit_storage};
575-
ArrayRef<Extension> ref(exts, std::size(exts));
576-
extensions.push_back(ref);
626+
if (type.getIntOrFloatBitWidth() == 8) {
627+
static constexpr auto ext = Extension::SPV_KHR_8bit_storage;
628+
extensions.push_back(ext);
577629
}
578630
[[fallthrough]];
579631
case StorageClass::Input:
580632
case StorageClass::Output:
581-
if (getIntOrFloatBitWidth() == 16) {
582-
static const Extension exts[] = {Extension::SPV_KHR_16bit_storage};
583-
ArrayRef<Extension> ref(exts, std::size(exts));
584-
extensions.push_back(ref);
633+
if (type.getIntOrFloatBitWidth() == 16) {
634+
static constexpr auto ext = Extension::SPV_KHR_16bit_storage;
635+
extensions.push_back(ext);
585636
}
586637
break;
587638
default:
588639
break;
589640
}
590641
}
591642

643+
void ScalarType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
644+
std::optional<StorageClass> storage) {
645+
TypeExtensionVisitor{extensions, storage}.add(this);
646+
}
647+
592648
void ScalarType::getCapabilities(
593649
SPIRVType::CapabilityArrayRefVector &capabilities,
594650
std::optional<StorageClass> storage) {
@@ -722,23 +778,7 @@ bool SPIRVType::isScalarOrVector() {
722778

723779
void SPIRVType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
724780
std::optional<StorageClass> storage) {
725-
if (auto scalarType = llvm::dyn_cast<ScalarType>(*this)) {
726-
scalarType.getExtensions(extensions, storage);
727-
} else if (auto compositeType = llvm::dyn_cast<CompositeType>(*this)) {
728-
compositeType.getExtensions(extensions, storage);
729-
} else if (auto imageType = llvm::dyn_cast<ImageType>(*this)) {
730-
imageType.getExtensions(extensions, storage);
731-
} else if (auto sampledImageType = llvm::dyn_cast<SampledImageType>(*this)) {
732-
sampledImageType.getExtensions(extensions, storage);
733-
} else if (auto matrixType = llvm::dyn_cast<MatrixType>(*this)) {
734-
matrixType.getExtensions(extensions, storage);
735-
} else if (auto ptrType = llvm::dyn_cast<PointerType>(*this)) {
736-
ptrType.getExtensions(extensions, storage);
737-
} else if (auto tensorArmType = llvm::dyn_cast<TensorArmType>(*this)) {
738-
tensorArmType.getExtensions(extensions, storage);
739-
} else {
740-
llvm_unreachable("invalid SPIR-V Type to getExtensions");
741-
}
781+
TypeExtensionVisitor{extensions, storage}.add(this);
742782
}
743783

744784
void SPIRVType::getCapabilities(
@@ -821,7 +861,7 @@ SampledImageType::verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
821861
void SampledImageType::getExtensions(
822862
SPIRVType::ExtensionArrayRefVector &extensions,
823863
std::optional<StorageClass> storage) {
824-
llvm::cast<ImageType>(getImageType()).getExtensions(extensions, storage);
864+
TypeExtensionVisitor{extensions, storage}.add(this);
825865
}
826866

827867
void SampledImageType::getCapabilities(
@@ -1184,8 +1224,7 @@ StructType::trySetBody(ArrayRef<Type> memberTypes,
11841224

11851225
void StructType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
11861226
std::optional<StorageClass> storage) {
1187-
for (Type elementType : getElementTypes())
1188-
llvm::cast<SPIRVType>(elementType).getExtensions(extensions, storage);
1227+
TypeExtensionVisitor{extensions, storage}.add(this);
11891228
}
11901229

11911230
void StructType::getCapabilities(
@@ -1289,7 +1328,7 @@ unsigned MatrixType::getNumElements() const {
12891328

12901329
void MatrixType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
12911330
std::optional<StorageClass> storage) {
1292-
llvm::cast<SPIRVType>(getColumnType()).getExtensions(extensions, storage);
1331+
TypeExtensionVisitor{extensions, storage}.add(this);
12931332
}
12941333

12951334
void MatrixType::getCapabilities(
@@ -1347,13 +1386,16 @@ TensorArmType TensorArmType::cloneWith(std::optional<ArrayRef<int64_t>> shape,
13471386
Type TensorArmType::getElementType() const { return getImpl()->elementType; }
13481387
ArrayRef<int64_t> TensorArmType::getShape() const { return getImpl()->shape; }
13491388

1389+
void TypeExtensionVisitor::add(TensorArmType type) {
1390+
add(type.getElementType());
1391+
static constexpr auto ext = Extension::SPV_ARM_tensors;
1392+
extensions.push_back(ext);
1393+
}
1394+
13501395
void TensorArmType::getExtensions(
13511396
SPIRVType::ExtensionArrayRefVector &extensions,
13521397
std::optional<StorageClass> storage) {
1353-
1354-
llvm::cast<SPIRVType>(getElementType()).getExtensions(extensions, storage);
1355-
static constexpr Extension ext{Extension::SPV_ARM_tensors};
1356-
extensions.push_back(ext);
1398+
TypeExtensionVisitor{extensions, storage}.add(this);
13571399
}
13581400

13591401
void TensorArmType::getCapabilities(

mlir/test/Conversion/SCFToSPIRV/unsupported.mlir

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: mlir-opt -convert-scf-to-spirv %s -o - | FileCheck %s
1+
// RUN: mlir-opt --convert-scf-to-spirv %s --verify-diagnostics --split-input-file | FileCheck %s
22

33
// `scf.parallel` conversion is not supported yet.
44
// Make sure that we do not accidentally invalidate this function by removing
@@ -19,3 +19,14 @@ func.func @func(%arg0: i64) {
1919
}
2020
return
2121
}
22+
23+
// -----
24+
25+
// Make sure we don't crash on recursive structs.
26+
// TODO(https://github.com/llvm/llvm-project/issues/159963): Promote this to a `vce-deduction.mlir` testcase.
27+
28+
// expected-error@below {{failed to legalize operation 'spirv.module' that was explicitly marked illegal}}
29+
spirv.module Physical64 GLSL450 {
30+
spirv.GlobalVariable @recursive:
31+
!spirv.ptr<!spirv.struct<rec, (!spirv.ptr<!spirv.struct<rec>, StorageBuffer>)>, StorageBuffer>
32+
}

0 commit comments

Comments
 (0)