From 608cf7bb0488d94edaace1674450cf01143e3f45 Mon Sep 17 00:00:00 2001 From: Jakub Kuderski Date: Thu, 30 Oct 2025 14:50:23 -0400 Subject: [PATCH] [mlir] Simplify Default cases in type switches Use default values instead of lambdas when possible. `std::nullopt` and `nullptr` can be used now because of https://github.com/llvm/llvm-project/pull/165724. --- mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp | 2 +- mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp | 2 +- mlir/lib/Conversion/GPUToLLVMSPV/GPUToLLVMSPV.cpp | 2 +- mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp | 2 +- mlir/lib/Dialect/Arith/IR/ArithOps.cpp | 2 +- mlir/lib/Dialect/GPU/Transforms/EliminateBarriers.cpp | 6 +++--- mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp | 2 +- mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp | 2 +- mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp | 2 +- mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp | 2 +- mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp | 2 +- .../lib/Dialect/Linalg/Transforms/SimplifyDepthwiseConv.cpp | 2 +- mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp | 2 +- mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp | 2 +- mlir/lib/Dialect/SCF/Transforms/LoopCanonicalization.cpp | 2 +- mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp | 4 ++-- mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp | 2 +- mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp | 2 +- mlir/lib/Dialect/Tensor/Transforms/RewriteAsConstant.cpp | 2 +- mlir/lib/TableGen/Type.cpp | 2 +- mlir/lib/Target/LLVMIR/DebugTranslation.cpp | 6 +++--- 21 files changed, 26 insertions(+), 26 deletions(-) diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp index 41e333c621eda..3a307a0756d93 100644 --- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp +++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp @@ -935,7 +935,7 @@ static std::optional mfmaTypeSelectCode(Type mlirElemType) { .Case([](Float6E2M3FNType) { return 2u; }) .Case([](Float6E3M2FNType) { return 3u; }) .Case([](Float4E2M1FNType) { return 4u; }) - .Default([](Type) { return std::nullopt; }); + .Default(std::nullopt); } /// If there is a scaled MFMA instruction for the input element types `aType` diff --git a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp index 247dba101cfc1..cfdcd9cc2d86d 100644 --- a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp +++ b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp @@ -432,7 +432,7 @@ static Value getOriginalVectorValue(Value value) { current = op.getSource(); return false; }) - .Default([](Operation *) { return false; }); + .Default(false); if (!skipOp) { break; diff --git a/mlir/lib/Conversion/GPUToLLVMSPV/GPUToLLVMSPV.cpp b/mlir/lib/Conversion/GPUToLLVMSPV/GPUToLLVMSPV.cpp index 25f1e1b184d61..425594b3382f0 100644 --- a/mlir/lib/Conversion/GPUToLLVMSPV/GPUToLLVMSPV.cpp +++ b/mlir/lib/Conversion/GPUToLLVMSPV/GPUToLLVMSPV.cpp @@ -259,7 +259,7 @@ struct GPUShuffleConversion final : ConvertOpToLLVMPattern { } return std::nullopt; }) - .Default([](auto) { return std::nullopt; }); + .Default(std::nullopt); } static std::optional getFuncName(gpu::ShuffleMode mode, diff --git a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp index e2c7d803e5a5e..91c1aa55fdb4e 100644 --- a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp +++ b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp @@ -46,7 +46,7 @@ static bool isZeroConstant(Value val) { [](auto floatAttr) { return floatAttr.getValue().isZero(); }) .Case( [](auto intAttr) { return intAttr.getValue().isZero(); }) - .Default([](auto) { return false; }); + .Default(false); } static LogicalResult storeLoadPreconditions(PatternRewriter &rewriter, diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp index 898d76ce8d9b5..980442efdf708 100644 --- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp +++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp @@ -2751,7 +2751,7 @@ std::optional mlir::arith::getNeutralElement(Operation *op) { .Case([](arith::MaxSIOp op) { return AtomicRMWKind::maxs; }) .Case([](arith::MinSIOp op) { return AtomicRMWKind::mins; }) .Case([](arith::MulIOp op) { return AtomicRMWKind::muli; }) - .Default([](Operation *op) { return std::nullopt; }); + .Default(std::nullopt); if (!maybeKind) { return std::nullopt; } diff --git a/mlir/lib/Dialect/GPU/Transforms/EliminateBarriers.cpp b/mlir/lib/Dialect/GPU/Transforms/EliminateBarriers.cpp index d2c2138d61638..025d1acf8d6ba 100644 --- a/mlir/lib/Dialect/GPU/Transforms/EliminateBarriers.cpp +++ b/mlir/lib/Dialect/GPU/Transforms/EliminateBarriers.cpp @@ -330,7 +330,7 @@ static Value getBase(Value v) { v = op.getSrc(); return true; }) - .Default([](Operation *) { return false; }); + .Default(false); if (!shouldContinue) break; } @@ -354,7 +354,7 @@ static Value propagatesCapture(Operation *op) { .Case([](memref::TransposeOp transpose) { return transpose.getIn(); }) .Case( [](auto op) { return op.getSrc(); }) - .Default([](Operation *) { return Value(); }); + .Default(nullptr); } /// Returns `true` if the given operation is known to capture the given value, @@ -371,7 +371,7 @@ static std::optional getKnownCapturingStatus(Operation *op, Value v) { // These operations are known not to capture. .Case([](memref::DeallocOp) { return false; }) // By default, we don't know anything. - .Default([](Operation *) { return std::nullopt; }); + .Default(std::nullopt); } /// Returns `true` if the value may be captured by any of its users, i.e., if diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp index 3eae67f4c1f98..2731069d6ef54 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -698,7 +698,7 @@ static void destructureIndices(Type currType, ArrayRef indices, return structType.getBody()[memberIndex]; return nullptr; }) - .Default(Type(nullptr)); + .Default(nullptr); } } diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp index cee943d2d86c6..7d9058c262562 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp @@ -1111,7 +1111,7 @@ memsetCanUsesBeRemoved(MemsetIntr op, const MemorySlot &slot, .Case([](auto type) { return type.getWidth() % 8 == 0 && type.getWidth() > 0; }) - .Default([](Type) { return false; }); + .Default(false); if (!canConvertType) return false; diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp index ac35eea66e9d6..ce93d18f56d39 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp @@ -798,7 +798,7 @@ static bool isCompatibleImpl(Type type, DenseSet &compatibleTypes) { // clang-format on .Case( [](Type type) { return isCompatiblePtrType(type); }) - .Default([](Type) { return false; }); + .Default(false); if (!result) compatibleTypes.erase(type); diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp index 8b89244486339..b09112bcf0bb7 100644 --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -4499,7 +4499,7 @@ DiagnosedSilenceableFailure transform::DecomposeWinogradOp::applyToOne( maybeTransformed = decomposeWinogradOutputTransformOp(rewriter, op); return true; }) - .Default([&](Operation *op) { return false; }); + .Default(false); if (!supported) { DiagnosedSilenceableFailure diag = diff --git a/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp index f05ffa8334d9c..6519c4f64dd05 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp @@ -322,7 +322,7 @@ promoteSubViews(ImplicitLocOpBuilder &b, tmp = arith::ConstantOp::create(b, IntegerAttr::get(et, 0)); return complex::CreateOp::create(b, t, tmp, tmp); }) - .Default([](auto) { return Value(); }); + .Default(nullptr); if (!fillVal) return failure(); linalg::FillOp::create(b, fillVal, promotionInfo->fullLocalView); diff --git a/mlir/lib/Dialect/Linalg/Transforms/SimplifyDepthwiseConv.cpp b/mlir/lib/Dialect/Linalg/Transforms/SimplifyDepthwiseConv.cpp index 27ccf3c2ba148..6becc1f29afbd 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/SimplifyDepthwiseConv.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/SimplifyDepthwiseConv.cpp @@ -89,7 +89,7 @@ matchAndReplaceDepthwiseConv(Operation *operation, Value input, Value kernel, ValueRange{input, collapsedKernel, iZp, kZp}, ValueRange{collapsedInit}, stride, dilation); }) - .Default([](Operation *op) { return nullptr; }); + .Default(nullptr); if (!newConv) return failure(); for (auto attr : preservedAttrs) diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp index 0f317eac8fa41..cb6199f026e03 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -656,7 +656,7 @@ mlir::linalg::getCombinerOpKind(Operation *combinerOp) { [&](auto op) { return CombiningKind::MUL; }) .Case([&](auto op) { return CombiningKind::OR; }) .Case([&](auto op) { return CombiningKind::XOR; }) - .Default([&](auto op) { return std::nullopt; }); + .Default(std::nullopt); } /// Check whether `outputOperand` is a reduction with a single combiner diff --git a/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp b/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp index 1208fddf37e0b..e6850890bf8fe 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp @@ -104,7 +104,7 @@ static Value getTargetMemref(Operation *op) { vector::MaskedStoreOp, vector::TransferReadOp, vector::TransferWriteOp>( [](auto op) { return op.getBase(); }) - .Default([](auto) { return Value{}; }); + .Default(nullptr); } template diff --git a/mlir/lib/Dialect/SCF/Transforms/LoopCanonicalization.cpp b/mlir/lib/Dialect/SCF/Transforms/LoopCanonicalization.cpp index 4ebd90dbcc1d5..d380c46f7fbee 100644 --- a/mlir/lib/Dialect/SCF/Transforms/LoopCanonicalization.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/LoopCanonicalization.cpp @@ -55,7 +55,7 @@ static bool isShapePreserving(ForOp forOp, int64_t arg) { ? forOp.getInitArgs()[opResult.getResultNumber()] : Value(); }) - .Default([&](auto op) { return Value(); }); + .Default(nullptr); } return false; } diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp index 0c8114d5e957e..938952ed273cd 100644 --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp @@ -346,7 +346,7 @@ LogicalResult spirv::CompositeConstructOp::verify() { llvm::TypeSwitch(getType()) .Case( [](auto coopType) { return coopType.getElementType(); }) - .Default([](Type) { return nullptr; }); + .Default(nullptr); // Case 1. -- matrices. if (coopElementType) { @@ -1708,7 +1708,7 @@ LogicalResult spirv::MatrixTimesScalarOp::verify() { llvm::TypeSwitch(getMatrix().getType()) .Case( [](auto matrixType) { return matrixType.getElementType(); }) - .Default([](Type) { return nullptr; }); + .Default(nullptr); assert(elementType && "Unhandled type"); diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp index f895807ea1d18..d1e275d590f78 100644 --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp @@ -731,7 +731,7 @@ std::optional SPIRVType::getSizeInBytes() { return *elementSize * type.getNumElements(); return std::nullopt; }) - .Default(std::optional()); + .Default(std::nullopt); } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp index 88e1ab6ab1e4d..cb9b7f6ec2fd2 100644 --- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp @@ -1467,7 +1467,7 @@ mlir::spirv::getNativeVectorShape(Operation *op) { return TypeSwitch>>(op) .Case( [](auto typedOp) { return getNativeVectorShapeImpl(typedOp); }) - .Default([](Operation *) { return std::nullopt; }); + .Default(std::nullopt); } LogicalResult mlir::spirv::unrollVectorsInSignatures(Operation *op) { diff --git a/mlir/lib/Dialect/Tensor/Transforms/RewriteAsConstant.cpp b/mlir/lib/Dialect/Tensor/Transforms/RewriteAsConstant.cpp index 69e649d2eebe8..bc4f5a5ac7f23 100644 --- a/mlir/lib/Dialect/Tensor/Transforms/RewriteAsConstant.cpp +++ b/mlir/lib/Dialect/Tensor/Transforms/RewriteAsConstant.cpp @@ -189,7 +189,7 @@ struct PadOpToConstant final : public OpRewritePattern { return constantFoldPadOp( rewriter, loc, inputAttr, integerAttr, *lowPad, *highPad); }) - .Default(Value()); + .Default(nullptr); if (!newOp) return rewriter.notifyMatchFailure(padTensorOp, diff --git a/mlir/lib/TableGen/Type.cpp b/mlir/lib/TableGen/Type.cpp index b31377e0de3e9..0f1bf83d1987b 100644 --- a/mlir/lib/TableGen/Type.cpp +++ b/mlir/lib/TableGen/Type.cpp @@ -56,7 +56,7 @@ std::optional TypeConstraint::getBuilderCall() const { StringRef value = init->getValue(); return value.empty() ? std::optional() : value; }) - .Default([](auto *) { return std::nullopt; }); + .Default(std::nullopt); } // Return the C++ type for this type (which may just be ::mlir::Type). diff --git a/mlir/lib/Target/LLVMIR/DebugTranslation.cpp b/mlir/lib/Target/LLVMIR/DebugTranslation.cpp index eeb87253e5eb8..e3bcf2749be13 100644 --- a/mlir/lib/Target/LLVMIR/DebugTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/DebugTranslation.cpp @@ -390,7 +390,7 @@ llvm::DISubrange *DebugTranslation::translateImpl(DISubrangeAttr attr) { .Case<>([&](LLVM::DIGlobalVariableAttr global) { return translate(global); }) - .Default([&](Attribute attr) { return nullptr; }); + .Default(nullptr); return metadata; }; return llvm::DISubrange::get(llvmCtx, getMetadataOrNull(attr.getCount()), @@ -420,10 +420,10 @@ DebugTranslation::translateImpl(DIGenericSubrangeAttr attr) { .Case([&](LLVM::DILocalVariableAttr local) { return translate(local); }) - .Case<>([&](LLVM::DIGlobalVariableAttr global) { + .Case([&](LLVM::DIGlobalVariableAttr global) { return translate(global); }) - .Default([&](Attribute attr) { return nullptr; }); + .Default(nullptr); return metadata; }; return llvm::DIGenericSubrange::get(llvmCtx,