Skip to content

Commit 6df5ab3

Browse files
committed
ReduceUnusedMultiSlice
1 parent 1e16f33 commit 6df5ab3

File tree

3 files changed

+140
-0
lines changed

3 files changed

+140
-0
lines changed

src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp

Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30279,6 +30279,125 @@ struct ReduceWindowWrapSimplify final
3027930279
}
3028030280
};
3028130281

30282+
// Pattern to reduce MultiSliceOp when some results are unused
30283+
struct ReduceUnusedMultiSlice final
30284+
: CheckedOpRewritePattern<enzymexla::MultiSliceOp, ReduceUnusedMultiSlice> {
30285+
using CheckedOpRewritePattern::CheckedOpRewritePattern;
30286+
30287+
LogicalResult matchAndRewriteImpl(enzymexla::MultiSliceOp op,
30288+
PatternRewriter &rewriter) const {
30289+
int32_t leftAmount = op.getLeftAmount();
30290+
int32_t rightAmount = op.getRightAmount();
30291+
int32_t totalResults = leftAmount + rightAmount + 1;
30292+
30293+
// Check which results are actually used
30294+
SmallVector<bool> used(totalResults, false);
30295+
int usedCount = 0;
30296+
for (int i = 0; i < totalResults; i++) {
30297+
if (!op.getResult(i).use_empty()) {
30298+
used[i] = true;
30299+
usedCount++;
30300+
}
30301+
}
30302+
30303+
// If all results are used, nothing to optimize
30304+
if (usedCount == totalResults)
30305+
return failure();
30306+
30307+
// If no results are used, this should be handled by dead code elimination
30308+
if (usedCount == 0)
30309+
return failure();
30310+
30311+
// Find the range of used results
30312+
int firstUsed = -1, lastUsed = -1;
30313+
for (int i = 0; i < totalResults; i++) {
30314+
if (used[i]) {
30315+
if (firstUsed == -1)
30316+
firstUsed = i;
30317+
lastUsed = i;
30318+
}
30319+
}
30320+
30321+
// Calculate new left and right amounts
30322+
int centerIdx = leftAmount;
30323+
int newLeftAmount = centerIdx - firstUsed;
30324+
int newRightAmount = lastUsed - centerIdx;
30325+
30326+
// If only one result is used, replace with a single SliceOp
30327+
if (usedCount == 1) {
30328+
int usedIdx = firstUsed;
30329+
int offset = usedIdx - centerIdx; // How much to shift the slice
30330+
30331+
auto startIndices = SmallVector<int64_t>(op.getStartIndices());
30332+
auto limitIndices = SmallVector<int64_t>(op.getLimitIndices());
30333+
auto strides = SmallVector<int64_t>(op.getStrides());
30334+
int32_t dim = op.getDimension();
30335+
30336+
// Adjust start and limit indices for the offset
30337+
if (dim >= 0 && dim < (int64_t)startIndices.size()) {
30338+
startIndices[dim] += offset;
30339+
limitIndices[dim] += offset;
30340+
}
30341+
30342+
auto sliceOp = rewriter.create<stablehlo::SliceOp>(
30343+
op.getLoc(), op.getOperand(),
30344+
rewriter.getDenseI64ArrayAttr(startIndices),
30345+
rewriter.getDenseI64ArrayAttr(limitIndices),
30346+
rewriter.getDenseI64ArrayAttr(strides));
30347+
30348+
rewriter.replaceAllUsesWith(op.getResult(usedIdx), sliceOp.getResult());
30349+
rewriter.eraseOp(op);
30350+
return success();
30351+
}
30352+
30353+
// Otherwise, create a smaller MultiSliceOp
30354+
if (newLeftAmount != leftAmount || newRightAmount != rightAmount) {
30355+
// Adjust start indices for the new center
30356+
int offset = firstUsed - centerIdx;
30357+
auto startIndices = SmallVector<int64_t>(op.getStartIndices());
30358+
auto limitIndices = SmallVector<int64_t>(op.getLimitIndices());
30359+
int32_t dim = op.getDimension();
30360+
30361+
if (dim >= 0 && dim < (int64_t)startIndices.size()) {
30362+
startIndices[dim] += offset;
30363+
limitIndices[dim] += offset;
30364+
}
30365+
30366+
// Determine result types
30367+
auto resultType = cast<RankedTensorType>(op.getResultTypes().front());
30368+
SmallVector<Type> resultTypes;
30369+
for (int i = 0; i < newLeftAmount + newRightAmount + 1; i++) {
30370+
resultTypes.push_back(resultType); // Will be properly typed by the op
30371+
}
30372+
30373+
auto newOp = rewriter.create<enzymexla::MultiSliceOp>(
30374+
op.getLoc(), resultTypes, op.getOperand(),
30375+
startIndices, limitIndices, op.getStrides(),
30376+
op.getDimension(), newLeftAmount, newRightAmount);
30377+
30378+
// Map old results to new results
30379+
SmallVector<Value> replacements(totalResults);
30380+
int newIdx = 0;
30381+
for (int oldIdx = firstUsed; oldIdx <= lastUsed; oldIdx++) {
30382+
replacements[oldIdx] = newOp.getResult(newIdx++);
30383+
}
30384+
30385+
// Replace uses
30386+
for (int i = 0; i < totalResults; i++) {
30387+
if (used[i]) {
30388+
op.getResult(i).replaceAllUsesWith(replacements[i]);
30389+
}
30390+
}
30391+
30392+
rewriter.eraseOp(op);
30393+
return success();
30394+
}
30395+
30396+
return failure();
30397+
}
30398+
};
30399+
30400+
3028230401
struct ScatterOpCanon final
3028330402
: CheckedOpRewritePattern<stablehlo::ScatterOp, ScatterOpCanon> {
3028430403
using CheckedOpRewritePattern::CheckedOpRewritePattern;
@@ -30678,6 +30797,18 @@ void mlir::transform::addExtendLICM(RewritePatternSet &patterns,
3067830797
patterns.insert<LICM<enzymexla::ExtendOp>>(single_user, &context, benefit);
3067930798
}
3068030799

30800+
void mlir::transform::addMultiSliceOpt(RewritePatternSet &patterns,
30801+
MLIRContext &context,
30802+
PatternBenefit benefit) {
30803+
patterns.insert<ReduceUnusedMultiSlice>(&context, benefit);
30804+
}
30805+
void mlir::transform::addMultiSliceLICM(RewritePatternSet &patterns,
30806+
bool single_user, MLIRContext &context,
30807+
PatternBenefit benefit) {
30808+
patterns.insert<LICM<enzymexla::MultiSliceOp>>(single_user, &context,
30809+
benefit);
30810+
}
30811+
3068130812
void mlir::transform::addElementwiseLICM(RewritePatternSet &patterns,
3068230813
bool single_user, MLIRContext &context,
3068330814
PatternBenefit benefit) {

src/enzyme_ad/jax/Passes/EnzymeHLOPatterns.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,5 +134,9 @@ void addSelfMulToConvolutionLike(RewritePatternSet &patterns,
134134
MLIRContext &context, PatternBenefit benefit);
135135
void addEnzymeHLOUnroll(RewritePatternSet &patterns, int64_t maxNumIterations,
136136
MLIRContext &context, PatternBenefit benefit);
137+
void addMultiSliceOpt(RewritePatternSet &patterns, MLIRContext &context,
138+
PatternBenefit benefit);
139+
void addMultiSliceLICM(RewritePatternSet &patterns, bool single_user,
140+
MLIRContext &context, PatternBenefit benefit);
137141

138142
} // namespace mlir::transform

src/enzyme_ad/jax/TransformOps/TransformOps.td

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2172,6 +2172,11 @@ def TransposeRotate : EnzymeHLOPatternOp<
21722172
let patterns = ["TransposeRotate"];
21732173
}
21742174

2175+
def ReduceUnusedMultiSlice : EnzymeHLOPatternOp<
2176+
"reduce_unused_multislice"> {
2177+
let patterns = ["ReduceUnusedMultiSlice"];
2178+
}
2179+
21752180
def SelectPad : EnzymeHLOPatternOp<
21762181
"select_pad"> {
21772182
let patterns = ["SelectPad"];

0 commit comments

Comments
 (0)