Skip to content

Commit 28d68f9

Browse files
authored
[mlir][spirv] Simplify CompositeType::getNumElements. NFC. (#160202)
Use a type switch to simplify the implementation.
1 parent 01b60df commit 28d68f9

File tree

1 file changed

+7
-19
lines changed

1 file changed

+7
-19
lines changed

mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp

Lines changed: 7 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -203,25 +203,13 @@ Type CompositeType::getElementType(unsigned index) const {
203203
}
204204

205205
unsigned CompositeType::getNumElements() const {
206-
if (auto arrayType = llvm::dyn_cast<ArrayType>(*this))
207-
return arrayType.getNumElements();
208-
if (auto matrixType = llvm::dyn_cast<MatrixType>(*this))
209-
return matrixType.getNumColumns();
210-
if (auto structType = llvm::dyn_cast<StructType>(*this))
211-
return structType.getNumElements();
212-
if (auto vectorType = llvm::dyn_cast<VectorType>(*this))
213-
return vectorType.getNumElements();
214-
if (auto tensorArmType = dyn_cast<TensorArmType>(*this))
215-
return tensorArmType.getNumElements();
216-
if (llvm::isa<CooperativeMatrixType>(*this)) {
217-
llvm_unreachable(
218-
"invalid to query number of elements of spirv Cooperative Matrix type");
219-
}
220-
if (llvm::isa<RuntimeArrayType>(*this)) {
221-
llvm_unreachable(
222-
"invalid to query number of elements of spirv::RuntimeArray type");
223-
}
224-
llvm_unreachable("invalid composite type");
206+
return TypeSwitch<SPIRVType, unsigned>(*this)
207+
.Case<ArrayType, StructType, TensorArmType, VectorType>(
208+
[](auto type) { return type.getNumElements(); })
209+
.Case<MatrixType>([](MatrixType type) { return type.getNumColumns(); })
210+
.Default([](SPIRVType) -> unsigned {
211+
llvm_unreachable("Invalid type for number of elements query");
212+
});
225213
}
226214

227215
bool CompositeType::hasCompileTimeKnownNumElements() const {

0 commit comments

Comments
 (0)