Skip to content

Commit 7561596

Browse files
committed
Simplify public API
1 parent 9054751 commit 7561596

File tree

2 files changed

+3
-87
lines changed

2 files changed

+3
-87
lines changed

mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h

Lines changed: 0 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -89,8 +89,6 @@ class ScalarType : public SPIRVType {
8989
/// Returns true if the given float type is valid for the SPIR-V dialect.
9090
static bool isValid(IntegerType);
9191

92-
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
93-
std::optional<StorageClass> storage = std::nullopt);
9492
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,
9593
std::optional<StorageClass> storage = std::nullopt);
9694

@@ -118,8 +116,6 @@ class CompositeType : public SPIRVType {
118116
/// implementation dependent.
119117
bool hasCompileTimeKnownNumElements() const;
120118

121-
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
122-
std::optional<StorageClass> storage = std::nullopt);
123119
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,
124120
std::optional<StorageClass> storage = std::nullopt);
125121

@@ -148,8 +144,6 @@ class ArrayType : public Type::TypeBase<ArrayType, CompositeType,
148144
/// type.
149145
unsigned getArrayStride() const;
150146

151-
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
152-
std::optional<StorageClass> storage = std::nullopt);
153147
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,
154148
std::optional<StorageClass> storage = std::nullopt);
155149

@@ -193,8 +187,6 @@ class ImageType
193187
ImageFormat getImageFormat() const;
194188
// TODO: Add support for Access qualifier
195189

196-
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
197-
std::optional<StorageClass> storage = std::nullopt);
198190
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,
199191
std::optional<StorageClass> storage = std::nullopt);
200192
};
@@ -213,8 +205,6 @@ class PointerType : public Type::TypeBase<PointerType, SPIRVType,
213205

214206
StorageClass getStorageClass() const;
215207

216-
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
217-
std::optional<StorageClass> storage = std::nullopt);
218208
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,
219209
std::optional<StorageClass> storage = std::nullopt);
220210
};
@@ -239,8 +229,6 @@ class RuntimeArrayType
239229
/// type.
240230
unsigned getArrayStride() const;
241231

242-
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
243-
std::optional<StorageClass> storage = std::nullopt);
244232
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,
245233
std::optional<StorageClass> storage = std::nullopt);
246234
};
@@ -265,8 +253,6 @@ class SampledImageType
265253

266254
Type getImageType() const;
267255

268-
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
269-
std::optional<spirv::StorageClass> storage = std::nullopt);
270256
void
271257
getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,
272258
std::optional<spirv::StorageClass> storage = std::nullopt);
@@ -420,8 +406,6 @@ class StructType
420406
ArrayRef<MemberDecorationInfo> memberDecorations = {},
421407
ArrayRef<StructDecorationInfo> structDecorations = {});
422408

423-
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
424-
std::optional<StorageClass> storage = std::nullopt);
425409
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,
426410
std::optional<StorageClass> storage = std::nullopt);
427411
};
@@ -456,8 +440,6 @@ class CooperativeMatrixType
456440
/// Returns the use parameter of the cooperative matrix.
457441
CooperativeMatrixUseKHR getUse() const;
458442

459-
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
460-
std::optional<StorageClass> storage = std::nullopt);
461443
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,
462444
std::optional<StorageClass> storage = std::nullopt);
463445

@@ -512,8 +494,6 @@ class MatrixType : public Type::TypeBase<MatrixType, CompositeType,
512494
/// Returns the elements' type (i.e, single element type).
513495
Type getElementType() const;
514496

515-
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
516-
std::optional<StorageClass> storage = std::nullopt);
517497
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,
518498
std::optional<StorageClass> storage = std::nullopt);
519499
};
@@ -552,8 +532,6 @@ class TensorArmType
552532
bool hasRank() const { return !getShape().empty(); }
553533
operator ShapedType() const { return llvm::cast<ShapedType>(*this); }
554534

