Skip to content

Commit b425572

Browse files
kuhargithub-actions[bot]
authored andcommitted
Automerge: [mlir][spirv] Rework type capability queries (#160113)
* Fix infinite recursion with nested structs. * Drop `::getCapbilities` function from derived types, so that there's only one entry point that queries type extensions. * Move all capability logic to a new helper class -- this way the `::getCapabilities` functions can't diverge across concrete types and 'convenience types' like CompositeType. Fixes: #159963
2 parents 53dbdef + ca7c058 commit b425572

File tree

4 files changed

+121
-177
lines changed

4 files changed

+121
-177
lines changed

mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h

Lines changed: 0 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -89,9 +89,6 @@ class ScalarType : public SPIRVType {
8989
/// Returns true if the given float type is valid for the SPIR-V dialect.
9090
static bool isValid(IntegerType);
9191

92-
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,
93-
std::optional<StorageClass> storage = std::nullopt);
94-
9592
std::optional<int64_t> getSizeInBytes();
9693
};
9794

@@ -116,9 +113,6 @@ class CompositeType : public SPIRVType {
116113
/// implementation dependent.
117114
bool hasCompileTimeKnownNumElements() const;
118115

119-
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,
120-
std::optional<StorageClass> storage = std::nullopt);
121-
122116
std::optional<int64_t> getSizeInBytes();
123117
};
124118

@@ -144,9 +138,6 @@ class ArrayType : public Type::TypeBase<ArrayType, CompositeType,
144138
/// type.
145139
unsigned getArrayStride() const;
146140

147-
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,
148-
std::optional<StorageClass> storage = std::nullopt);
149-
150141
/// Returns the array size in bytes. Since array type may have an explicit
151142
/// stride declaration (in bytes), we also include it in the calculation.
152143
std::optional<int64_t> getSizeInBytes();
@@ -186,9 +177,6 @@ class ImageType
186177
ImageSamplerUseInfo getSamplerUseInfo() const;
187178
ImageFormat getImageFormat() const;
188179
// TODO: Add support for Access qualifier
189-
190-
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,
191-
std::optional<StorageClass> storage = std::nullopt);
192180
};
193181

194182
// SPIR-V pointer type
@@ -204,9 +192,6 @@ class PointerType : public Type::TypeBase<PointerType, SPIRVType,
204192
Type getPointeeType() const;
205193

206194
StorageClass getStorageClass() const;
207-
208-
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,
209-
std::optional<StorageClass> storage = std::nullopt);
210195
};
211196

212197
// SPIR-V run-time array type
@@ -228,9 +213,6 @@ class RuntimeArrayType
228213
/// Returns the array stride in bytes. 0 means no stride decorated on this
229214
/// type.
230215
unsigned getArrayStride() const;
231-
232-
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,
233-
std::optional<StorageClass> storage = std::nullopt);
234216
};
235217

236218
// SPIR-V sampled image type
@@ -252,10 +234,6 @@ class SampledImageType
252234
Type imageType);
253235

254236
Type getImageType() const;
255-
256-
void
257-
getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,
258-
std::optional<spirv::StorageClass> storage = std::nullopt);
259237
};
260238

261239
/// SPIR-V struct type. Two kinds of struct types are supported:
@@ -405,9 +383,6 @@ class StructType
405383
trySetBody(ArrayRef<Type> memberTypes, ArrayRef<OffsetInfo> offsetInfo = {},
406384
ArrayRef<MemberDecorationInfo> memberDecorations = {},
407385
ArrayRef<StructDecorationInfo> structDecorations = {});
408-
409-
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,
410-
std::optional<StorageClass> storage = std::nullopt);
411386
};
412387

413388
llvm::hash_code
@@ -440,9 +415,6 @@ class CooperativeMatrixType
440415
/// Returns the use parameter of the cooperative matrix.
441416
CooperativeMatrixUseKHR getUse() const;
442417

443-
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,
444-
std::optional<StorageClass> storage = std::nullopt);
445-
446418
operator ShapedType() const { return llvm::cast<ShapedType>(*this); }
447419

448420
ArrayRef<int64_t> getShape() const;
@@ -493,9 +465,6 @@ class MatrixType : public Type::TypeBase<MatrixType, CompositeType,
493465

494466
/// Returns the elements' type (i.e, single element type).
495467
Type getElementType() const;
496-
497-
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,
498-
std::optional<StorageClass> storage = std::nullopt);
499468
};
500469

501470
/// SPIR-V TensorARM Type
@@ -531,9 +500,6 @@ class TensorArmType
531500
ArrayRef<int64_t> getShape() const;
532501
bool hasRank() const { return !getShape().empty(); }
533502
operator ShapedType() const { return llvm::cast<ShapedType>(*this); }
534-
535-
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,
536-
std::optional<StorageClass> storage = std::nullopt);
537503
};
538504

539505
} // namespace spirv

0 commit comments

Comments
 (0)