Skip to content

Commit ec4a18e

Browse files
committed
assert expected ArrayRerf argument size
1 parent 6f42ee5 commit ec4a18e

File tree

1 file changed

+9
-8
lines changed

1 file changed

+9
-8
lines changed

mlir/lib/Dialect/Arith/Transforms/ShardingInterfaceImpl.cpp

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -46,18 +46,20 @@ struct ConstantShardingInterface
4646
FailureOr<ShardingOption>
4747
getShardingOption(Operation *op, ArrayRef<MeshSharding> operandShardings,
4848
ArrayRef<MeshSharding> resultShardings) const {
49-
if (!resultShardings[0]) {
49+
assert(resultShardings.size() == 1 &&
50+
"Expecting exactly one result sharding for arith.constant");
51+
auto resultSharding = resultShardings[0];
52+
if (!resultSharding) {
5053
return failure();
5154
}
5255
if (auto type = dyn_cast<RankedTensorType>(op->getResult(0).getType())) {
53-
ShardingArray axesArray(resultShardings[0].getSplitAxes().size());
54-
for (auto [i, axes] :
55-
llvm::enumerate(resultShardings[0].getSplitAxes())) {
56+
ShardingArray axesArray(resultSharding.getSplitAxes().size());
57+
for (auto [i, axes] : llvm::enumerate(resultSharding.getSplitAxes())) {
5658
axesArray[i].append(axes.asArrayRef().begin(), axes.asArrayRef().end());
5759
}
58-
return ShardingOption(axesArray, resultShardings[0].getMeshAttr());
60+
return ShardingOption(axesArray, resultSharding.getMeshAttr());
5961
}
60-
return ShardingOption({}, resultShardings[0].getMeshAttr());
62+
return ShardingOption({}, resultSharding.getMeshAttr());
6163
}
6264

6365
LogicalResult spmdize(Operation *op, ArrayRef<Value> spmdizedOperands,
@@ -67,8 +69,7 @@ struct ConstantShardingInterface
6769
SymbolTableCollection &symbolTable,
6870
OpBuilder &builder) const {
6971
auto cOp = cast<ConstantOp>(op);
70-
auto value = dyn_cast<DenseIntOrFPElementsAttr>(cOp.getValue());
71-
if (value) {
72+
if (auto value = dyn_cast<DenseIntOrFPElementsAttr>(cOp.getValue())) {
7273
if (!value.isSplat() || !resultShardings[0]) {
7374
// Currently non-splat constants are not supported.
7475
return failure();

0 commit comments

Comments
 (0)