555-
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
556-
std::optional<StorageClass> storage = std::nullopt);
557535
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,
558536
std::optional<StorageClass> storage = std::nullopt);
559537
};

mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp

Lines changed: 3 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,8 @@ namespace {
2929
// subtypes. Maintains a set of `seen` types to avoid recursion in structs.
3030
//
3131
// Serves as the source-of-truth for type extension information. All extension
32-
// logic should be added to this class, while
33-
// `*Type::getExtensions` functions should not handle extension-related logic
32+
// logic should be added to this class, while the
33+
// `SPIRVType::getExtensions` function should not handle extension-related logic
3434
// directly and only invoke `TypeExtensionVisitor::add(Type *)`.
3535
class TypeExtensionVisitor {
3636
public:
@@ -59,9 +59,7 @@ class TypeExtensionVisitor {
5959
.Default([](SPIRVType) { llvm_unreachable("Unhandled type"); });
6060
}
6161

62-
// Convenience overloads for use in `T::getExtensions` functions.
6362
void add(Type type) { add(cast<SPIRVType>(type)); }
64-
void add(Type *type) { add(cast<SPIRVType>(*type)); }
6563

6664
private:
6765
// Types that add unique extensions.
@@ -120,11 +118,6 @@ Type ArrayType::getElementType() const { return getImpl()->elementType; }
120118

121119
unsigned ArrayType::getArrayStride() const { return getImpl()->stride; }
122120

123-
void ArrayType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
124-
std::optional<StorageClass> storage) {
125-
TypeExtensionVisitor{extensions, storage}.add(this);
126-
}
127-
128121
void ArrayType::getCapabilities(
129122
SPIRVType::CapabilityArrayRefVector &capabilities,
130123
std::optional<StorageClass> storage) {
@@ -195,12 +188,6 @@ bool CompositeType::hasCompileTimeKnownNumElements() const {
195188
return !llvm::isa<CooperativeMatrixType, RuntimeArrayType>(*this);
196189
}
197190

198-
void CompositeType::getExtensions(
199-
SPIRVType::ExtensionArrayRefVector &extensions,
200-
std::optional<StorageClass> storage) {
201-
TypeExtensionVisitor{extensions, storage}.add(cast<SPIRVType>(*this));
202-
}
203-
204191
void CompositeType::getCapabilities(
205192
SPIRVType::CapabilityArrayRefVector &capabilities,
206193
std::optional<StorageClass> storage) {
@@ -330,12 +317,6 @@ void TypeExtensionVisitor::addConcrete(CooperativeMatrixType type) {
330317
extensions.push_back(ext);
331318
}
332319

333-
void CooperativeMatrixType::getExtensions(
334-
SPIRVType::ExtensionArrayRefVector &extensions,
335-
std::optional<StorageClass> storage) {
336-
TypeExtensionVisitor{extensions, storage}.add(this);
337-
}
338-
339320
void CooperativeMatrixType::getCapabilities(
340321
SPIRVType::CapabilityArrayRefVector &capabilities,
341322
std::optional<StorageClass> storage) {
@@ -447,11 +428,6 @@ ImageSamplerUseInfo ImageType::getSamplerUseInfo() const {
447428

448429
ImageFormat ImageType::getImageFormat() const { return getImpl()->format; }
449430

450-
void ImageType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
451-
std::optional<StorageClass> storage) {
452-
TypeExtensionVisitor{extensions, storage}.add(this);
453-
}
454-
455431
void ImageType::getCapabilities(
456432
SPIRVType::CapabilityArrayRefVector &capabilities,
457433
std::optional<StorageClass>) {
@@ -510,11 +486,6 @@ void TypeExtensionVisitor::addConcrete(PointerType type) {
510486
extensions.push_back(*scExts);
511487
}
512488

513-
void PointerType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
514-
std::optional<StorageClass> storage) {
515-
TypeExtensionVisitor{extensions, storage}.add(this);
516-
}
517-
518489
void PointerType::getCapabilities(
519490
SPIRVType::CapabilityArrayRefVector &capabilities,
520491
std::optional<StorageClass> storage) {
@@ -563,12 +534,6 @@ Type RuntimeArrayType::getElementType() const { return getImpl()->elementType; }
563534

564535
unsigned RuntimeArrayType::getArrayStride() const { return getImpl()->stride; }
565536

566-
void RuntimeArrayType::getExtensions(
567-
SPIRVType::ExtensionArrayRefVector &extensions,
568-
std::optional<StorageClass> storage) {
569-
TypeExtensionVisitor{extensions, storage}.add(this);
570-
}
571-
572537
void RuntimeArrayType::getCapabilities(
573538
SPIRVType::CapabilityArrayRefVector &capabilities,
574539
std::optional<StorageClass> storage) {
@@ -636,11 +601,6 @@ void TypeExtensionVisitor::addConcrete(ScalarType type) {
636601
}
637602
}
638603

639-
void ScalarType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
640-
std::optional<StorageClass> storage) {
641-
TypeExtensionVisitor{extensions, storage}.add(this);
642-
}
643-
644604
void ScalarType::getCapabilities(
645605
SPIRVType::CapabilityArrayRefVector &capabilities,
646606
std::optional<StorageClass> storage) {
@@ -774,7 +734,7 @@ bool SPIRVType::isScalarOrVector() {
774734

775735
void SPIRVType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
776736
std::optional<StorageClass> storage) {
777-
TypeExtensionVisitor{extensions, storage}.add(this);
737+
TypeExtensionVisitor{extensions, storage}.add(*this);
778738
}
779739

780740
void SPIRVType::getCapabilities(
@@ -854,12 +814,6 @@ SampledImageType::verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
854814
return success();
855815
}
856816

857-
void SampledImageType::getExtensions(
858-
SPIRVType::ExtensionArrayRefVector &extensions,
859-
std::optional<StorageClass> storage) {
860-
TypeExtensionVisitor{extensions, storage}.add(this);
861-
}
862-
863817
void SampledImageType::getCapabilities(
864818
SPIRVType::CapabilityArrayRefVector &capabilities,
865819
std::optional<StorageClass> storage) {
@@ -1218,11 +1172,6 @@ StructType::trySetBody(ArrayRef<Type> memberTypes,
12181172
structDecorations);
12191173
}
12201174

1221-
void StructType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
1222-
std::optional<StorageClass> storage) {
1223-
TypeExtensionVisitor{extensions, storage}.add(this);
1224-
}
1225-
12261175
void StructType::getCapabilities(
12271176
SPIRVType::CapabilityArrayRefVector &capabilities,
12281177
std::optional<StorageClass> storage) {
@@ -1322,11 +1271,6 @@ unsigned MatrixType::getNumElements() const {
13221271
return (getImpl()->columnCount) * getNumRows();
13231272
}
13241273

1325-
void MatrixType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
1326-
std::optional<StorageClass> storage) {
1327-
TypeExtensionVisitor{extensions, storage}.add(this);
1328-
}
1329-
13301274
void MatrixType::getCapabilities(
13311275
SPIRVType::CapabilityArrayRefVector &capabilities,
13321276
std::optional<StorageClass> storage) {
@@ -1388,12 +1332,6 @@ void TypeExtensionVisitor::addConcrete(TensorArmType type) {
13881332
extensions.push_back(ext);
13891333
}
13901334

1391-
void TensorArmType::getExtensions(
1392-
SPIRVType::ExtensionArrayRefVector &extensions,
1393-
std::optional<StorageClass> storage) {
1394-
TypeExtensionVisitor{extensions, storage}.add(this);
1395-
}
1396-
13971335
void TensorArmType::getCapabilities(
13981336
SPIRVType::CapabilityArrayRefVector &capabilities,
13991337
std::optional<StorageClass> storage) {

0 commit comments

Comments
 (0)