From 0a06aab127792542f7c3b4d70b2258bcdac3446f Mon Sep 17 00:00:00 2001 From: jumerckx <31353884+jumerckx@users.noreply.github.com> Date: Wed, 21 Jan 2026 15:48:28 -0600 Subject: [PATCH 1/9] ReduceUnusedMultiSlice --- src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp | 129 ++++++++++++++++++ src/enzyme_ad/jax/Passes/EnzymeHLOPatterns.h | 4 + .../jax/TransformOps/TransformOps.td | 5 + 3 files changed, 138 insertions(+) diff --git a/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp b/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp index c61a7d6545..1d65ccdff1 100644 --- a/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp +++ b/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp @@ -30279,6 +30279,123 @@ struct ReduceWindowWrapSimplify final } }; +// Pattern to reduce MultiSliceOp when some results are unused +struct ReduceUnusedMultiSlice final + : CheckedOpRewritePattern { + using CheckedOpRewritePattern::CheckedOpRewritePattern; + + LogicalResult matchAndRewriteImpl(enzymexla::MultiSliceOp op, + PatternRewriter &rewriter) const { + int32_t leftAmount = op.getLeftAmount(); + int32_t rightAmount = op.getRightAmount(); + int32_t totalResults = leftAmount + rightAmount + 1; + + // Check which results are actually used + SmallVector used(totalResults, false); + int usedCount = 0; + for (int i = 0; i < totalResults; i++) { + if (!op.getResult(i).use_empty()) { + used[i] = true; + usedCount++; + } + } + + // If all results are used, nothing to optimize + if (usedCount == totalResults) + return failure(); + + // If no results are used, this should be handled by dead code elimination + if (usedCount == 0) + return failure(); + + // Find the range of used results + int firstUsed = -1, lastUsed = -1; + for (int i = 0; i < totalResults; i++) { + if (used[i]) { + if (firstUsed == -1) + firstUsed = i; + lastUsed = i; + } + } + + // Calculate new left and right amounts + int centerIdx = leftAmount; + int newLeftAmount = centerIdx - firstUsed; + int newRightAmount = lastUsed - centerIdx; + + // If only one result is used, replace with a single SliceOp + if (usedCount == 1) { + int usedIdx = firstUsed; + int offset = usedIdx - centerIdx; // How much to shift the slice + + auto startIndices = SmallVector(op.getStartIndices()); + auto limitIndices = SmallVector(op.getLimitIndices()); + auto strides = SmallVector(op.getStrides()); + int32_t dim = op.getDimension(); + + // Adjust start and limit indices for the offset + if (dim >= 0 && dim < (int64_t)startIndices.size()) { + startIndices[dim] += offset; + limitIndices[dim] += offset; + } + + auto sliceOp = rewriter.create( + op.getLoc(), op.getOperand(), + rewriter.getDenseI64ArrayAttr(startIndices), + rewriter.getDenseI64ArrayAttr(limitIndices), + rewriter.getDenseI64ArrayAttr(strides)); + + rewriter.replaceAllUsesWith(op.getResult(usedIdx), sliceOp.getResult()); + rewriter.eraseOp(op); + return success(); + } + + // Otherwise, create a smaller MultiSliceOp + if (newLeftAmount != leftAmount || newRightAmount != rightAmount) { + // Adjust start indices for the new center + int offset = firstUsed - centerIdx; + auto startIndices = SmallVector(op.getStartIndices()); + auto limitIndices = SmallVector(op.getLimitIndices()); + int32_t dim = op.getDimension(); + + if (dim >= 0 && dim < (int64_t)startIndices.size()) { + startIndices[dim] += offset; + limitIndices[dim] += offset; + } + + // Determine result types + auto resultType = cast(op.getResultTypes().front()); + SmallVector resultTypes; + for (int i = 0; i < newLeftAmount + newRightAmount + 1; i++) { + resultTypes.push_back(resultType); // Will be properly typed by the op + } + + auto newOp = rewriter.create( + op.getLoc(), resultTypes, op.getOperand(), startIndices, limitIndices, + op.getStrides(), op.getDimension(), newLeftAmount, newRightAmount); + + // Map old results to new results + SmallVector replacements(totalResults); + int newIdx = 0; + for (int oldIdx = firstUsed; oldIdx <= lastUsed; oldIdx++) { + replacements[oldIdx] = newOp.getResult(newIdx++); + } + + // Replace uses + for (int i = 0; i < totalResults; i++) { + if (used[i]) { + op.getResult(i).replaceAllUsesWith(replacements[i]); + } + } + + rewriter.eraseOp(op); + return success(); + } + + return failure(); + } +}; + struct ScatterOpCanon final : CheckedOpRewritePattern { using CheckedOpRewritePattern::CheckedOpRewritePattern; @@ -30678,6 +30795,18 @@ void mlir::transform::addExtendLICM(RewritePatternSet &patterns, patterns.insert>(single_user, &context, benefit); } +void mlir::transform::addMultiSliceOpt(RewritePatternSet &patterns, + MLIRContext &context, + PatternBenefit benefit) { + patterns.insert(&context, benefit); +} +void mlir::transform::addMultiSliceLICM(RewritePatternSet &patterns, + bool single_user, MLIRContext &context, + PatternBenefit benefit) { + patterns.insert>(single_user, &context, + benefit); +} + void mlir::transform::addElementwiseLICM(RewritePatternSet &patterns, bool single_user, MLIRContext &context, PatternBenefit benefit) { diff --git a/src/enzyme_ad/jax/Passes/EnzymeHLOPatterns.h b/src/enzyme_ad/jax/Passes/EnzymeHLOPatterns.h index cf63c2ea3e..4234d69e53 100644 --- a/src/enzyme_ad/jax/Passes/EnzymeHLOPatterns.h +++ b/src/enzyme_ad/jax/Passes/EnzymeHLOPatterns.h @@ -134,5 +134,9 @@ void addSelfMulToConvolutionLike(RewritePatternSet &patterns, MLIRContext &context, PatternBenefit benefit); void addEnzymeHLOUnroll(RewritePatternSet &patterns, int64_t maxNumIterations, MLIRContext &context, PatternBenefit benefit); +void addMultiSliceOpt(RewritePatternSet &patterns, MLIRContext &context, + PatternBenefit benefit); +void addMultiSliceLICM(RewritePatternSet &patterns, bool single_user, + MLIRContext &context, PatternBenefit benefit); } // namespace mlir::transform diff --git a/src/enzyme_ad/jax/TransformOps/TransformOps.td b/src/enzyme_ad/jax/TransformOps/TransformOps.td index e831951521..fe806ba630 100644 --- a/src/enzyme_ad/jax/TransformOps/TransformOps.td +++ b/src/enzyme_ad/jax/TransformOps/TransformOps.td @@ -2172,6 +2172,11 @@ def TransposeRotate : EnzymeHLOPatternOp< let patterns = ["TransposeRotate"]; } +def ReduceUnusedMultiSlice : EnzymeHLOPatternOp< + "reduce_unused_multislice"> { + let patterns = ["ReduceUnusedMultiSlice"]; +} + def SelectPad : EnzymeHLOPatternOp< "select_pad"> { let patterns = ["SelectPad"]; From 0b8a86d89a74793c34d91f9bb83828cc92b7eb9b Mon Sep 17 00:00:00 2001 From: jumerckx <31353884+jumerckx@users.noreply.github.com> Date: Wed, 21 Jan 2026 15:48:36 -0600 Subject: [PATCH 2/9] filecheck test --- test/lit_tests/reduce_unused_multi_slice.mlir | 140 ++++++++++++++++++ 1 file changed, 140 insertions(+) create mode 100644 test/lit_tests/reduce_unused_multi_slice.mlir diff --git a/test/lit_tests/reduce_unused_multi_slice.mlir b/test/lit_tests/reduce_unused_multi_slice.mlir new file mode 100644 index 0000000000..3f75c0b474 --- /dev/null +++ b/test/lit_tests/reduce_unused_multi_slice.mlir @@ -0,0 +1,140 @@ +// RUN: enzymexlamlir-opt --enzyme-hlo-generate-td="patterns=reduce_unused_multislice" --transform-interpreter --enzyme-hlo-remove-transform %s | FileCheck %s + +// Test 1: Only center result used - should become a regular slice +func.func @multi_slice_only_center(%arg0: tensor<20x24x80xf64>) -> tensor<1x8x72xf64> { + %0, %1, %2, %3, %4, %5 = "enzymexla.multi_slice"(%arg0) <{ + start_indices = array, + limit_indices = array, + strides = array, + dimension = 2 : si32, + left_amount = 2 : si32, + right_amount = 3 : si32 + }> : (tensor<20x24x80xf64>) -> (tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>) + return %2 : tensor<1x8x72xf64> +} + +// CHECK-LABEL: func.func @multi_slice_only_center( +// CHECK-SAME: %[[VAL_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: tensor<20x24x80xf64>) -> tensor<1x8x72xf64> { +// CHECK: %[[VAL_1:.*]] = stablehlo.slice %[[VAL_0]] [1:2, 0:8, 3:75] : (tensor<20x24x80xf64>) -> tensor<1x8x72xf64> +// CHECK: return %[[VAL_1]] : tensor<1x8x72xf64> +// CHECK: } + + +// Test 2: Only left-most result used - should become a regular slice +func.func @multi_slice_only_left(%arg0: tensor<20x24x80xf64>) -> tensor<1x8x72xf64> { + %0, %1, %2, %3, %4, %5 = "enzymexla.multi_slice"(%arg0) <{ + start_indices = array, + limit_indices = array, + strides = array, + dimension = 2 : si32, + left_amount = 2 : si32, + right_amount = 3 : si32 + }> : (tensor<20x24x80xf64>) -> (tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>) + return %0 : tensor<1x8x72xf64> +} + +// CHECK-LABEL: func.func @multi_slice_only_left( +// CHECK-SAME: %[[VAL_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: tensor<20x24x80xf64>) -> tensor<1x8x72xf64> { +// CHECK: %[[VAL_1:.*]] = stablehlo.slice %[[VAL_0]] [1:2, 0:8, 1:73] : (tensor<20x24x80xf64>) -> tensor<1x8x72xf64> +// CHECK: return %[[VAL_1]] : tensor<1x8x72xf64> +// CHECK: } + + +// Test 3: Only right-most result used - should become a regular slice +func.func @multi_slice_only_right(%arg0: tensor<20x24x80xf64>) -> tensor<1x8x72xf64> { + %0, %1, %2, %3, %4, %5 = "enzymexla.multi_slice"(%arg0) <{ + start_indices = array, + limit_indices = array, + strides = array, + dimension = 2 : si32, + left_amount = 2 : si32, + right_amount = 3 : si32 + }> : (tensor<20x24x80xf64>) -> (tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>) + return %5 : tensor<1x8x72xf64> +} + +// CHECK-LABEL: func.func @multi_slice_only_right( +// CHECK-SAME: %[[VAL_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: tensor<20x24x80xf64>) -> tensor<1x8x72xf64> { +// CHECK: %[[VAL_1:.*]] = stablehlo.slice %[[VAL_0]] [1:2, 0:8, 6:78] : (tensor<20x24x80xf64>) -> tensor<1x8x72xf64> +// CHECK: return %[[VAL_1]] : tensor<1x8x72xf64> +// CHECK: } + + +// Test 4: Two consecutive results used - should become smaller multi_slice +func.func @multi_slice_consecutive(%arg0: tensor<20x24x80xf64>) -> (tensor<1x8x72xf64>, tensor<1x8x72xf64>) { + %0, %1, %2, %3, %4, %5 = "enzymexla.multi_slice"(%arg0) <{ + start_indices = array, + limit_indices = array, + strides = array, + dimension = 2 : si32, + left_amount = 2 : si32, + right_amount = 3 : si32 + }> : (tensor<20x24x80xf64>) -> (tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>) + return %2, %3 : tensor<1x8x72xf64>, tensor<1x8x72xf64> +} + +// CHECK-LABEL: func.func @multi_slice_consecutive( +// CHECK-SAME: %[[VAL_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: tensor<20x24x80xf64>) -> (tensor<1x8x72xf64>, tensor<1x8x72xf64>) { +// CHECK: %[[VAL_1:.*]]:2 = "enzymexla.multi_slice"(%[[VAL_0]]) <{dimension = 2 : si32, left_amount = 0 : si32, limit_indices = array, right_amount = 1 : si32, start_indices = array, strides = array}> : (tensor<20x24x80xf64>) -> (tensor<1x8x72xf64>, tensor<1x8x72xf64>) +// CHECK: return %[[VAL_1]]#0, %[[VAL_1]]#1 : tensor<1x8x72xf64>, tensor<1x8x72xf64> +// CHECK: } + + +// Test 5: Non-contiguous results used - should keep range between first and last used +func.func @multi_slice_non_contiguous(%arg0: tensor<20x24x80xf64>) -> (tensor<1x8x72xf64>, tensor<1x8x72xf64>) { + %0, %1, %2, %3, %4, %5 = "enzymexla.multi_slice"(%arg0) <{ + start_indices = array, + limit_indices = array, + strides = array, + dimension = 2 : si32, + left_amount = 2 : si32, + right_amount = 3 : si32 + }> : (tensor<20x24x80xf64>) -> (tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>) + return %2, %5 : tensor<1x8x72xf64>, tensor<1x8x72xf64> +} + +// CHECK-LABEL: func.func @multi_slice_non_contiguous( +// CHECK-SAME: %[[VAL_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: tensor<20x24x80xf64>) -> (tensor<1x8x72xf64>, tensor<1x8x72xf64>) { +// CHECK: %[[VAL_1:.*]]:4 = "enzymexla.multi_slice"(%[[VAL_0]]) <{dimension = 2 : si32, left_amount = 0 : si32, limit_indices = array, right_amount = 3 : si32, start_indices = array, strides = array}> : (tensor<20x24x80xf64>) -> (tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>) +// CHECK: return %[[VAL_1]]#0, %[[VAL_1]]#3 : tensor<1x8x72xf64>, tensor<1x8x72xf64> +// CHECK: } + + +// Test 6: All results used - should not change +func.func @multi_slice_all_used(%arg0: tensor<20x24x80xf64>) -> (tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>) { + %0, %1, %2, %3, %4, %5 = "enzymexla.multi_slice"(%arg0) <{ + start_indices = array, + limit_indices = array, + strides = array, + dimension = 2 : si32, + left_amount = 2 : si32, + right_amount = 3 : si32 + }> : (tensor<20x24x80xf64>) -> (tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>) + return %0, %1, %2, %3, %4, %5 : tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64> +} + +// CHECK-LABEL: func.func @multi_slice_all_used( +// CHECK-SAME: %[[VAL_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: tensor<20x24x80xf64>) -> (tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>) { +// CHECK: %[[VAL_1:.*]]:6 = "enzymexla.multi_slice"(%[[VAL_0]]) <{dimension = 2 : si32, left_amount = 2 : si32, limit_indices = array, right_amount = 3 : si32, start_indices = array, strides = array}> : (tensor<20x24x80xf64>) -> (tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>) +// CHECK: return %[[VAL_1]]#0, %[[VAL_1]]#1, %[[VAL_1]]#2, %[[VAL_1]]#3, %[[VAL_1]]#4, %[[VAL_1]]#5 : tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64> +// CHECK: } + + +// Test 7: Different dimension - test on dimension 0 +func.func @multi_slice_dim0(%arg0: tensor<20x24x80xf64>) -> tensor<4x24x80xf64> { + %0, %1, %2, %3, %4 = "enzymexla.multi_slice"(%arg0) <{ + start_indices = array, + limit_indices = array, + strides = array, + dimension = 0 : si32, + left_amount = 2 : si32, + right_amount = 2 : si32 + }> : (tensor<20x24x80xf64>) -> (tensor<4x24x80xf64>, tensor<4x24x80xf64>, tensor<4x24x80xf64>, tensor<4x24x80xf64>, tensor<4x24x80xf64>) + return %2 : tensor<4x24x80xf64> +} + +// CHECK-LABEL: func.func @multi_slice_dim0( +// CHECK-SAME: %[[VAL_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: tensor<20x24x80xf64>) -> tensor<4x24x80xf64> { +// CHECK: %[[VAL_1:.*]] = stablehlo.slice %[[VAL_0]] [8:12, 0:24, 0:80] : (tensor<20x24x80xf64>) -> tensor<4x24x80xf64> +// CHECK: return %[[VAL_1]] : tensor<4x24x80xf64> +// CHECK: } From 43c04a96630d1e977ba6fd278d97fff9cbe4ef1f Mon Sep 17 00:00:00 2001 From: jumerckx <31353884+jumerckx@users.noreply.github.com> Date: Thu, 22 Jan 2026 04:29:53 -0600 Subject: [PATCH 3/9] immediately erase --- src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp b/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp index 1d65ccdff1..31fa06dedb 100644 --- a/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp +++ b/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp @@ -30305,8 +30305,10 @@ struct ReduceUnusedMultiSlice final return failure(); // If no results are used, this should be handled by dead code elimination - if (usedCount == 0) + if (usedCount == 0) { + rewriter.eraseOp(op); return failure(); + } // Find the range of used results int firstUsed = -1, lastUsed = -1; From 51e8ca4f696a3a06fbc3e12498acb8475a6b27df Mon Sep 17 00:00:00 2001 From: jumerckx <31353884+jumerckx@users.noreply.github.com> Date: Thu, 22 Jan 2026 04:30:03 -0600 Subject: [PATCH 4/9] keep sharding --- src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp b/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp index 31fa06dedb..3ed8466119 100644 --- a/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp +++ b/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp @@ -30346,6 +30346,10 @@ struct ReduceUnusedMultiSlice final rewriter.getDenseI64ArrayAttr(startIndices), rewriter.getDenseI64ArrayAttr(limitIndices), rewriter.getDenseI64ArrayAttr(strides)); + // Propagate sharding if present + if (auto shard = sdy::getShardingPerValue(op)) { + sdy::setShardings(sliceOp, shard); + } rewriter.replaceAllUsesWith(op.getResult(usedIdx), sliceOp.getResult()); rewriter.eraseOp(op); @@ -30375,6 +30379,10 @@ struct ReduceUnusedMultiSlice final auto newOp = rewriter.create( op.getLoc(), resultTypes, op.getOperand(), startIndices, limitIndices, op.getStrides(), op.getDimension(), newLeftAmount, newRightAmount); + // Propagate sharding if present + if (auto shard = sdy::getShardingPerValue(op)) { + sdy::setShardings(newOp, shard); + } // Map old results to new results SmallVector replacements(totalResults); From e59fee1f3af37be7b735a9aad988915ff972f480 Mon Sep 17 00:00:00 2001 From: jumerckx <31353884+jumerckx@users.noreply.github.com> Date: Fri, 23 Jan 2026 02:46:50 -0600 Subject: [PATCH 5/9] fixes Co-authored-by: William S. Moses --- src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp b/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp index fbc8e31b1e..45f8074b0f 100644 --- a/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp +++ b/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp @@ -30321,7 +30321,7 @@ struct ReduceUnusedMultiSlice final // If no results are used, this should be handled by dead code elimination if (usedCount == 0) { rewriter.eraseOp(op); - return failure(); + return success(); } // Find the range of used results @@ -30355,18 +30355,17 @@ struct ReduceUnusedMultiSlice final limitIndices[dim] += offset; } - auto sliceOp = rewriter.create( - op.getLoc(), op.getOperand(), + auto shard = sdy::getShardingPerValue(op); + + auto sliceOp = rewriter.replaceOpWithNewOp( + op, op.getLoc(), op.getOperand(), rewriter.getDenseI64ArrayAttr(startIndices), rewriter.getDenseI64ArrayAttr(limitIndices), rewriter.getDenseI64ArrayAttr(strides)); // Propagate sharding if present - if (auto shard = sdy::getShardingPerValue(op)) { + if (shard) { sdy::setShardings(sliceOp, shard); } - - rewriter.replaceAllUsesWith(op.getResult(usedIdx), sliceOp.getResult()); - rewriter.eraseOp(op); return success(); } From 589414470063954b94fbd24fbe385f096a8431c3 Mon Sep 17 00:00:00 2001 From: jumerckx <31353884+jumerckx@users.noreply.github.com> Date: Sat, 24 Jan 2026 08:07:02 -0600 Subject: [PATCH 6/9] fix --- src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp b/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp index 45f8074b0f..551c0cab21 100644 --- a/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp +++ b/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp @@ -30358,8 +30358,7 @@ struct ReduceUnusedMultiSlice final auto shard = sdy::getShardingPerValue(op); auto sliceOp = rewriter.replaceOpWithNewOp( - op, op.getLoc(), op.getOperand(), - rewriter.getDenseI64ArrayAttr(startIndices), + op, op.getOperand(), rewriter.getDenseI64ArrayAttr(startIndices), rewriter.getDenseI64ArrayAttr(limitIndices), rewriter.getDenseI64ArrayAttr(strides)); // Propagate sharding if present From 4924e9f5b3350a1ada37719be43dcaa0374c1150 Mon Sep 17 00:00:00 2001 From: jumerckx <31353884+jumerckx@users.noreply.github.com> Date: Sat, 24 Jan 2026 09:26:35 -0600 Subject: [PATCH 7/9] revert use of replaceOpWithNewOp --- src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp b/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp index 551c0cab21..73e1d0ed2d 100644 --- a/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp +++ b/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp @@ -30355,16 +30355,18 @@ struct ReduceUnusedMultiSlice final limitIndices[dim] += offset; } - auto shard = sdy::getShardingPerValue(op); - - auto sliceOp = rewriter.replaceOpWithNewOp( - op, op.getOperand(), rewriter.getDenseI64ArrayAttr(startIndices), + auto sliceOp = rewriter.create( + op.getLoc(), op.getOperand(), + rewriter.getDenseI64ArrayAttr(startIndices), rewriter.getDenseI64ArrayAttr(limitIndices), rewriter.getDenseI64ArrayAttr(strides)); // Propagate sharding if present - if (shard) { + if (auto shard = sdy::getShardingPerValue(op)) { sdy::setShardings(sliceOp, shard); } + + rewriter.replaceAllUsesWith(op.getResult(usedIdx), sliceOp.getResult()); + rewriter.eraseOp(op); return success(); } From a4a2cdfa0e9e74c7c30203591647dbabe49cb8ab Mon Sep 17 00:00:00 2001 From: jumerckx <31353884+jumerckx@users.noreply.github.com> Date: Sat, 24 Jan 2026 12:14:39 -0600 Subject: [PATCH 8/9] remove addMultiSliceOpt --- src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp | 5 ----- src/enzyme_ad/jax/Passes/EnzymeHLOPatterns.h | 2 -- 2 files changed, 7 deletions(-) diff --git a/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp b/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp index 73e1d0ed2d..9dc2f612dc 100644 --- a/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp +++ b/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp @@ -30819,11 +30819,6 @@ void mlir::transform::addExtendLICM(RewritePatternSet &patterns, patterns.insert>(single_user, &context, benefit); } -void mlir::transform::addMultiSliceOpt(RewritePatternSet &patterns, - MLIRContext &context, - PatternBenefit benefit) { - patterns.insert(&context, benefit); -} void mlir::transform::addMultiSliceLICM(RewritePatternSet &patterns, bool single_user, MLIRContext &context, PatternBenefit benefit) { diff --git a/src/enzyme_ad/jax/Passes/EnzymeHLOPatterns.h b/src/enzyme_ad/jax/Passes/EnzymeHLOPatterns.h index 4234d69e53..11c9a74e08 100644 --- a/src/enzyme_ad/jax/Passes/EnzymeHLOPatterns.h +++ b/src/enzyme_ad/jax/Passes/EnzymeHLOPatterns.h @@ -134,8 +134,6 @@ void addSelfMulToConvolutionLike(RewritePatternSet &patterns, MLIRContext &context, PatternBenefit benefit); void addEnzymeHLOUnroll(RewritePatternSet &patterns, int64_t maxNumIterations, MLIRContext &context, PatternBenefit benefit); -void addMultiSliceOpt(RewritePatternSet &patterns, MLIRContext &context, - PatternBenefit benefit); void addMultiSliceLICM(RewritePatternSet &patterns, bool single_user, MLIRContext &context, PatternBenefit benefit); From 28882b678ed7a696303d0108f58d6c7378059337 Mon Sep 17 00:00:00 2001 From: jumerckx <31353884+jumerckx@users.noreply.github.com> Date: Sat, 24 Jan 2026 13:58:48 -0600 Subject: [PATCH 9/9] fix merge mistake --- src/enzyme_ad/jax/TransformOps/TransformOps.td | 1 + 1 file changed, 1 insertion(+) diff --git a/src/enzyme_ad/jax/TransformOps/TransformOps.td b/src/enzyme_ad/jax/TransformOps/TransformOps.td index 573f77366f..76e3919f15 100644 --- a/src/enzyme_ad/jax/TransformOps/TransformOps.td +++ b/src/enzyme_ad/jax/TransformOps/TransformOps.td @@ -2175,6 +2175,7 @@ def TransposeRotate : EnzymeHLOPatternOp< def ReduceUnusedMultiSlice : EnzymeHLOPatternOp< "reduce_unused_multislice"> { let patterns = ["ReduceUnusedMultiSlice"]; +} def ReduceUnusedMultiRotate : EnzymeHLOPatternOp< "reduce_unused_multirotate"> {