@@ -193,6 +193,7 @@ static std::array<Value, 2> getShardSliceOffAndSz(
193193 ValueRange myIdx, int64_t dim, ArrayRef<int64_t > meshShape,
194194 ArrayRef<MeshAxesAttr> splitAxes, Value targetOffs,
195195 ArrayRef<int64_t > srcShape, const SmallVector<OpFoldResult> &slcOffs,
196+ const SmallVector<OpFoldResult> &slcSizes,
196197 const SmallVector<OpFoldResult> &slcStrides,
197198 const SmallVector<OpFoldResult> &haloSizes, const EasyI64 &zero,
198199 const EasyI64 &one, OpBuilder &builder, Location loc) {
@@ -214,8 +215,12 @@ static std::array<Value, 2> getShardSliceOffAndSz(
214215 std::tie (myOff, mySize) =
215216 getOffsetAndSize (myID, zero, one, targetOffs, currPos, builder, loc);
216217 } else {
217- myOff = getBaseShardDimOff (myID, numShards, extend, zero).get ();
218- mySize = getBaseShardDimSize (myID, numShards, extend, one, zero).get ();
218+ auto myOff_ = getBaseShardDimOff (myID, numShards, extend, zero);
219+ auto mySize_ = getBaseShardDimSize (myID, numShards, extend, one, zero);
220+ auto slcSz = easyI64 (loc, builder, slcSizes[dim]);
221+ mySize_ = zero.max (slcSz - myOff_).min (mySize_);
222+ myOff = myOff_.get ();
223+ mySize = mySize_.get ();
219224 }
220225
221226 // the global offset of the local shard is slice offset plus the computed
@@ -290,7 +295,7 @@ getLocalOffSzAndStrFromSlice(OP op, ArrayRef<int64_t> srcShape,
290295 } else {
291296 auto offAndSz = getShardSliceOffAndSz (
292297 myIdx, dim, mesh.getShape (), splitAxes, targetOffs, srcShape, slcOffs,
293- slcStrides, haloSizes, zero, one, builder, loc);
298+ slcSizes, slcStrides, haloSizes, zero, one, builder, loc);
294299 lShardOffs.emplace_back (offAndSz[0 ]);
295300 lShardSizes.emplace_back (offAndSz[1 ]);
296301 }
@@ -439,6 +444,7 @@ struct InsertSliceShardingInterface
439444 }
440445 auto dstSharding = mlir::mesh::MeshSharding::get (shardingOption.mesh , res);
441446 maybeInsertSourceShardingAnnotation (dstSharding, op->getOpOperand (0 ), b);
447+ maybeInsertTargetShardingAnnotation (dstSharding, op->getResult (0 ), b);
442448
443449 return success ();
444450 }
@@ -449,7 +455,8 @@ struct InsertSliceShardingInterface
449455 IRMapping &spmdizationMap,
450456 SymbolTableCollection &symbolTableCollection,
451457 OpBuilder &builder) const {
452- if (resultShardings.size () != 0 ) {
458+ if (resultShardings.size () != 1 || operandShardings.size () < 2 ||
459+ resultShardings[0 ] != operandShardings[0 ]) {
453460 return failure ();
454461 }
455462
@@ -493,22 +500,29 @@ struct InsertSliceShardingInterface
493500 }
494501
495502 scf::IfOp ifOp = builder.create <scf::IfOp>(
496- loc, hasSize.get (), [&](OpBuilder &b, Location loc) {
497- (void )b.create <imex::ndarray::InsertSliceOp>(
503+ loc, hasSize.get (),
504+ [&](OpBuilder &b, Location loc) {
505+ auto res = b.create <imex::ndarray::InsertSliceOp>(
498506 loc, spmdizedOperands[0 ], spmdizedOperands[1 ], lShardOffs,
499507 lShardSizes, lShardStrides);
500- b.create <scf::YieldOp>(loc);
508+ b.create <scf::YieldOp>(loc, res.getResult ());
509+ },
510+ [&](OpBuilder &b, Location loc) {
511+ b.create <scf::YieldOp>(loc, spmdizedOperands[0 ]);
501512 });
502- spmdizationMap.map (op, ifOp.getOperation ());
503513
504- builder.create <mlir::mesh::UpdateHaloOp>(
505- loc, spmdizedOperands[0 ].getType (), spmdizedOperands[ 0 ] ,
514+ auto res = builder.create <mlir::mesh::UpdateHaloOp>(
515+ loc, spmdizedOperands[0 ].getType (), ifOp. getResult ( 0 ) ,
506516 dstSharding.getMeshAttr (),
507517 mlir::mesh::MeshAxesArrayAttr::get (op->getContext (),
508518 dstSharding.getSplitAxes ()),
509519 dstSharding.getDynamicHaloSizes (),
510520 DenseI64ArrayAttr::get (op->getContext (),
511521 dstSharding.getStaticHaloSizes ()));
522+
523+ spmdizationMap.map (op->getResult (0 ), res->getResult (0 ));
524+ spmdizationMap.map (op, res.getOperation ());
525+
512526 return success ();
513527 }
514528};
0 commit comments