Skip to content

Commit 6963e7d

Browse files
committed
Addressed review comments.
1 parent 4ab5cac commit 6963e7d

File tree

6 files changed

+47
-38
lines changed

6 files changed

+47
-38
lines changed

flang/lib/Optimizer/Transforms/ControlFlowConverter.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -215,8 +215,8 @@ class CfgIfConv : public mlir::OpRewritePattern<fir::IfOp> {
215215
auto branchOp = rewriter.create<mlir::cf::CondBranchOp>(
216216
loc, ifOp.getCondition(), ifOpBlock, llvm::ArrayRef<mlir::Value>(),
217217
otherwiseBlock, llvm::ArrayRef<mlir::Value>());
218-
if (auto weights = ifOp.getRegionWeightsOrNull())
219-
branchOp.setBranchWeights(weights);
218+
if (auto weights = ifOp.getWeights())
219+
branchOp.setWeights(weights);
220220
rewriter.replaceOp(ifOp, continueBlock->getArguments());
221221
return success();
222222
}

mlir/include/mlir/Interfaces/ControlFlowInterfaces.td

Lines changed: 26 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -388,34 +388,37 @@ def WeightedBranchOpInterface : OpInterface<"WeightedBranchOpInterface"> {
388388
weight of each branch. The probability of executing a branch
389389
is computed as the ratio between the branch's weight and the total
390390
sum of the weights.
391-
The number of weights must match the number of successors of the operation,
391+
The weights are optional. If they are provided, then their number
392+
must match the number of successors of the operation,
392393
with one exception for CallOpInterface operations, which may only
393-
have on weight when they do not have any successors.
394+
have one weight when they do not have any successors.
394395

395396
The default implementations of the methods expect the operation
396397
to have an attribute of type DenseI32ArrayAttr named branch_weights.
397398
}];
398399
let cppNamespace = "::mlir";
399400

400401
let methods = [InterfaceMethod<
401-
/*desc=*/"Returns the branch weights attribute or nullptr",
402-
/*returnType=*/"::mlir::DenseI32ArrayAttr",
403-
/*methodName=*/"getBranchWeightsOrNull",
402+
/*desc=*/"Returns the branch weights",
403+
/*returnType=*/"::llvm::ArrayRef<int32_t>",
404+
/*methodName=*/"getWeights",
404405
/*args=*/(ins),
405406
/*methodBody=*/[{}],
406407
/*defaultImpl=*/[{
407408
auto op = cast<ConcreteOp>(this->getOperation());
408-
return op.getBranchWeightsAttr();
409+
if (auto attr = op.getBranchWeightsAttr())
410+
return attr.asArrayRef();
411+
return {};
409412
}]>,
410413
InterfaceMethod<
411-
/*desc=*/"Sets the branch weights attribute",
414+
/*desc=*/"Sets the branch weights",
412415
/*returnType=*/"void",
413-
/*methodName=*/"setBranchWeights",
414-
/*args=*/(ins "::mlir::DenseI32ArrayAttr":$attr),
416+
/*methodName=*/"setWeights",
417+
/*args=*/(ins "::llvm::ArrayRef<int32_t>":$weights),
415418
/*methodBody=*/[{}],
416419
/*defaultImpl=*/[{
417420
auto op = cast<ConcreteOp>(this->getOperation());
418-
op.setBranchWeightsAttr(attr);
421+
op.setBranchWeightsAttr(::mlir::DenseI32ArrayAttr::get(op->getContext(), weights));
419422
}]>,
420423
];
421424

@@ -443,8 +446,9 @@ def WeightedRegionBranchOpInterface
443446
weight of each branch. The probability of executing a region is computed
444447
as the ratio between the region branch's weight and the total sum
445448
of the weights.
446-
The number of weights must match the number of regions
447-
held by the operation (including empty regions).
449+
The weights are optional. If they are provided, then their number
450+
must match the number of regions held by the operation
451+
(including empty regions).
448452

449453
The weights specify the probability of branching to a particular
450454
region when first executing the operation.
@@ -457,24 +461,26 @@ def WeightedRegionBranchOpInterface
457461
let cppNamespace = "::mlir";
458462

459463
let methods = [InterfaceMethod<
460-
/*desc=*/"Returns the region weights attribute or nullptr",
461-
/*returnType=*/"::mlir::DenseI32ArrayAttr",
462-
/*methodName=*/"getRegionWeightsOrNull",
464+
/*desc=*/"Returns the region weights",
465+
/*returnType=*/"::llvm::ArrayRef<int32_t>",
466+
/*methodName=*/"getWeights",
463467
/*args=*/(ins),
464468
/*methodBody=*/[{}],
465469
/*defaultImpl=*/[{
466470
auto op = cast<ConcreteOp>(this->getOperation());
467-
return op.getRegionWeightsAttr();
471+
if (auto attr = op.getRegionWeightsAttr())
472+
return attr.asArrayRef();
473+
return {};
468474
}]>,
469475
InterfaceMethod<
470-
/*desc=*/"Sets the region weights attribute",
476+
/*desc=*/"Sets the region weights",
471477
/*returnType=*/"void",
472-
/*methodName=*/"setRegionWeights",
473-
/*args=*/(ins "::mlir::DenseI32ArrayAttr":$attr),
478+
/*methodName=*/"setWeights",
479+
/*args=*/(ins "::llvm::ArrayRef<int32_t>":$weights),
474480
/*methodBody=*/[{}],
475481
/*defaultImpl=*/[{
476482
auto op = cast<ConcreteOp>(this->getOperation());
477-
op.setRegionWeightsAttr(attr);
483+
op.setRegionWeightsAttr(::mlir::DenseI32ArrayAttr::get(op->getContext(), weights));
478484
}]>,
479485
];
480486

mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -170,8 +170,9 @@ struct CondBranchOpLowering : public ConvertOpToLLVMPattern<cf::CondBranchOp> {
170170
op, adaptor.getCondition(), *convertedTrueBlock,
171171
adaptor.getTrueDestOperands(), *convertedFalseBlock,
172172
adaptor.getFalseDestOperands());
173-
if (auto weights = op.getBranchWeightsOrNull()) {
174-
newOp.setBranchWeights(weights);
173+
ArrayRef<int32_t> weights = op.getWeights();
174+
if (!weights.empty()) {
175+
newOp.setWeights(weights);
175176
op.removeBranchWeightsAttr();
176177
}
177178
// TODO: We should not just forward all attributes like that. But there are

mlir/lib/Interfaces/ControlFlowInterfaces.cpp

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -85,11 +85,12 @@ detail::verifyBranchSuccessorOperands(Operation *op, unsigned succNo,
8585
// WeightedBranchOpInterface
8686
//===----------------------------------------------------------------------===//
8787

88-
static LogicalResult verifyWeights(Operation *op, DenseI32ArrayAttr weights,
89-
int64_t expectedWeightsNum,
88+
static LogicalResult verifyWeights(Operation *op,
89+
llvm::ArrayRef<int32_t> weights,
90+
std::size_t expectedWeightsNum,
9091
llvm::StringRef weightAnchorName,
9192
llvm::StringRef weightRefName) {
92-
if (!weights)
93+
if (weights.empty())
9394
return success();
9495

9596
if (weights.size() != expectedWeightsNum)
@@ -98,22 +99,22 @@ static LogicalResult verifyWeights(Operation *op, DenseI32ArrayAttr weights,
9899
<< ": " << weights.size() << " vs "
99100
<< expectedWeightsNum;
100101

101-
for (auto [index, weight] : llvm::enumerate(weights.asArrayRef()))
102+
for (auto [index, weight] : llvm::enumerate(weights))
102103
if (weight < 0)
103104
return op->emitError() << "weight #" << index << " must be non-negative";
104105

105106
return success();
106107
}
107108

108109
LogicalResult detail::verifyBranchWeights(Operation *op) {
109-
auto weights = cast<WeightedBranchOpInterface>(op).getBranchWeightsOrNull();
110+
llvm::ArrayRef<int32_t> weights =
111+
cast<WeightedBranchOpInterface>(op).getWeights();
110112
unsigned successorsNum = op->getNumSuccessors();
111113
// CallOpInterface operations without successors may only have
112114
// one weight, though it seems to be redundant and indicate
113115
// 100% probability of calling the callee(s).
114-
// TODO: maybe we should remove this interface for calls without
115-
// successors.
116-
int64_t weightsNum =
116+
// TODO: maybe we should disallow weights for calls without successors.
117+
std::size_t weightsNum =
117118
(successorsNum == 0 && isa<CallOpInterface>(op)) ? 1 : successorsNum;
118119
return verifyWeights(op, weights, weightsNum, "branch", "successors");
119120
}
@@ -123,8 +124,8 @@ LogicalResult detail::verifyBranchWeights(Operation *op) {
123124
//===----------------------------------------------------------------------===//
124125

125126
LogicalResult detail::verifyRegionBranchWeights(Operation *op) {
126-
auto weights =
127-
cast<WeightedRegionBranchOpInterface>(op).getRegionWeightsOrNull();
127+
llvm::ArrayRef<int32_t> weights =
128+
cast<WeightedRegionBranchOpInterface>(op).getWeights();
128129
return verifyWeights(op, weights, op->getNumRegions(), "region", "regions");
129130
}
130131

mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ static LogicalResult setProfilingAttr(OpBuilder &builder, llvm::MDNode *node,
147147
}
148148

149149
if (auto iface = dyn_cast<WeightedBranchOpInterface>(op)) {
150-
iface.setBranchWeights(builder.getDenseI32ArrayAttr(branchWeights));
150+
iface.setWeights(branchWeights);
151151
return success();
152152
}
153153
return failure();

mlir/lib/Target/LLVMIR/ModuleTranslation.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2027,13 +2027,14 @@ void ModuleTranslation::setDereferenceableMetadata(
20272027
}
20282028

20292029
void ModuleTranslation::setBranchWeightsMetadata(WeightedBranchOpInterface op) {
2030-
DenseI32ArrayAttr weightsAttr = op.getBranchWeightsOrNull();
2031-
if (!weightsAttr)
2030+
SmallVector<uint32_t> weights;
2031+
llvm::transform(op.getWeights(), std::back_inserter(weights),
2032+
[](int32_t value) { return static_cast<uint32_t>(value); });
2033+
if (weights.empty())
20322034
return;
20332035

20342036
llvm::Instruction *inst = isa<CallOp>(op) ? lookupCall(op) : lookupBranch(op);
20352037
assert(inst && "expected the operation to have a mapping to an instruction");
2036-
SmallVector<uint32_t> weights(weightsAttr.asArrayRef());
20372038
inst->setMetadata(
20382039
llvm::LLVMContext::MD_prof,
20392040
llvm::MDBuilder(getLLVMContext()).createBranchWeights(weights));

0 commit comments

Comments
 (0)