@@ -46,18 +46,20 @@ struct ConstantShardingInterface
4646 FailureOr<ShardingOption>
4747 getShardingOption (Operation *op, ArrayRef<MeshSharding> operandShardings,
4848 ArrayRef<MeshSharding> resultShardings) const {
49- if (!resultShardings[0 ]) {
49+ assert (resultShardings.size () == 1 &&
50+ " Expecting exactly one result sharding for arith.constant" );
51+ auto resultSharding = resultShardings[0 ];
52+ if (!resultSharding) {
5053 return failure ();
5154 }
5255 if (auto type = dyn_cast<RankedTensorType>(op->getResult (0 ).getType ())) {
53- ShardingArray axesArray (resultShardings[0 ].getSplitAxes ().size ());
54- for (auto [i, axes] :
55- llvm::enumerate (resultShardings[0 ].getSplitAxes ())) {
56+ ShardingArray axesArray (resultSharding.getSplitAxes ().size ());
57+ for (auto [i, axes] : llvm::enumerate (resultSharding.getSplitAxes ())) {
5658 axesArray[i].append (axes.asArrayRef ().begin (), axes.asArrayRef ().end ());
5759 }
58- return ShardingOption (axesArray, resultShardings[ 0 ] .getMeshAttr ());
60+ return ShardingOption (axesArray, resultSharding .getMeshAttr ());
5961 }
60- return ShardingOption ({}, resultShardings[ 0 ] .getMeshAttr ());
62+ return ShardingOption ({}, resultSharding .getMeshAttr ());
6163 }
6264
6365 LogicalResult spmdize (Operation *op, ArrayRef<Value> spmdizedOperands,
@@ -67,8 +69,7 @@ struct ConstantShardingInterface
6769 SymbolTableCollection &symbolTable,
6870 OpBuilder &builder) const {
6971 auto cOp = cast<ConstantOp>(op);
70- auto value = dyn_cast<DenseIntOrFPElementsAttr>(cOp.getValue ());
71- if (value) {
72+ if (auto value = dyn_cast<DenseIntOrFPElementsAttr>(cOp.getValue ())) {
7273 if (!value.isSplat () || !resultShardings[0 ]) {
7374 // Currently non-splat constants are not supported.
7475 return failure ();
0 commit comments