diff --git a/flang/include/flang/Optimizer/Dialect/FIROps.td b/flang/include/flang/Optimizer/Dialect/FIROps.td index 90e05ce3d5ca6..27a6ca4ebdb4e 100644 --- a/flang/include/flang/Optimizer/Dialect/FIROps.td +++ b/flang/include/flang/Optimizer/Dialect/FIROps.td @@ -2323,9 +2323,13 @@ def fir_DoLoopOp : region_Op<"do_loop", [AttrSizedOperandSegments, }]; } -def fir_IfOp : region_Op<"if", [DeclareOpInterfaceMethods, RecursiveMemoryEffects, - NoRegionArguments]> { +def fir_IfOp + : region_Op< + "if", [DeclareOpInterfaceMethods< + RegionBranchOpInterface, ["getRegionInvocationBounds", + "getEntrySuccessorRegions"]>, + RecursiveMemoryEffects, NoRegionArguments, + WeightedRegionBranchOpInterface]> { let summary = "if-then-else conditional operation"; let description = [{ Used to conditionally execute operations. This operation is the FIR @@ -2342,7 +2346,8 @@ def fir_IfOp : region_Op<"if", [DeclareOpInterfaceMethods:$region_weights); let results = (outs Variadic:$results); let regions = (region @@ -2371,6 +2376,11 @@ def fir_IfOp : region_Op<"if", [DeclareOpInterfaceMethods &results, unsigned resultNum); + + /// Returns the display name string for the region_weights attribute. + static constexpr llvm::StringRef getWeightsAttrAssemblyName() { + return "weights"; + } }]; } diff --git a/flang/lib/Optimizer/Dialect/FIROps.cpp b/flang/lib/Optimizer/Dialect/FIROps.cpp index 6181e1fad4240..ecfa2939e96a6 100644 --- a/flang/lib/Optimizer/Dialect/FIROps.cpp +++ b/flang/lib/Optimizer/Dialect/FIROps.cpp @@ -4418,6 +4418,19 @@ mlir::ParseResult fir::IfOp::parse(mlir::OpAsmParser &parser, parser.resolveOperand(cond, i1Type, result.operands)) return mlir::failure(); + if (mlir::succeeded( + parser.parseOptionalKeyword(getWeightsAttrAssemblyName()))) { + if (parser.parseLParen()) + return mlir::failure(); + mlir::DenseI32ArrayAttr weights; + if (parser.parseCustomAttributeWithFallback(weights, mlir::Type{})) + return mlir::failure(); + if (weights) + result.addAttribute(getRegionWeightsAttrName(result.name), weights); + if (parser.parseRParen()) + return mlir::failure(); + } + if (parser.parseOptionalArrowTypeList(result.types)) return mlir::failure(); @@ -4449,6 +4462,11 @@ llvm::LogicalResult fir::IfOp::verify() { void fir::IfOp::print(mlir::OpAsmPrinter &p) { bool printBlockTerminators = false; p << ' ' << getCondition(); + if (auto weights = getRegionWeightsAttr()) { + p << ' ' << getWeightsAttrAssemblyName() << '('; + p.printStrippedAttrOrType(weights); + p << ')'; + } if (!getResults().empty()) { p << " -> (" << getResultTypes() << ')'; printBlockTerminators = true; @@ -4464,7 +4482,8 @@ void fir::IfOp::print(mlir::OpAsmPrinter &p) { p.printRegion(otherReg, /*printEntryBlockArgs=*/false, printBlockTerminators); } - p.printOptionalAttrDict((*this)->getAttrs()); + p.printOptionalAttrDict((*this)->getAttrs(), + /*elideAttrs=*/{getRegionWeightsAttrName()}); } void fir::IfOp::resultToSourceOps(llvm::SmallVectorImpl &results, diff --git a/flang/lib/Optimizer/Transforms/ControlFlowConverter.cpp b/flang/lib/Optimizer/Transforms/ControlFlowConverter.cpp index 8a9e9b80134b8..3d35803e6a2d3 100644 --- a/flang/lib/Optimizer/Transforms/ControlFlowConverter.cpp +++ b/flang/lib/Optimizer/Transforms/ControlFlowConverter.cpp @@ -212,9 +212,12 @@ class CfgIfConv : public mlir::OpRewritePattern { } rewriter.setInsertionPointToEnd(condBlock); - rewriter.create( + auto branchOp = rewriter.create( loc, ifOp.getCondition(), ifOpBlock, llvm::ArrayRef(), otherwiseBlock, llvm::ArrayRef()); + llvm::ArrayRef weights = ifOp.getWeights(); + if (!weights.empty()) + branchOp.setWeights(weights); rewriter.replaceOp(ifOp, continueBlock->getArguments()); return success(); } diff --git a/flang/test/Fir/cfg-conversion-if.fir b/flang/test/Fir/cfg-conversion-if.fir new file mode 100644 index 0000000000000..1e30ee8e64f02 --- /dev/null +++ b/flang/test/Fir/cfg-conversion-if.fir @@ -0,0 +1,46 @@ +// RUN: fir-opt --split-input-file --cfg-conversion %s | FileCheck %s + +func.func private @callee() -> none + +// CHECK-LABEL: func.func @if_then( +// CHECK-SAME: %[[ARG0:.*]]: i1) { +// CHECK: cf.cond_br %[[ARG0]] weights([10, 90]), ^bb1, ^bb2 +// CHECK: ^bb1: +// CHECK: %[[VAL_0:.*]] = fir.call @callee() : () -> none +// CHECK: cf.br ^bb2 +// CHECK: ^bb2: +// CHECK: return +// CHECK: } +func.func @if_then(%cond: i1) { + fir.if %cond weights([10, 90]) { + fir.call @callee() : () -> none + } + return +} + +// ----- + +// CHECK-LABEL: func.func @if_then_else( +// CHECK-SAME: %[[ARG0:.*]]: i1) -> i32 { +// CHECK: %[[VAL_0:.*]] = arith.constant 0 : i32 +// CHECK: %[[VAL_1:.*]] = arith.constant 1 : i32 +// CHECK: cf.cond_br %[[ARG0]] weights([90, 10]), ^bb1, ^bb2 +// CHECK: ^bb1: +// CHECK: cf.br ^bb3(%[[VAL_0]] : i32) +// CHECK: ^bb2: +// CHECK: cf.br ^bb3(%[[VAL_1]] : i32) +// CHECK: ^bb3(%[[VAL_2:.*]]: i32): +// CHECK: cf.br ^bb4 +// CHECK: ^bb4: +// CHECK: return %[[VAL_2]] : i32 +// CHECK: } +func.func @if_then_else(%cond: i1) -> i32 { + %c0 = arith.constant 0 : i32 + %c1 = arith.constant 1 : i32 + %result = fir.if %cond weights([90, 10]) -> i32 { + fir.result %c0 : i32 + } else { + fir.result %c1 : i32 + } + return %result : i32 +} diff --git a/flang/test/Fir/fir-ops.fir b/flang/test/Fir/fir-ops.fir index 9c444d2f4e0bc..3585bf9efca3e 100644 --- a/flang/test/Fir/fir-ops.fir +++ b/flang/test/Fir/fir-ops.fir @@ -1015,3 +1015,19 @@ func.func @test_box_total_elements(%arg0: !fir.class> %6 = arith.addi %2, %5 : index return %6 : index } + +// CHECK-LABEL: func.func @test_if_weights( +// CHECK-SAME: %[[ARG0:.*]]: i1) { +func.func @test_if_weights(%cond: i1) { +// CHECK: fir.if %[[ARG0]] weights([99, 1]) { +// CHECK: } + fir.if %cond weights([99, 1]) { + } +// CHECK: fir.if %[[ARG0]] weights([99, 1]) { +// CHECK: } else { +// CHECK: } + fir.if %cond weights ([99,1]) { + } else { + } + return +} diff --git a/flang/test/Fir/invalid.fir b/flang/test/Fir/invalid.fir index 45cae1f82cb8e..aca0ecc1abdc1 100644 --- a/flang/test/Fir/invalid.fir +++ b/flang/test/Fir/invalid.fir @@ -1393,3 +1393,31 @@ fir.local {type = local_init} @x.localizer : f32 init { ^bb0(%arg0: f32, %arg1: f32): fir.yield(%arg0 : f32) } + +// ----- + +func.func @wrong_weights_number_in_if_then(%cond: i1) { +// expected-error @below {{expects number of region weights to match number of regions: 1 vs 2}} + fir.if %cond weights([50]) { + } + return +} + +// ----- + +func.func @wrong_weights_number_in_if_then_else(%cond: i1) { +// expected-error @below {{expects number of region weights to match number of regions: 3 vs 2}} + fir.if %cond weights([50, 40, 10]) { + } else { + } + return +} + +// ----- + +func.func @negative_weight_in_if_then(%cond: i1) { +// expected-error @below {{weight #0 must be non-negative}} + fir.if %cond weights([-1, 101]) { + } + return +} diff --git a/mlir/include/mlir/Dialect/ControlFlow/IR/ControlFlowOps.td b/mlir/include/mlir/Dialect/ControlFlow/IR/ControlFlowOps.td index 48f12b46a57f1..79da81ba049dd 100644 --- a/mlir/include/mlir/Dialect/ControlFlow/IR/ControlFlowOps.td +++ b/mlir/include/mlir/Dialect/ControlFlow/IR/ControlFlowOps.td @@ -112,10 +112,11 @@ def BranchOp : CF_Op<"br", [ // CondBranchOp //===----------------------------------------------------------------------===// -def CondBranchOp : CF_Op<"cond_br", - [AttrSizedOperandSegments, - DeclareOpInterfaceMethods, - Pure, Terminator]> { +def CondBranchOp + : CF_Op<"cond_br", [AttrSizedOperandSegments, + DeclareOpInterfaceMethods< + BranchOpInterface, ["getSuccessorForOperands"]>, + WeightedBranchOpInterface, Pure, Terminator]> { let summary = "Conditional branch operation"; let description = [{ The `cf.cond_br` terminator operation represents a conditional branch on a @@ -144,20 +145,23 @@ def CondBranchOp : CF_Op<"cond_br", ``` }]; - let arguments = (ins I1:$condition, - Variadic:$trueDestOperands, - Variadic:$falseDestOperands); + let arguments = (ins I1:$condition, Variadic:$trueDestOperands, + Variadic:$falseDestOperands, + OptionalAttr:$branch_weights); let successors = (successor AnySuccessor:$trueDest, AnySuccessor:$falseDest); - let builders = [ - OpBuilder<(ins "Value":$condition, "Block *":$trueDest, - "ValueRange":$trueOperands, "Block *":$falseDest, - "ValueRange":$falseOperands), [{ - build($_builder, $_state, condition, trueOperands, falseOperands, trueDest, + let builders = [OpBuilder<(ins "Value":$condition, "Block *":$trueDest, + "ValueRange":$trueOperands, + "Block *":$falseDest, + "ValueRange":$falseOperands), + [{ + build($_builder, $_state, condition, trueOperands, falseOperands, /*branch_weights=*/{}, trueDest, falseDest); }]>, - OpBuilder<(ins "Value":$condition, "Block *":$trueDest, - "Block *":$falseDest, CArg<"ValueRange", "{}">:$falseOperands), [{ + OpBuilder<(ins "Value":$condition, "Block *":$trueDest, + "Block *":$falseDest, + CArg<"ValueRange", "{}">:$falseOperands), + [{ build($_builder, $_state, condition, trueDest, ValueRange(), falseDest, falseOperands); }]>]; @@ -216,7 +220,7 @@ def CondBranchOp : CF_Op<"cond_br", let hasCanonicalizer = 1; let assemblyFormat = [{ - $condition `,` + $condition (`weights` `(` $branch_weights^ `)` )? `,` $trueDest (`(` $trueDestOperands^ `:` type($trueDestOperands) `)`)? `,` $falseDest (`(` $falseDestOperands^ `:` type($falseDestOperands) `)`)? attr-dict diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td index 2824f09dab6ce..138170f8c8762 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td @@ -168,42 +168,6 @@ def NonNegFlagInterface : OpInterface<"NonNegFlagInterface"> { ]; } -def BranchWeightOpInterface : OpInterface<"BranchWeightOpInterface"> { - let description = [{ - An interface for operations that can carry branch weights metadata. It - provides setters and getters for the operation's branch weights attribute. - The default implementation of the interface methods expect the operation to - have an attribute of type DenseI32ArrayAttr named branch_weights. - }]; - - let cppNamespace = "::mlir::LLVM"; - - let methods = [ - InterfaceMethod< - /*desc=*/ "Returns the branch weights attribute or nullptr", - /*returnType=*/ "::mlir::DenseI32ArrayAttr", - /*methodName=*/ "getBranchWeightsOrNull", - /*args=*/ (ins), - /*methodBody=*/ [{}], - /*defaultImpl=*/ [{ - auto op = cast(this->getOperation()); - return op.getBranchWeightsAttr(); - }] - >, - InterfaceMethod< - /*desc=*/ "Sets the branch weights attribute", - /*returnType=*/ "void", - /*methodName=*/ "setBranchWeights", - /*args=*/ (ins "::mlir::DenseI32ArrayAttr":$attr), - /*methodBody=*/ [{}], - /*defaultImpl=*/ [{ - auto op = cast(this->getOperation()); - op.setBranchWeightsAttr(attr); - }] - > - ]; -} - def AccessGroupOpInterface : OpInterface<"AccessGroupOpInterface"> { let description = [{ An interface for memory operations that can carry access groups metadata. diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td index 68fa620d239b9..939e7a09a73ad 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td @@ -660,12 +660,12 @@ def LLVM_FPTruncOp : LLVM_CastOp<"fptrunc", "FPTrunc", LLVM_ScalarOrVectorOf>; // Call-related operations. -def LLVM_InvokeOp : LLVM_Op<"invoke", [ - AttrSizedOperandSegments, - DeclareOpInterfaceMethods, - DeclareOpInterfaceMethods, - DeclareOpInterfaceMethods, - Terminator]> { +def LLVM_InvokeOp + : LLVM_Op<"invoke", [AttrSizedOperandSegments, + DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods, + Terminator]> { let arguments = (ins OptionalAttr>:$var_callee_type, OptionalAttr:$callee, @@ -734,12 +734,12 @@ def LLVM_VaArgOp : LLVM_Op<"va_arg"> { // CallOp //===----------------------------------------------------------------------===// -def LLVM_CallOp : LLVM_MemAccessOpBase<"call", - [AttrSizedOperandSegments, - DeclareOpInterfaceMethods, - DeclareOpInterfaceMethods, - DeclareOpInterfaceMethods, - DeclareOpInterfaceMethods]> { +def LLVM_CallOp + : LLVM_MemAccessOpBase< + "call", [AttrSizedOperandSegments, + DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods]> { let summary = "Call to an LLVM function."; let description = [{ In LLVM IR, functions may return either 0 or 1 value. LLVM IR dialect @@ -788,21 +788,16 @@ def LLVM_CallOp : LLVM_MemAccessOpBase<"call", OptionalAttr:$callee, Variadic:$callee_operands, DefaultValuedAttr:$fastmathFlags, - OptionalAttr:$branch_weights, DefaultValuedAttr:$CConv, DefaultValuedAttr:$TailCallKind, OptionalAttr:$memory_effects, - UnitAttr:$convergent, - UnitAttr:$no_unwind, - UnitAttr:$will_return, + UnitAttr:$convergent, UnitAttr:$no_unwind, UnitAttr:$will_return, VariadicOfVariadic:$op_bundle_operands, DenseI32ArrayAttr:$op_bundle_sizes, OptionalAttr:$op_bundle_tags, OptionalAttr:$arg_attrs, - OptionalAttr:$res_attrs, - UnitAttr:$no_inline, - UnitAttr:$always_inline, - UnitAttr:$inline_hint); + OptionalAttr:$res_attrs, UnitAttr:$no_inline, + UnitAttr:$always_inline, UnitAttr:$inline_hint); // Append the aliasing related attributes defined in LLVM_MemAccessOpBase. let arguments = !con(args, aliasAttrs); let results = (outs Optional:$result); @@ -1047,11 +1042,12 @@ def LLVM_BrOp : LLVM_TerminatorOp<"br", LLVM_TerminatorPassthroughOpBuilder ]; } -def LLVM_CondBrOp : LLVM_TerminatorOp<"cond_br", - [AttrSizedOperandSegments, - DeclareOpInterfaceMethods, - DeclareOpInterfaceMethods, - Pure]> { +def LLVM_CondBrOp + : LLVM_TerminatorOp< + "cond_br", [AttrSizedOperandSegments, + DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods, + Pure]> { let arguments = (ins I1:$condition, Variadic:$trueDestOperands, Variadic:$falseDestOperands, @@ -1136,11 +1132,12 @@ def LLVM_UnreachableOp : LLVM_TerminatorOp<"unreachable"> { }]; } -def LLVM_SwitchOp : LLVM_TerminatorOp<"switch", - [AttrSizedOperandSegments, - DeclareOpInterfaceMethods, - DeclareOpInterfaceMethods, - Pure]> { +def LLVM_SwitchOp + : LLVM_TerminatorOp< + "switch", [AttrSizedOperandSegments, + DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods, + Pure]> { let arguments = (ins AnySignlessInteger:$value, Variadic:$defaultOperands, diff --git a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h index 7f6967f11444f..d63800c12d132 100644 --- a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h +++ b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h @@ -142,6 +142,26 @@ LogicalResult verifyBranchSuccessorOperands(Operation *op, unsigned succNo, const SuccessorOperands &operands); } // namespace detail +//===----------------------------------------------------------------------===// +// WeightedBranchOpInterface +//===----------------------------------------------------------------------===// + +namespace detail { +/// Verify that the branch weights attached to an operation +/// implementing WeightedBranchOpInterface are correct. +LogicalResult verifyBranchWeights(Operation *op); +} // namespace detail + +//===----------------------------------------------------------------------===// +// WeightedRegiobBranchOpInterface +//===----------------------------------------------------------------------===// + +namespace detail { +/// Verify that the region weights attached to an operation +/// implementing WeightedRegiobBranchOpInterface are correct. +LogicalResult verifyRegionBranchWeights(Operation *op); +} // namespace detail + //===----------------------------------------------------------------------===// // RegionBranchOpInterface //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td index 69bce78e946c8..46ab0b9ebbc6b 100644 --- a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td +++ b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td @@ -375,6 +375,118 @@ def SelectLikeOpInterface : OpInterface<"SelectLikeOpInterface"> { ]; } +//===----------------------------------------------------------------------===// +// WeightedBranchOpInterface +//===----------------------------------------------------------------------===// + +def WeightedBranchOpInterface : OpInterface<"WeightedBranchOpInterface"> { + let description = [{ + This interface provides weight information for branching terminator + operations, i.e. terminator operations with successors. + + This interface provides methods for getting/setting integer non-negative + weight of each branch. The probability of executing a branch + is computed as the ratio between the branch's weight and the total + sum of the weights (which cannot be zero). + The weights are optional. If they are provided, then their number + must match the number of successors of the operation. + + The default implementations of the methods expect the operation + to have an attribute of type DenseI32ArrayAttr named branch_weights. + }]; + let cppNamespace = "::mlir"; + + let methods = [InterfaceMethod< + /*desc=*/"Returns the branch weights", + /*returnType=*/"::llvm::ArrayRef", + /*methodName=*/"getWeights", + /*args=*/(ins), + /*methodBody=*/[{}], + /*defaultImpl=*/[{ + auto op = cast(this->getOperation()); + if (auto attr = op.getBranchWeightsAttr()) + return attr.asArrayRef(); + return {}; + }]>, + InterfaceMethod< + /*desc=*/"Sets the branch weights", + /*returnType=*/"void", + /*methodName=*/"setWeights", + /*args=*/(ins "::llvm::ArrayRef":$weights), + /*methodBody=*/[{}], + /*defaultImpl=*/[{ + auto op = cast(this->getOperation()); + op.setBranchWeightsAttr(::mlir::DenseI32ArrayAttr::get(op->getContext(), weights)); + }]>, + ]; + + let verify = [{ + return ::mlir::detail::verifyBranchWeights($_op); + }]; +} + +//===----------------------------------------------------------------------===// +// WeightedRegionBranchOpInterface +//===----------------------------------------------------------------------===// + +// TODO: the probabilities of entering a particular region seem +// to correlate with the values returned by +// RegionBranchOpInterface::invocationBounds(), and we should probably +// verify that the values are consistent. In that case, should +// WeightedRegionBranchOpInterface extend RegionBranchOpInterface? +def WeightedRegionBranchOpInterface + : OpInterface<"WeightedRegionBranchOpInterface"> { + let description = [{ + This interface provides weight information for region operations + that exhibit branching behavior between held regions. + + This interface provides methods for getting/setting integer non-negative + weight of each branch. The probability of executing a region is computed + as the ratio between the region branch's weight and the total sum + of the weights (which cannot be zero). + The weights are optional. If they are provided, then their number + must match the number of regions held by the operation + (including empty regions). + + The weights specify the probability of branching to a particular + region when first executing the operation. + For example, for loop-like operations with a single region + the weight specifies the probability of entering the loop. + + The default implementations of the methods expect the operation + to have an attribute of type DenseI32ArrayAttr named branch_weights. + }]; + let cppNamespace = "::mlir"; + + let methods = [InterfaceMethod< + /*desc=*/"Returns the region weights", + /*returnType=*/"::llvm::ArrayRef", + /*methodName=*/"getWeights", + /*args=*/(ins), + /*methodBody=*/[{}], + /*defaultImpl=*/[{ + auto op = cast(this->getOperation()); + if (auto attr = op.getRegionWeightsAttr()) + return attr.asArrayRef(); + return {}; + }]>, + InterfaceMethod< + /*desc=*/"Sets the region weights", + /*returnType=*/"void", + /*methodName=*/"setWeights", + /*args=*/(ins "::llvm::ArrayRef":$weights), + /*methodBody=*/[{}], + /*defaultImpl=*/[{ + auto op = cast(this->getOperation()); + op.setRegionWeightsAttr(::mlir::DenseI32ArrayAttr::get(op->getContext(), weights)); + }]>, + ]; + + let verify = [{ + return ::mlir::detail::verifyRegionBranchWeights($_op); + }]; +} + //===----------------------------------------------------------------------===// // ControlFlow Traits //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h index 97ae14aa0d6af..0f136c5c46d79 100644 --- a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h +++ b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h @@ -189,7 +189,7 @@ class ModuleTranslation { llvm::Instruction *inst); /// Sets LLVM profiling metadata for operations that have branch weights. - void setBranchWeightsMetadata(BranchWeightOpInterface op); + void setBranchWeightsMetadata(WeightedBranchOpInterface op); /// Sets LLVM loop metadata for branch operations that have a loop annotation /// attribute. diff --git a/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp b/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp index debfd003bd5b5..d31d7d801e149 100644 --- a/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp +++ b/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp @@ -166,10 +166,15 @@ struct CondBranchOpLowering : public ConvertOpToLLVMPattern { TypeRange(adaptor.getFalseDestOperands())); if (failed(convertedFalseBlock)) return failure(); - Operation *newOp = rewriter.replaceOpWithNewOp( + auto newOp = rewriter.replaceOpWithNewOp( op, adaptor.getCondition(), *convertedTrueBlock, adaptor.getTrueDestOperands(), *convertedFalseBlock, adaptor.getFalseDestOperands()); + ArrayRef weights = op.getWeights(); + if (!weights.empty()) { + newOp.setWeights(weights); + op.removeBranchWeightsAttr(); + } // TODO: We should not just forward all attributes like that. But there are // existing Flang tests that depend on this behavior. newOp->setAttrs(op->getAttrDictionary()); diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp index c7528c970a4ba..a12aef0dfad38 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -589,10 +589,6 @@ LogicalResult SwitchOp::verify() { static_cast(getCaseDestinations().size()))) return emitOpError("expects number of case values to match number of " "case destinations"); - if (getBranchWeights() && getBranchWeights()->size() != getNumSuccessors()) - return emitError("expects number of branch weights to match number of " - "successors: ") - << getBranchWeights()->size() << " vs " << getNumSuccessors(); if (getCaseValues() && getValue().getType() != getCaseValues()->getElementType()) return emitError("expects case value type to match condition value type"); @@ -962,7 +958,6 @@ void CallOp::build(OpBuilder &builder, OperationState &state, TypeRange results, assert(callee && "expected non-null callee in direct call builder"); build(builder, state, results, /*var_callee_type=*/nullptr, callee, args, /*fastmathFlags=*/nullptr, - /*branch_weights=*/nullptr, /*CConv=*/nullptr, /*TailCallKind=*/nullptr, /*memory_effects=*/nullptr, /*convergent=*/nullptr, /*no_unwind=*/nullptr, /*will_return=*/nullptr, @@ -992,7 +987,7 @@ void CallOp::build(OpBuilder &builder, OperationState &state, build(builder, state, getCallOpResultTypes(calleeType), getCallOpVarCalleeType(calleeType), callee, args, /*fastmathFlags=*/nullptr, - /*branch_weights=*/nullptr, /*CConv=*/nullptr, + /*CConv=*/nullptr, /*TailCallKind=*/nullptr, /*memory_effects=*/nullptr, /*convergent=*/nullptr, /*no_unwind=*/nullptr, /*will_return=*/nullptr, @@ -1009,7 +1004,7 @@ void CallOp::build(OpBuilder &builder, OperationState &state, build(builder, state, getCallOpResultTypes(calleeType), getCallOpVarCalleeType(calleeType), /*callee=*/nullptr, args, - /*fastmathFlags=*/nullptr, /*branch_weights=*/nullptr, + /*fastmathFlags=*/nullptr, /*CConv=*/nullptr, /*TailCallKind=*/nullptr, /*memory_effects=*/nullptr, /*convergent=*/nullptr, /*no_unwind=*/nullptr, /*will_return=*/nullptr, /*op_bundle_operands=*/{}, /*op_bundle_tags=*/{}, @@ -1025,7 +1020,7 @@ void CallOp::build(OpBuilder &builder, OperationState &state, LLVMFuncOp func, auto calleeType = func.getFunctionType(); build(builder, state, getCallOpResultTypes(calleeType), getCallOpVarCalleeType(calleeType), SymbolRefAttr::get(func), args, - /*fastmathFlags=*/nullptr, /*branch_weights=*/nullptr, + /*fastmathFlags=*/nullptr, /*CConv=*/nullptr, /*TailCallKind=*/nullptr, /*memory_effects=*/nullptr, /*convergent=*/nullptr, /*no_unwind=*/nullptr, /*will_return=*/nullptr, /*op_bundle_operands=*/{}, /*op_bundle_tags=*/{}, diff --git a/mlir/lib/Interfaces/ControlFlowInterfaces.cpp b/mlir/lib/Interfaces/ControlFlowInterfaces.cpp index 2ae334b517a31..3a63db35eec0f 100644 --- a/mlir/lib/Interfaces/ControlFlowInterfaces.cpp +++ b/mlir/lib/Interfaces/ControlFlowInterfaces.cpp @@ -9,6 +9,7 @@ #include #include "mlir/IR/BuiltinTypes.h" +#include "mlir/Interfaces/CallInterfaces.h" #include "mlir/Interfaces/ControlFlowInterfaces.h" #include "llvm/ADT/SmallPtrSet.h" @@ -80,6 +81,51 @@ detail::verifyBranchSuccessorOperands(Operation *op, unsigned succNo, return success(); } +//===----------------------------------------------------------------------===// +// WeightedBranchOpInterface +//===----------------------------------------------------------------------===// + +static LogicalResult verifyWeights(Operation *op, + llvm::ArrayRef weights, + std::size_t expectedWeightsNum, + llvm::StringRef weightAnchorName, + llvm::StringRef weightRefName) { + if (weights.empty()) + return success(); + + if (weights.size() != expectedWeightsNum) + return op->emitError() << "expects number of " << weightAnchorName + << " weights to match number of " << weightRefName + << ": " << weights.size() << " vs " + << expectedWeightsNum; + + for (auto [index, weight] : llvm::enumerate(weights)) + if (weight < 0) + return op->emitError() << "weight #" << index << " must be non-negative"; + + if (llvm::all_of(weights, [](int32_t value) { return value == 0; })) + return op->emitError() << "branch weights cannot all be zero"; + + return success(); +} + +LogicalResult detail::verifyBranchWeights(Operation *op) { + llvm::ArrayRef weights = + cast(op).getWeights(); + return verifyWeights(op, weights, op->getNumSuccessors(), "branch", + "successors"); +} + +//===----------------------------------------------------------------------===// +// WeightedRegionBranchOpInterface +//===----------------------------------------------------------------------===// + +LogicalResult detail::verifyRegionBranchWeights(Operation *op) { + llvm::ArrayRef weights = + cast(op).getWeights(); + return verifyWeights(op, weights, op->getNumRegions(), "region", "regions"); +} + //===----------------------------------------------------------------------===// // RegionBranchOpInterface //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.cpp index 1b5ce868b5c77..e67aa892afe09 100644 --- a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.cpp @@ -146,8 +146,15 @@ static LogicalResult setProfilingAttr(OpBuilder &builder, llvm::MDNode *node, branchWeights.push_back(branchWeight->getZExtValue()); } - if (auto iface = dyn_cast(op)) { - iface.setBranchWeights(builder.getDenseI32ArrayAttr(branchWeights)); + if (auto iface = dyn_cast(op)) { + // LLVM allows attaching a single weight to call instructions. + // This is used for carrying the execution count information + // in PGO modes. MLIR WeightedBranchOpInterface does not allow this, + // so we drop the metadata in this case. + // LLVM should probably use the VP form of MD_prof metadata + // for such cases. + if (op->getNumSuccessors() != 0) + iface.setWeights(branchWeights); return success(); } return failure(); diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp index e5ca147ea98f8..3eaa24eb5c95b 100644 --- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp @@ -1055,7 +1055,7 @@ LogicalResult ModuleTranslation::convertBlockImpl(Block &bb, return failure(); // Set the branch weight metadata on the translated instruction. - if (auto iface = dyn_cast(op)) + if (auto iface = dyn_cast(op)) setBranchWeightsMetadata(iface); } @@ -2026,14 +2026,15 @@ void ModuleTranslation::setDereferenceableMetadata( inst->setMetadata(kindId, derefSizeNode); } -void ModuleTranslation::setBranchWeightsMetadata(BranchWeightOpInterface op) { - DenseI32ArrayAttr weightsAttr = op.getBranchWeightsOrNull(); - if (!weightsAttr) +void ModuleTranslation::setBranchWeightsMetadata(WeightedBranchOpInterface op) { + SmallVector weights; + llvm::transform(op.getWeights(), std::back_inserter(weights), + [](int32_t value) { return static_cast(value); }); + if (weights.empty()) return; llvm::Instruction *inst = isa(op) ? lookupCall(op) : lookupBranch(op); assert(inst && "expected the operation to have a mapping to an instruction"); - SmallVector weights(weightsAttr.asArrayRef()); inst->setMetadata( llvm::LLVMContext::MD_prof, llvm::MDBuilder(getLLVMContext()).createBranchWeights(weights)); diff --git a/mlir/test/Conversion/ControlFlowToLLVM/branch.mlir b/mlir/test/Conversion/ControlFlowToLLVM/branch.mlir index 9a0f2b7714544..7c78211d59010 100644 --- a/mlir/test/Conversion/ControlFlowToLLVM/branch.mlir +++ b/mlir/test/Conversion/ControlFlowToLLVM/branch.mlir @@ -67,3 +67,17 @@ func.func @unreachable_block() { ^bb1(%arg0: index): cf.br ^bb1(%arg0 : index) } + +// ----- + +// Test case for cf.cond_br with weights. + +// CHECK-LABEL: func.func @cf_cond_br_with_weights( +func.func @cf_cond_br_with_weights(%cond: i1, %a: index, %b: index) -> index { +// CHECK: llvm.cond_br %{{.*}} weights([90, 10]), ^bb1(%{{.*}} : i64), ^bb2(%{{.*}} : i64) + cf.cond_br %cond, ^bb1(%a : index), ^bb2(%b : index) {branch_weights = array} +^bb1(%arg1: index): + return %arg1 : index +^bb2(%arg2: index): + return %arg2 : index +} diff --git a/mlir/test/Dialect/ControlFlow/invalid.mlir b/mlir/test/Dialect/ControlFlow/invalid.mlir index b51d8095c9974..1b8de22a9ff9f 100644 --- a/mlir/test/Dialect/ControlFlow/invalid.mlir +++ b/mlir/test/Dialect/ControlFlow/invalid.mlir @@ -67,3 +67,39 @@ func.func @switch_missing_default(%flag : i32, %caseOperand : i32) { ^bb3(%bb3arg : i32): return } + +// ----- + +// CHECK-LABEL: func @wrong_weights_number +func.func @wrong_weights_number(%cond: i1) { + // expected-error@+1 {{expects number of branch weights to match number of successors: 1 vs 2}} + cf.cond_br %cond weights([100]), ^bb1, ^bb2 + ^bb1: + return + ^bb2: + return +} + +// ----- + +// CHECK-LABEL: func @negative_weight +func.func @wrong_total_weight(%cond: i1) { + // expected-error@+1 {{weight #0 must be non-negative}} + cf.cond_br %cond weights([-1, 101]), ^bb1, ^bb2 + ^bb1: + return + ^bb2: + return +} + +// ----- + +// CHECK-LABEL: func @zero_weights +func.func @wrong_total_weight(%cond: i1) { + // expected-error@+1 {{branch weights cannot all be zero}} + cf.cond_br %cond weights([0, 0]), ^bb1, ^bb2 + ^bb1: + return + ^bb2: + return +} diff --git a/mlir/test/Dialect/ControlFlow/ops.mlir b/mlir/test/Dialect/ControlFlow/ops.mlir index c9317c7613972..160534240e0fa 100644 --- a/mlir/test/Dialect/ControlFlow/ops.mlir +++ b/mlir/test/Dialect/ControlFlow/ops.mlir @@ -51,3 +51,13 @@ func.func @switch_result_number(%arg0: i32) { ^bb2: return } + +// CHECK-LABEL: func @cond_weights +func.func @cond_weights(%cond: i1) { +// CHECK: cf.cond_br %{{.*}} weights([60, 40]), ^{{.*}}, ^{{.*}} + cf.cond_br %cond weights([60, 40]), ^bb1, ^bb2 + ^bb1: + return + ^bb2: + return +} diff --git a/mlir/test/Target/LLVMIR/Import/metadata-profiling.ll b/mlir/test/Target/LLVMIR/Import/metadata-profiling.ll index cc3b47a54dfe9..c623df0b605b2 100644 --- a/mlir/test/Target/LLVMIR/Import/metadata-profiling.ll +++ b/mlir/test/Target/LLVMIR/Import/metadata-profiling.ll @@ -36,14 +36,17 @@ bbd: ; // ----- +; Verify that a single weight attached to a call is not translated. +; The MLIR WeightedBranchOpInterface does not support this case. + ; CHECK: llvm.func @fn() -declare void @fn() +declare i32 @fn() ; CHECK-LABEL: @call_branch_weights -define void @call_branch_weights() { - ; CHECK: llvm.call @fn() {branch_weights = array} - call void @fn(), !prof !0 - ret void +define i32 @call_branch_weights() { + ; CHECK: llvm.call @fn() : () -> i32 + %1 = call i32 @fn(), !prof !0 + ret i32 %1 } !0 = !{!"branch_weights", i32 42} diff --git a/mlir/test/Target/LLVMIR/llvmir-invalid.mlir b/mlir/test/Target/LLVMIR/llvmir-invalid.mlir index 24a7b42557278..a8ef401fff27e 100644 --- a/mlir/test/Target/LLVMIR/llvmir-invalid.mlir +++ b/mlir/test/Target/LLVMIR/llvmir-invalid.mlir @@ -448,3 +448,19 @@ llvm.mlir.global external constant @const() {addr_space = 0 : i32, dso_local} : } llvm.func extern_weak @extern_func() + +// ----- + +llvm.func @invoke_branch_weights_callee() +llvm.func @__gxx_personality_v0(...) -> i32 + +llvm.func @invoke_branch_weights() -> i32 attributes {personality = @__gxx_personality_v0} { + %0 = llvm.mlir.constant(1 : i32) : i32 + // expected-error @below{{expects number of branch weights to match number of successors: 1 vs 2}} + llvm.invoke @invoke_branch_weights_callee() to ^bb2 unwind ^bb1 {branch_weights = array} : () -> () +^bb1: // pred: ^bb0 + %1 = llvm.landingpad cleanup : !llvm.struct<(ptr, i32)> + llvm.br ^bb2 +^bb2: // 2 preds: ^bb0, ^bb1 + llvm.return %0 : i32 +} diff --git a/mlir/test/Target/LLVMIR/llvmir.mlir b/mlir/test/Target/LLVMIR/llvmir.mlir index 7742259e7a478..fc1993b50ba2d 100644 --- a/mlir/test/Target/LLVMIR/llvmir.mlir +++ b/mlir/test/Target/LLVMIR/llvmir.mlir @@ -1906,32 +1906,6 @@ llvm.func @cond_br_weights(%cond : i1, %arg0 : i32, %arg1 : i32) -> i32 { // ----- -llvm.func @fn() - -// CHECK-LABEL: @call_branch_weights -llvm.func @call_branch_weights() { - // CHECK: !prof ![[NODE:[0-9]+]] - llvm.call @fn() {branch_weights = array} : () -> () - llvm.return -} - -// CHECK: ![[NODE]] = !{!"branch_weights", i32 42} - -// ----- - -llvm.func @fn() -> i32 - -// CHECK-LABEL: @call_branch_weights -llvm.func @call_branch_weights() { - // CHECK: !prof ![[NODE:[0-9]+]] - %res = llvm.call @fn() {branch_weights = array} : () -> i32 - llvm.return -} - -// CHECK: ![[NODE]] = !{!"branch_weights", i32 42} - -// ----- - llvm.func @foo() llvm.func @__gxx_personality_v0(...) -> i32