Skip to content

Commit 6ee3453

Browse files
committed
fixes for insert_slice
1 parent f474abb commit 6ee3453

File tree

4 files changed

+45
-19
lines changed

4 files changed

+45
-19
lines changed

build_tools/llvm_version.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
0eeb79d76a8284fae3e5e3b4ebbbe98d02249235
1+
d8bb4e6495793fc6bbc38a75dbef52091139c68a

lib/Dialect/NDArray/Extensions/MeshShardingExtensions.cpp

Lines changed: 24 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,7 @@ static std::array<Value, 2> getShardSliceOffAndSz(
193193
ValueRange myIdx, int64_t dim, ArrayRef<int64_t> meshShape,
194194
ArrayRef<MeshAxesAttr> splitAxes, Value targetOffs,
195195
ArrayRef<int64_t> srcShape, const SmallVector<OpFoldResult> &slcOffs,
196+
const SmallVector<OpFoldResult> &slcSizes,
196197
const SmallVector<OpFoldResult> &slcStrides,
197198
const SmallVector<OpFoldResult> &haloSizes, const EasyI64 &zero,
198199
const EasyI64 &one, OpBuilder &builder, Location loc) {
@@ -214,8 +215,12 @@ static std::array<Value, 2> getShardSliceOffAndSz(
214215
std::tie(myOff, mySize) =
215216
getOffsetAndSize(myID, zero, one, targetOffs, currPos, builder, loc);
216217
} else {
217-
myOff = getBaseShardDimOff(myID, numShards, extend, zero).get();
218-
mySize = getBaseShardDimSize(myID, numShards, extend, one, zero).get();
218+
auto myOff_ = getBaseShardDimOff(myID, numShards, extend, zero);
219+
auto mySize_ = getBaseShardDimSize(myID, numShards, extend, one, zero);
220+
auto slcSz = easyI64(loc, builder, slcSizes[dim]);
221+
mySize_ = zero.max(slcSz - myOff_).min(mySize_);
222+
myOff = myOff_.get();
223+
mySize = mySize_.get();
219224
}
220225

221226
// the global offset of the local shard is slice offset plus the computed
@@ -290,7 +295,7 @@ getLocalOffSzAndStrFromSlice(OP op, ArrayRef<int64_t> srcShape,
290295
} else {
291296
auto offAndSz = getShardSliceOffAndSz(
292297
myIdx, dim, mesh.getShape(), splitAxes, targetOffs, srcShape, slcOffs,
293-
slcStrides, haloSizes, zero, one, builder, loc);
298+
slcSizes, slcStrides, haloSizes, zero, one, builder, loc);
294299
lShardOffs.emplace_back(offAndSz[0]);
295300
lShardSizes.emplace_back(offAndSz[1]);
296301
}
@@ -439,6 +444,7 @@ struct InsertSliceShardingInterface
439444
}
440445
auto dstSharding = mlir::mesh::MeshSharding::get(shardingOption.mesh, res);
441446
maybeInsertSourceShardingAnnotation(dstSharding, op->getOpOperand(0), b);
447+
maybeInsertTargetShardingAnnotation(dstSharding, op->getResult(0), b);
442448

