diff --git a/mlir/include/mlir/Dialect/ControlFlow/IR/ControlFlowOps.td b/mlir/include/mlir/Dialect/ControlFlow/IR/ControlFlowOps.td index 79da81ba049dd..a441fd82546e3 100644 --- a/mlir/include/mlir/Dialect/ControlFlow/IR/ControlFlowOps.td +++ b/mlir/include/mlir/Dialect/ControlFlow/IR/ControlFlowOps.td @@ -153,17 +153,25 @@ def CondBranchOp let builders = [OpBuilder<(ins "Value":$condition, "Block *":$trueDest, "ValueRange":$trueOperands, "Block *":$falseDest, - "ValueRange":$falseOperands), + "ValueRange":$falseOperands, + CArg<"ArrayRef", "{}">:$branchWeights), [{ - build($_builder, $_state, condition, trueOperands, falseOperands, /*branch_weights=*/{}, trueDest, - falseDest); + DenseI32ArrayAttr weights; + if (!branchWeights.empty()) + weights = $_builder.getDenseI32ArrayAttr(branchWeights); + build($_builder, $_state, condition, trueOperands, falseOperands, + weights, trueDest, falseDest); }]>, OpBuilder<(ins "Value":$condition, "Block *":$trueDest, "Block *":$falseDest, - CArg<"ValueRange", "{}">:$falseOperands), + CArg<"ValueRange", "{}">:$falseOperands, + CArg<"ArrayRef", "{}">:$branchWeights), [{ - build($_builder, $_state, condition, trueDest, ValueRange(), falseDest, - falseOperands); + DenseI32ArrayAttr weights; + if (!branchWeights.empty()) + weights = $_builder.getDenseI32ArrayAttr(branchWeights); + build($_builder, $_state, condition, ValueRange(), falseOperands, + weights, trueDest, falseDest); }]>]; let extraClassDeclaration = [{ diff --git a/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp b/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp index edd7f607f24f4..0c11c76cf1f71 100644 --- a/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp +++ b/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp @@ -265,9 +265,9 @@ struct SimplifyPassThroughCondBranch : public OpRewritePattern { return failure(); // Create a new branch with the collapsed successors. - rewriter.replaceOpWithNewOp(condbr, condbr.getCondition(), - trueDest, trueDestOperands, - falseDest, falseDestOperands); + rewriter.replaceOpWithNewOp( + condbr, condbr.getCondition(), trueDest, trueDestOperands, falseDest, + falseDestOperands, condbr.getWeights()); return success(); } }; diff --git a/mlir/test/Dialect/ControlFlow/canonicalize.mlir b/mlir/test/Dialect/ControlFlow/canonicalize.mlir index 0ad6898fce86c..bf69935a00bf0 100644 --- a/mlir/test/Dialect/ControlFlow/canonicalize.mlir +++ b/mlir/test/Dialect/ControlFlow/canonicalize.mlir @@ -102,6 +102,31 @@ func.func @cond_br_and_br_folding(%a : i32) { /// Test that pass-through successors of CondBranchOp get folded. +// Test that the weights are preserved: +// CHECK-LABEL: func.func @cond_br_passthrough_weights( +// CHECK-SAME: %[[ARG0:.*]]: i32, %[[ARG1:.*]]: i32, %[[ARG2:.*]]: i1) -> i32 { +func.func @cond_br_passthrough_weights(%arg0 : i32, %arg1 : i32, %cond : i1) -> i32 { +// CHECK: cf.cond_br %[[ARG2]] weights([30, 70]), ^bb1, ^bb2 +// CHECK: ^bb1: +// CHECK: return %[[ARG0]] : i32 +// CHECK: ^bb2: +// CHECK: return %[[ARG1]] : i32 +// CHECK: } + cf.cond_br %cond weights([30,70]), ^bb1, ^bb3 + +^bb1: + cf.br ^bb2 + +^bb3: + cf.br ^bb4 + +^bb2: + return %arg0 : i32 + +^bb4: + return %arg1 : i32 +} + // CHECK-LABEL: func @cond_br_passthrough( // CHECK-SAME: %[[ARG0:.*]]: i32, %[[ARG1:.*]]: i32, %[[ARG2:.*]]: i32, %[[COND:.*]]: i1 func.func @cond_br_passthrough(%arg0 : i32, %arg1 : i32, %arg2 : i32, %cond : i1) -> (i32, i32) {