@@ -203,25 +203,13 @@ Type CompositeType::getElementType(unsigned index) const {
203
203
}
204
204
205
205
unsigned CompositeType::getNumElements () const {
206
- if (auto arrayType = llvm::dyn_cast<ArrayType>(*this ))
207
- return arrayType.getNumElements ();
208
- if (auto matrixType = llvm::dyn_cast<MatrixType>(*this ))
209
- return matrixType.getNumColumns ();
210
- if (auto structType = llvm::dyn_cast<StructType>(*this ))
211
- return structType.getNumElements ();
212
- if (auto vectorType = llvm::dyn_cast<VectorType>(*this ))
213
- return vectorType.getNumElements ();
214
- if (auto tensorArmType = dyn_cast<TensorArmType>(*this ))
215
- return tensorArmType.getNumElements ();
216
- if (llvm::isa<CooperativeMatrixType>(*this )) {
217
- llvm_unreachable (
218
- " invalid to query number of elements of spirv Cooperative Matrix type" );
219
- }
220
- if (llvm::isa<RuntimeArrayType>(*this )) {
221
- llvm_unreachable (
222
- " invalid to query number of elements of spirv::RuntimeArray type" );
223
- }
224
- llvm_unreachable (" invalid composite type" );
206
+ return TypeSwitch<SPIRVType, unsigned >(*this )
207
+ .Case <ArrayType, StructType, TensorArmType, VectorType>(
208
+ [](auto type) { return type.getNumElements (); })
209
+ .Case <MatrixType>([](MatrixType type) { return type.getNumColumns (); })
210
+ .Default ([](SPIRVType) -> unsigned {
211
+ llvm_unreachable (" Invalid type for number of elements query" );
212
+ });
225
213
}
226
214
227
215
bool CompositeType::hasCompileTimeKnownNumElements () const {
0 commit comments