@@ -203,25 +203,13 @@ Type CompositeType::getElementType(unsigned index) const {
203203}
204204
205205unsigned CompositeType::getNumElements () const {
206- if (auto arrayType = llvm::dyn_cast<ArrayType>(*this ))
207- return arrayType.getNumElements ();
208- if (auto matrixType = llvm::dyn_cast<MatrixType>(*this ))
209- return matrixType.getNumColumns ();
210- if (auto structType = llvm::dyn_cast<StructType>(*this ))
211- return structType.getNumElements ();
212- if (auto vectorType = llvm::dyn_cast<VectorType>(*this ))
213- return vectorType.getNumElements ();
214- if (auto tensorArmType = dyn_cast<TensorArmType>(*this ))
215- return tensorArmType.getNumElements ();
216- if (llvm::isa<CooperativeMatrixType>(*this )) {
217- llvm_unreachable (
218- " invalid to query number of elements of spirv Cooperative Matrix type" );
219- }
220- if (llvm::isa<RuntimeArrayType>(*this )) {
221- llvm_unreachable (
222- " invalid to query number of elements of spirv::RuntimeArray type" );
223- }
224- llvm_unreachable (" invalid composite type" );
206+ return TypeSwitch<SPIRVType, unsigned >(*this )
207+ .Case <ArrayType, StructType, TensorArmType, VectorType>(
208+ [](auto type) { return type.getNumElements (); })
209+ .Case <MatrixType>([](MatrixType type) { return type.getNumColumns (); })
210+ .Default ([](SPIRVType) -> unsigned {
211+ llvm_unreachable (" Invalid type for number of elements query" );
212+ });
225213}
226214
227215bool CompositeType::hasCompileTimeKnownNumElements () const {
0 commit comments