Skip to content

Commit 61ffc4c

Browse files
committed
keep sharding
1 parent c69b285 commit 61ffc4c

File tree

1 file changed

+8
-0
lines changed

1 file changed

+8
-0
lines changed

src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30346,6 +30346,10 @@ struct ReduceUnusedMultiSlice final
3034630346
rewriter.getDenseI64ArrayAttr(startIndices),
3034730347
rewriter.getDenseI64ArrayAttr(limitIndices),
3034830348
rewriter.getDenseI64ArrayAttr(strides));
30349+
// Propagate sharding if present
30350+
if (auto shard = sdy::getShardingPerValue(op)) {
30351+
sdy::setShardings(sliceOp, shard);
30352+
}
3034930353

3035030354
rewriter.replaceAllUsesWith(op.getResult(usedIdx), sliceOp.getResult());
3035130355
rewriter.eraseOp(op);
@@ -30375,6 +30379,10 @@ struct ReduceUnusedMultiSlice final
3037530379
auto newOp = rewriter.create<enzymexla::MultiSliceOp>(
3037630380
op.getLoc(), resultTypes, op.getOperand(), startIndices, limitIndices,
3037730381
op.getStrides(), op.getDimension(), newLeftAmount, newRightAmount);
30382+
// Propagate sharding if present
30383+
if (auto shard = sdy::getShardingPerValue(op)) {
30384+
sdy::setShardings(newOp, shard);
30385+
}
3037830386

3037930387
// Map old results to new results
3038030388
SmallVector<Value> replacements(totalResults);

0 commit comments

Comments
 (0)