diff --git a/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp b/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp index 6674eb6b5c..8d6e08077e 100644 --- a/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp +++ b/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp @@ -30304,6 +30304,133 @@ struct ReduceWindowWrapSimplify final } }; +// Pattern to reduce MultiSliceOp when some results are unused +struct ReduceUnusedMultiSlice final + : CheckedOpRewritePattern { + using CheckedOpRewritePattern::CheckedOpRewritePattern; + + LogicalResult matchAndRewriteImpl(enzymexla::MultiSliceOp op, + PatternRewriter &rewriter) const { + int32_t leftAmount = op.getLeftAmount(); + int32_t rightAmount = op.getRightAmount(); + int32_t totalResults = leftAmount + rightAmount + 1; + + // Check which results are actually used + SmallVector used(totalResults, false); + int usedCount = 0; + for (int i = 0; i < totalResults; i++) { + if (!op.getResult(i).use_empty()) { + used[i] = true; + usedCount++; + } + } + + // If all results are used, nothing to optimize + if (usedCount == totalResults) + return failure(); + + // If no results are used, this should be handled by dead code elimination + if (usedCount == 0) { + rewriter.eraseOp(op); + return success(); + } + + // Find the range of used results + int firstUsed = -1, lastUsed = -1; + for (int i = 0; i < totalResults; i++) { + if (used[i]) { + if (firstUsed == -1) + firstUsed = i; + lastUsed = i; + } + } + + // Calculate new left and right amounts + int centerIdx = leftAmount; + int newLeftAmount = centerIdx - firstUsed; + int newRightAmount = lastUsed - centerIdx; + + // If only one result is used, replace with a single SliceOp + if (usedCount == 1) { + int usedIdx = firstUsed; + int offset = usedIdx - centerIdx; // How much to shift the slice + + auto startIndices = SmallVector(op.getStartIndices()); + auto limitIndices = SmallVector(op.getLimitIndices()); + auto strides = SmallVector(op.getStrides()); + int32_t dim = op.getDimension(); + + // Adjust start and limit indices for the offset + if (dim >= 0 && dim < (int64_t)startIndices.size()) { + startIndices[dim] += offset; + limitIndices[dim] += offset; + } + + auto sliceOp = rewriter.create( + op.getLoc(), op.getOperand(), + rewriter.getDenseI64ArrayAttr(startIndices), + rewriter.getDenseI64ArrayAttr(limitIndices), + rewriter.getDenseI64ArrayAttr(strides)); + // Propagate sharding if present + if (auto shard = sdy::getShardingPerValue(op)) { + sdy::setShardings(sliceOp, shard); + } + + rewriter.replaceAllUsesWith(op.getResult(usedIdx), sliceOp.getResult()); + rewriter.eraseOp(op); + return success(); + } + + // Otherwise, create a smaller MultiSliceOp + if (newLeftAmount != leftAmount || newRightAmount != rightAmount) { + // Adjust start indices for the new center + int offset = firstUsed - centerIdx; + auto startIndices = SmallVector(op.getStartIndices()); + auto limitIndices = SmallVector(op.getLimitIndices()); + int32_t dim = op.getDimension(); + + if (dim >= 0 && dim < (int64_t)startIndices.size()) { + startIndices[dim] += offset; + limitIndices[dim] += offset; + } + + // Determine result types + auto resultType = cast(op.getResultTypes().front()); + SmallVector resultTypes; + for (int i = 0; i < newLeftAmount + newRightAmount + 1; i++) { + resultTypes.push_back(resultType); // Will be properly typed by the op + } + + auto newOp = rewriter.create( + op.getLoc(), resultTypes, op.getOperand(), startIndices, limitIndices, + op.getStrides(), op.getDimension(), newLeftAmount, newRightAmount); + // Propagate sharding if present + if (auto shard = sdy::getShardingPerValue(op)) { + sdy::setShardings(newOp, shard); + } + + // Map old results to new results + SmallVector replacements(totalResults); + int newIdx = 0; + for (int oldIdx = firstUsed; oldIdx <= lastUsed; oldIdx++) { + replacements[oldIdx] = newOp.getResult(newIdx++); + } + + // Replace uses + for (int i = 0; i < totalResults; i++) { + if (used[i]) { + op.getResult(i).replaceAllUsesWith(replacements[i]); + } + } + + rewriter.eraseOp(op); + return success(); + } + + return failure(); + } +}; + struct RecognizeMultiRotate : public CheckedOpRewritePattern { @@ -30955,6 +31082,13 @@ void mlir::transform::addExtendLICM(RewritePatternSet &patterns, patterns.insert>(single_user, &context, benefit); } +void mlir::transform::addMultiSliceLICM(RewritePatternSet &patterns, + bool single_user, MLIRContext &context, + PatternBenefit benefit) { + patterns.insert>(single_user, &context, + benefit); +} + void mlir::transform::addMultiRotateLICM(RewritePatternSet &patterns, bool single_user, MLIRContext &context, PatternBenefit benefit) { diff --git a/src/enzyme_ad/jax/Passes/EnzymeHLOPatterns.h b/src/enzyme_ad/jax/Passes/EnzymeHLOPatterns.h index 5ebd178580..68c7aebb4a 100644 --- a/src/enzyme_ad/jax/Passes/EnzymeHLOPatterns.h +++ b/src/enzyme_ad/jax/Passes/EnzymeHLOPatterns.h @@ -134,6 +134,8 @@ void addSelfMulToConvolutionLike(RewritePatternSet &patterns, MLIRContext &context, PatternBenefit benefit); void addEnzymeHLOUnroll(RewritePatternSet &patterns, int64_t maxNumIterations, MLIRContext &context, PatternBenefit benefit); +void addMultiSliceLICM(RewritePatternSet &patterns, bool single_user, + MLIRContext &context, PatternBenefit benefit); void addMultiRotateLICM(RewritePatternSet &patterns, bool single_user, MLIRContext &context, PatternBenefit benefit); diff --git a/src/enzyme_ad/jax/TransformOps/TransformOps.td b/src/enzyme_ad/jax/TransformOps/TransformOps.td index 090d091294..eba5fb7872 100644 --- a/src/enzyme_ad/jax/TransformOps/TransformOps.td +++ b/src/enzyme_ad/jax/TransformOps/TransformOps.td @@ -2172,6 +2172,11 @@ def TransposeRotate : EnzymeHLOPatternOp< let patterns = ["TransposeRotate"]; } +def ReduceUnusedMultiSlice : EnzymeHLOPatternOp< + "reduce_unused_multislice"> { + let patterns = ["ReduceUnusedMultiSlice"]; +} + def RecognizeMultiRotate : EnzymeHLOPatternOp< "recognize_multirotate"> { let patterns = ["RecognizeMultiRotate"]; diff --git a/test/lit_tests/reduce_unused_multi_slice.mlir b/test/lit_tests/reduce_unused_multi_slice.mlir new file mode 100644 index 0000000000..3f75c0b474 --- /dev/null +++ b/test/lit_tests/reduce_unused_multi_slice.mlir @@ -0,0 +1,140 @@ +// RUN: enzymexlamlir-opt --enzyme-hlo-generate-td="patterns=reduce_unused_multislice" --transform-interpreter --enzyme-hlo-remove-transform %s | FileCheck %s + +// Test 1: Only center result used - should become a regular slice +func.func @multi_slice_only_center(%arg0: tensor<20x24x80xf64>) -> tensor<1x8x72xf64> { + %0, %1, %2, %3, %4, %5 = "enzymexla.multi_slice"(%arg0) <{ + start_indices = array, + limit_indices = array, + strides = array, + dimension = 2 : si32, + left_amount = 2 : si32, + right_amount = 3 : si32 + }> : (tensor<20x24x80xf64>) -> (tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>) + return %2 : tensor<1x8x72xf64> +} + +// CHECK-LABEL: func.func @multi_slice_only_center( +// CHECK-SAME: %[[VAL_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: tensor<20x24x80xf64>) -> tensor<1x8x72xf64> { +// CHECK: %[[VAL_1:.*]] = stablehlo.slice %[[VAL_0]] [1:2, 0:8, 3:75] : (tensor<20x24x80xf64>) -> tensor<1x8x72xf64> +// CHECK: return %[[VAL_1]] : tensor<1x8x72xf64> +// CHECK: } + + +// Test 2: Only left-most result used - should become a regular slice +func.func @multi_slice_only_left(%arg0: tensor<20x24x80xf64>) -> tensor<1x8x72xf64> { + %0, %1, %2, %3, %4, %5 = "enzymexla.multi_slice"(%arg0) <{ + start_indices = array, + limit_indices = array, + strides = array, + dimension = 2 : si32, + left_amount = 2 : si32, + right_amount = 3 : si32 + }> : (tensor<20x24x80xf64>) -> (tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>) + return %0 : tensor<1x8x72xf64> +} + +// CHECK-LABEL: func.func @multi_slice_only_left( +// CHECK-SAME: %[[VAL_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: tensor<20x24x80xf64>) -> tensor<1x8x72xf64> { +// CHECK: %[[VAL_1:.*]] = stablehlo.slice %[[VAL_0]] [1:2, 0:8, 1:73] : (tensor<20x24x80xf64>) -> tensor<1x8x72xf64> +// CHECK: return %[[VAL_1]] : tensor<1x8x72xf64> +// CHECK: } + + +// Test 3: Only right-most result used - should become a regular slice +func.func @multi_slice_only_right(%arg0: tensor<20x24x80xf64>) -> tensor<1x8x72xf64> { + %0, %1, %2, %3, %4, %5 = "enzymexla.multi_slice"(%arg0) <{ + start_indices = array, + limit_indices = array, + strides = array, + dimension = 2 : si32, + left_amount = 2 : si32, + right_amount = 3 : si32 + }> : (tensor<20x24x80xf64>) -> (tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>) + return %5 : tensor<1x8x72xf64> +} + +// CHECK-LABEL: func.func @multi_slice_only_right( +// CHECK-SAME: %[[VAL_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: tensor<20x24x80xf64>) -> tensor<1x8x72xf64> { +// CHECK: %[[VAL_1:.*]] = stablehlo.slice %[[VAL_0]] [1:2, 0:8, 6:78] : (tensor<20x24x80xf64>) -> tensor<1x8x72xf64> +// CHECK: return %[[VAL_1]] : tensor<1x8x72xf64> +// CHECK: } + + +// Test 4: Two consecutive results used - should become smaller multi_slice +func.func @multi_slice_consecutive(%arg0: tensor<20x24x80xf64>) -> (tensor<1x8x72xf64>, tensor<1x8x72xf64>) { + %0, %1, %2, %3, %4, %5 = "enzymexla.multi_slice"(%arg0) <{ + start_indices = array, + limit_indices = array, + strides = array, + dimension = 2 : si32, + left_amount = 2 : si32, + right_amount = 3 : si32 + }> : (tensor<20x24x80xf64>) -> (tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>) + return %2, %3 : tensor<1x8x72xf64>, tensor<1x8x72xf64> +} + +// CHECK-LABEL: func.func @multi_slice_consecutive( +// CHECK-SAME: %[[VAL_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: tensor<20x24x80xf64>) -> (tensor<1x8x72xf64>, tensor<1x8x72xf64>) { +// CHECK: %[[VAL_1:.*]]:2 = "enzymexla.multi_slice"(%[[VAL_0]]) <{dimension = 2 : si32, left_amount = 0 : si32, limit_indices = array, right_amount = 1 : si32, start_indices = array, strides = array}> : (tensor<20x24x80xf64>) -> (tensor<1x8x72xf64>, tensor<1x8x72xf64>) +// CHECK: return %[[VAL_1]]#0, %[[VAL_1]]#1 : tensor<1x8x72xf64>, tensor<1x8x72xf64> +// CHECK: } + + +// Test 5: Non-contiguous results used - should keep range between first and last used +func.func @multi_slice_non_contiguous(%arg0: tensor<20x24x80xf64>) -> (tensor<1x8x72xf64>, tensor<1x8x72xf64>) { + %0, %1, %2, %3, %4, %5 = "enzymexla.multi_slice"(%arg0) <{ + start_indices = array, + limit_indices = array, + strides = array, + dimension = 2 : si32, + left_amount = 2 : si32, + right_amount = 3 : si32 + }> : (tensor<20x24x80xf64>) -> (tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>) + return %2, %5 : tensor<1x8x72xf64>, tensor<1x8x72xf64> +} + +// CHECK-LABEL: func.func @multi_slice_non_contiguous( +// CHECK-SAME: %[[VAL_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: tensor<20x24x80xf64>) -> (tensor<1x8x72xf64>, tensor<1x8x72xf64>) { +// CHECK: %[[VAL_1:.*]]:4 = "enzymexla.multi_slice"(%[[VAL_0]]) <{dimension = 2 : si32, left_amount = 0 : si32, limit_indices = array, right_amount = 3 : si32, start_indices = array, strides = array}> : (tensor<20x24x80xf64>) -> (tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>) +// CHECK: return %[[VAL_1]]#0, %[[VAL_1]]#3 : tensor<1x8x72xf64>, tensor<1x8x72xf64> +// CHECK: } + + +// Test 6: All results used - should not change +func.func @multi_slice_all_used(%arg0: tensor<20x24x80xf64>) -> (tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>) { + %0, %1, %2, %3, %4, %5 = "enzymexla.multi_slice"(%arg0) <{ + start_indices = array, + limit_indices = array, + strides = array, + dimension = 2 : si32, + left_amount = 2 : si32, + right_amount = 3 : si32 + }> : (tensor<20x24x80xf64>) -> (tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>) + return %0, %1, %2, %3, %4, %5 : tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64> +} + +// CHECK-LABEL: func.func @multi_slice_all_used( +// CHECK-SAME: %[[VAL_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: tensor<20x24x80xf64>) -> (tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>) { +// CHECK: %[[VAL_1:.*]]:6 = "enzymexla.multi_slice"(%[[VAL_0]]) <{dimension = 2 : si32, left_amount = 2 : si32, limit_indices = array, right_amount = 3 : si32, start_indices = array, strides = array}> : (tensor<20x24x80xf64>) -> (tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>) +// CHECK: return %[[VAL_1]]#0, %[[VAL_1]]#1, %[[VAL_1]]#2, %[[VAL_1]]#3, %[[VAL_1]]#4, %[[VAL_1]]#5 : tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64> +// CHECK: } + + +// Test 7: Different dimension - test on dimension 0 +func.func @multi_slice_dim0(%arg0: tensor<20x24x80xf64>) -> tensor<4x24x80xf64> { + %0, %1, %2, %3, %4 = "enzymexla.multi_slice"(%arg0) <{ + start_indices = array, + limit_indices = array, + strides = array, + dimension = 0 : si32, + left_amount = 2 : si32, + right_amount = 2 : si32 + }> : (tensor<20x24x80xf64>) -> (tensor<4x24x80xf64>, tensor<4x24x80xf64>, tensor<4x24x80xf64>, tensor<4x24x80xf64>, tensor<4x24x80xf64>) + return %2 : tensor<4x24x80xf64> +} + +// CHECK-LABEL: func.func @multi_slice_dim0( +// CHECK-SAME: %[[VAL_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: tensor<20x24x80xf64>) -> tensor<4x24x80xf64> { +// CHECK: %[[VAL_1:.*]] = stablehlo.slice %[[VAL_0]] [8:12, 0:24, 0:80] : (tensor<20x24x80xf64>) -> tensor<4x24x80xf64> +// CHECK: return %[[VAL_1]] : tensor<4x24x80xf64> +// CHECK: }