443449
return success();
444450
}
@@ -449,7 +455,8 @@ struct InsertSliceShardingInterface
449455
IRMapping &spmdizationMap,
450456
SymbolTableCollection &symbolTableCollection,
451457
OpBuilder &builder) const {
452-
if (resultShardings.size() != 0) {
458+
if (resultShardings.size() != 1 || operandShardings.size() < 2 ||
459+
resultShardings[0] != operandShardings[0]) {
453460
return failure();
454461
}
455462

@@ -493,22 +500,29 @@ struct InsertSliceShardingInterface
493500
}
494501

495502
scf::IfOp ifOp = builder.create<scf::IfOp>(
496-
loc, hasSize.get(), [&](OpBuilder &b, Location loc) {
497-
(void)b.create<imex::ndarray::InsertSliceOp>(
503+
loc, hasSize.get(),
504+
[&](OpBuilder &b, Location loc) {
505+
auto res = b.create<imex::ndarray::InsertSliceOp>(
498506
loc, spmdizedOperands[0], spmdizedOperands[1], lShardOffs,
499507
lShardSizes, lShardStrides);
500-
b.create<scf::YieldOp>(loc);
508+
b.create<scf::YieldOp>(loc, res.getResult());
509+
},
510+
[&](OpBuilder &b, Location loc) {
511+
b.create<scf::YieldOp>(loc, spmdizedOperands[0]);
501512
});
502-
spmdizationMap.map(op, ifOp.getOperation());
503513

504-
builder.create<mlir::mesh::UpdateHaloOp>(
505-
loc, spmdizedOperands[0].getType(), spmdizedOperands[0],
514+
auto res = builder.create<mlir::mesh::UpdateHaloOp>(
515+
loc, spmdizedOperands[0].getType(), ifOp.getResult(0),
506516
dstSharding.getMeshAttr(),
507517
mlir::mesh::MeshAxesArrayAttr::get(op->getContext(),
508518
dstSharding.getSplitAxes()),
509519
dstSharding.getDynamicHaloSizes(),
510520
DenseI64ArrayAttr::get(op->getContext(),
511521
dstSharding.getStaticHaloSizes()));
522+
523+
spmdizationMap.map(op->getResult(0), res->getResult(0));
524+
spmdizationMap.map(op, res.getOperation());
525+
512526
return success();
513527
}
514528
};

lib/Dialect/NDArray/IR/InsertSliceOp.cpp

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -132,22 +132,31 @@ class InsertSliceOpConstantArgumentFolder final
132132
return mlir::failure();
133133

134134
auto sourceType = insertSliceOp.getSourceType();
135-
auto dstTnsrType = insertSliceOp.getDestinationType(); //.getTensorType();
135+
auto dstTnsrType = insertSliceOp.getDestinationType();
136+
136137
// Create the new op in canonical form.
137138
auto sourceTnsrType =
138139
mlir::tensor::ExtractSliceOp::inferCanonicalRankReducedResultType(
139140
insertSliceOp.getSourceType().getRank(), dstTnsrType, mixedOffsets,
140141
mixedSizes, mixedStrides);
141142
auto newSourceType = sourceType.cloneWith(sourceTnsrType.getShape(),
142143
sourceTnsrType.getElementType());
144+
143145
mlir::Value toInsert = insertSliceOp.getSource();
144146
if (newSourceType != sourceType) {
145-
if (newSourceType.getRank() != sourceType.getRank())
147+
if (sourceType.getRank() == 0) {
148+
if (newSourceType.getRank() > 1) {
149+
return mlir::failure();
150+
}
151+
} else if (newSourceType.getRank() != sourceType.getRank()) {
146152
return mlir::failure();
147-
mlir::OpBuilder::InsertionGuard g(rewriter);
148-
toInsert = rewriter.create<mlir::tensor::CastOp>(insertSliceOp.getLoc(),
149-
newSourceType, toInsert);
153+
} else {
154+
mlir::OpBuilder::InsertionGuard g(rewriter);
155+
toInsert = rewriter.create<mlir::tensor::CastOp>(
156+
insertSliceOp.getLoc(), newSourceType, toInsert);
157+
}
150158
}
159+
151160
rewriter.replaceOpWithNewOp<InsertOpTy>(
152161
insertSliceOp, insertSliceOp.getDestination(), toInsert, mixedOffsets,
153162
mixedSizes, mixedStrides);

lib/Dialect/NDArray/Transforms/CoalesceShardOps.cpp

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -101,8 +101,8 @@ struct CoalesceShardOpsPass
101101
return defOp;
102102
} else if (auto op = ::mlir::dyn_cast<::mlir::DestinationStyleOpInterface>(
103103
defOp)) {
104-
return op.getNumDpsInputs() == 1 ? op.getDpsInits()[0].getDefiningOp()
105-
: defOp;
104+
return op.getNumDpsInits() == 1 ? getBaseArray(op.getDpsInits()[0])
105+
: defOp;
106106
} else if (auto op = ::mlir::dyn_cast<::imex::ndarray::SubviewOp>(defOp)) {
107107
return getBaseArray(op.getSource());
108108
} else if (auto op =
@@ -479,7 +479,10 @@ struct CoalesceShardOpsPass
479479

480480
// update shardOps of dependent Subview/InsertSliceOps
481481
for (auto svShardOp : shardOps) {
482-
svShardOp.getSrcMutable().assign(newShardOp.getResult());
482+
assert(svShardOp->hasOneUse());
483+
if (mlir::isa<::imex::ndarray::SubviewOp>(*svShardOp->user_begin())) {
484+
svShardOp.getSrcMutable().assign(newShardOp.getResult());
485+
}
483486
svShardOp.getShardingMutable().assign(newSharding);
484487
}
485488
// barriers/halo-updates get inserted when InsertSliceOps (or other write

0 commit comments

Comments
 (0)