@@ -30279,6 +30279,125 @@ struct ReduceWindowWrapSimplify final
3027930279 }
3028030280};
3028130281
30282+ // Pattern to reduce MultiSliceOp when some results are unused
30283+ struct ReduceUnusedMultiSlice final
30284+ : CheckedOpRewritePattern<enzymexla::MultiSliceOp, ReduceUnusedMultiSlice> {
30285+ using CheckedOpRewritePattern::CheckedOpRewritePattern;
30286+
30287+ LogicalResult matchAndRewriteImpl(enzymexla::MultiSliceOp op,
30288+ PatternRewriter &rewriter) const {
30289+ int32_t leftAmount = op.getLeftAmount();
30290+ int32_t rightAmount = op.getRightAmount();
30291+ int32_t totalResults = leftAmount + rightAmount + 1;
30292+
30293+ // Check which results are actually used
30294+ SmallVector<bool> used(totalResults, false);
30295+ int usedCount = 0;
30296+ for (int i = 0; i < totalResults; i++) {
30297+ if (!op.getResult(i).use_empty()) {
30298+ used[i] = true;
30299+ usedCount++;
30300+ }
30301+ }
30302+
30303+ // If all results are used, nothing to optimize
30304+ if (usedCount == totalResults)
30305+ return failure();
30306+
30307+ // If no results are used, this should be handled by dead code elimination
30308+ if (usedCount == 0)
30309+ return failure();
30310+
30311+ // Find the range of used results
30312+ int firstUsed = -1, lastUsed = -1;
30313+ for (int i = 0; i < totalResults; i++) {
30314+ if (used[i]) {
30315+ if (firstUsed == -1)
30316+ firstUsed = i;
30317+ lastUsed = i;
30318+ }
30319+ }
30320+
30321+ // Calculate new left and right amounts
30322+ int centerIdx = leftAmount;
30323+ int newLeftAmount = centerIdx - firstUsed;
30324+ int newRightAmount = lastUsed - centerIdx;
30325+
30326+ // If only one result is used, replace with a single SliceOp
30327+ if (usedCount == 1) {
30328+ int usedIdx = firstUsed;
30329+ int offset = usedIdx - centerIdx; // How much to shift the slice
30330+
30331+ auto startIndices = SmallVector<int64_t>(op.getStartIndices());
30332+ auto limitIndices = SmallVector<int64_t>(op.getLimitIndices());
30333+ auto strides = SmallVector<int64_t>(op.getStrides());
30334+ int32_t dim = op.getDimension();
30335+
30336+ // Adjust start and limit indices for the offset
30337+ if (dim >= 0 && dim < (int64_t)startIndices.size()) {
30338+ startIndices[dim] += offset;
30339+ limitIndices[dim] += offset;
30340+ }
30341+
30342+ auto sliceOp = rewriter.create<stablehlo::SliceOp>(
30343+ op.getLoc(), op.getOperand(),
30344+ rewriter.getDenseI64ArrayAttr(startIndices),
30345+ rewriter.getDenseI64ArrayAttr(limitIndices),
30346+ rewriter.getDenseI64ArrayAttr(strides));
30347+
30348+ rewriter.replaceAllUsesWith(op.getResult(usedIdx), sliceOp.getResult());
30349+ rewriter.eraseOp(op);
30350+ return success();
30351+ }
30352+
30353+ // Otherwise, create a smaller MultiSliceOp
30354+ if (newLeftAmount != leftAmount || newRightAmount != rightAmount) {
30355+ // Adjust start indices for the new center
30356+ int offset = firstUsed - centerIdx;
30357+ auto startIndices = SmallVector<int64_t>(op.getStartIndices());
30358+ auto limitIndices = SmallVector<int64_t>(op.getLimitIndices());
30359+ int32_t dim = op.getDimension();
30360+
30361+ if (dim >= 0 && dim < (int64_t)startIndices.size()) {
30362+ startIndices[dim] += offset;
30363+ limitIndices[dim] += offset;
30364+ }
30365+
30366+ // Determine result types
30367+ auto resultType = cast<RankedTensorType>(op.getResultTypes().front());
30368+ SmallVector<Type> resultTypes;
30369+ for (int i = 0; i < newLeftAmount + newRightAmount + 1; i++) {
30370+ resultTypes.push_back(resultType); // Will be properly typed by the op
30371+ }
30372+
30373+ auto newOp = rewriter.create<enzymexla::MultiSliceOp>(
30374+ op.getLoc(), resultTypes, op.getOperand(),
30375+ startIndices, limitIndices, op.getStrides(),
30376+ op.getDimension(), newLeftAmount, newRightAmount);
30377+
30378+ // Map old results to new results
30379+ SmallVector<Value> replacements(totalResults);
30380+ int newIdx = 0;
30381+ for (int oldIdx = firstUsed; oldIdx <= lastUsed; oldIdx++) {
30382+ replacements[oldIdx] = newOp.getResult(newIdx++);
30383+ }
30384+
30385+ // Replace uses
30386+ for (int i = 0; i < totalResults; i++) {
30387+ if (used[i]) {
30388+ op.getResult(i).replaceAllUsesWith(replacements[i]);
30389+ }
30390+ }
30391+
30392+ rewriter.eraseOp(op);
30393+ return success();
30394+ }
30395+
30396+ return failure();
30397+ }
30398+ };
30399+
30400+
3028230401struct ScatterOpCanon final
3028330402 : CheckedOpRewritePattern<stablehlo::ScatterOp, ScatterOpCanon> {
3028430403 using CheckedOpRewritePattern::CheckedOpRewritePattern;
@@ -30678,6 +30797,18 @@ void mlir::transform::addExtendLICM(RewritePatternSet &patterns,
3067830797 patterns.insert<LICM<enzymexla::ExtendOp>>(single_user, &context, benefit);
3067930798}
3068030799
30800+ void mlir::transform::addMultiSliceOpt(RewritePatternSet &patterns,
30801+ MLIRContext &context,
30802+ PatternBenefit benefit) {
30803+ patterns.insert<ReduceUnusedMultiSlice>(&context, benefit);
30804+ }
30805+ void mlir::transform::addMultiSliceLICM(RewritePatternSet &patterns,
30806+ bool single_user, MLIRContext &context,
30807+ PatternBenefit benefit) {
30808+ patterns.insert<LICM<enzymexla::MultiSliceOp>>(single_user, &context,
30809+ benefit);
30810+ }
30811+
3068130812void mlir::transform::addElementwiseLICM(RewritePatternSet &patterns,
3068230813 bool single_user, MLIRContext &context,
3068330814 PatternBenefit benefit) {
0 commit comments