Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
134 changes: 134 additions & 0 deletions src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30304,6 +30304,133 @@ struct ReduceWindowWrapSimplify final
}
};

// Pattern to reduce MultiSliceOp when some results are unused
struct ReduceUnusedMultiSlice final
: CheckedOpRewritePattern<enzymexla::MultiSliceOp, ReduceUnusedMultiSlice> {
using CheckedOpRewritePattern::CheckedOpRewritePattern;

LogicalResult matchAndRewriteImpl(enzymexla::MultiSliceOp op,
PatternRewriter &rewriter) const {
int32_t leftAmount = op.getLeftAmount();
int32_t rightAmount = op.getRightAmount();
int32_t totalResults = leftAmount + rightAmount + 1;

// Check which results are actually used
SmallVector<bool> used(totalResults, false);
int usedCount = 0;
for (int i = 0; i < totalResults; i++) {
if (!op.getResult(i).use_empty()) {
used[i] = true;
usedCount++;
}
}

// If all results are used, nothing to optimize
if (usedCount == totalResults)
return failure();

// If no results are used, this should be handled by dead code elimination
if (usedCount == 0) {
rewriter.eraseOp(op);
return success();
}

// Find the range of used results
int firstUsed = -1, lastUsed = -1;
for (int i = 0; i < totalResults; i++) {
if (used[i]) {
if (firstUsed == -1)
firstUsed = i;
lastUsed = i;
}
}

// Calculate new left and right amounts
int centerIdx = leftAmount;
int newLeftAmount = centerIdx - firstUsed;
int newRightAmount = lastUsed - centerIdx;

// If only one result is used, replace with a single SliceOp
if (usedCount == 1) {
int usedIdx = firstUsed;
int offset = usedIdx - centerIdx; // How much to shift the slice

auto startIndices = SmallVector<int64_t>(op.getStartIndices());
auto limitIndices = SmallVector<int64_t>(op.getLimitIndices());
auto strides = SmallVector<int64_t>(op.getStrides());
int32_t dim = op.getDimension();

// Adjust start and limit indices for the offset
if (dim >= 0 && dim < (int64_t)startIndices.size()) {
startIndices[dim] += offset;
limitIndices[dim] += offset;
}

auto sliceOp = rewriter.create<stablehlo::SliceOp>(
op.getLoc(), op.getOperand(),
rewriter.getDenseI64ArrayAttr(startIndices),
rewriter.getDenseI64ArrayAttr(limitIndices),
rewriter.getDenseI64ArrayAttr(strides));
// Propagate sharding if present
if (auto shard = sdy::getShardingPerValue(op)) {
sdy::setShardings(sliceOp, shard);
}

rewriter.replaceAllUsesWith(op.getResult(usedIdx), sliceOp.getResult());
rewriter.eraseOp(op);
return success();
}

// Otherwise, create a smaller MultiSliceOp
if (newLeftAmount != leftAmount || newRightAmount != rightAmount) {
// Adjust start indices for the new center
int offset = firstUsed - centerIdx;
auto startIndices = SmallVector<int64_t>(op.getStartIndices());
auto limitIndices = SmallVector<int64_t>(op.getLimitIndices());
int32_t dim = op.getDimension();

if (dim >= 0 && dim < (int64_t)startIndices.size()) {
startIndices[dim] += offset;
limitIndices[dim] += offset;
}

// Determine result types
auto resultType = cast<RankedTensorType>(op.getResultTypes().front());
SmallVector<Type> resultTypes;
for (int i = 0; i < newLeftAmount + newRightAmount + 1; i++) {
resultTypes.push_back(resultType); // Will be properly typed by the op
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

make sur eto keep sharding

}

auto newOp = rewriter.create<enzymexla::MultiSliceOp>(
op.getLoc(), resultTypes, op.getOperand(), startIndices, limitIndices,
op.getStrides(), op.getDimension(), newLeftAmount, newRightAmount);
// Propagate sharding if present
if (auto shard = sdy::getShardingPerValue(op)) {
sdy::setShardings(newOp, shard);
}

// Map old results to new results
SmallVector<Value> replacements(totalResults);
int newIdx = 0;
for (int oldIdx = firstUsed; oldIdx <= lastUsed; oldIdx++) {
replacements[oldIdx] = newOp.getResult(newIdx++);
}

// Replace uses
for (int i = 0; i < totalResults; i++) {
if (used[i]) {
op.getResult(i).replaceAllUsesWith(replacements[i]);
}
}

rewriter.eraseOp(op);
return success();
}

return failure();
}
};

struct RecognizeMultiRotate
: public CheckedOpRewritePattern<enzymexla::RotateOp,
RecognizeMultiRotate> {
Expand Down Expand Up @@ -30955,6 +31082,13 @@ void mlir::transform::addExtendLICM(RewritePatternSet &patterns,
patterns.insert<LICM<enzymexla::ExtendOp>>(single_user, &context, benefit);
}

void mlir::transform::addMultiSliceLICM(RewritePatternSet &patterns,
bool single_user, MLIRContext &context,
PatternBenefit benefit) {
patterns.insert<LICM<enzymexla::MultiSliceOp>>(single_user, &context,
benefit);
}

void mlir::transform::addMultiRotateLICM(RewritePatternSet &patterns,
bool single_user, MLIRContext &context,
PatternBenefit benefit) {
Expand Down
2 changes: 2 additions & 0 deletions src/enzyme_ad/jax/Passes/EnzymeHLOPatterns.h
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,8 @@ void addSelfMulToConvolutionLike(RewritePatternSet &patterns,
MLIRContext &context, PatternBenefit benefit);
void addEnzymeHLOUnroll(RewritePatternSet &patterns, int64_t maxNumIterations,
MLIRContext &context, PatternBenefit benefit);
void addMultiSliceLICM(RewritePatternSet &patterns, bool single_user,
MLIRContext &context, PatternBenefit benefit);
void addMultiRotateLICM(RewritePatternSet &patterns, bool single_user,
MLIRContext &context, PatternBenefit benefit);

Expand Down
5 changes: 5 additions & 0 deletions src/enzyme_ad/jax/TransformOps/TransformOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -2172,6 +2172,11 @@ def TransposeRotate : EnzymeHLOPatternOp<
let patterns = ["TransposeRotate"];
}

def ReduceUnusedMultiSlice : EnzymeHLOPatternOp<
"reduce_unused_multislice"> {
let patterns = ["ReduceUnusedMultiSlice"];
}

def RecognizeMultiRotate : EnzymeHLOPatternOp<
"recognize_multirotate"> {
let patterns = ["RecognizeMultiRotate"];
Expand Down
140 changes: 140 additions & 0 deletions test/lit_tests/reduce_unused_multi_slice.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
// RUN: enzymexlamlir-opt --enzyme-hlo-generate-td="patterns=reduce_unused_multislice" --transform-interpreter --enzyme-hlo-remove-transform %s | FileCheck %s

// Test 1: Only center result used - should become a regular slice
func.func @multi_slice_only_center(%arg0: tensor<20x24x80xf64>) -> tensor<1x8x72xf64> {
%0, %1, %2, %3, %4, %5 = "enzymexla.multi_slice"(%arg0) <{
start_indices = array<i64: 1, 0, 3>,
limit_indices = array<i64: 2, 8, 75>,
strides = array<i64: 1, 1, 1>,
dimension = 2 : si32,
left_amount = 2 : si32,
right_amount = 3 : si32
}> : (tensor<20x24x80xf64>) -> (tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>)
return %2 : tensor<1x8x72xf64>
}

// CHECK-LABEL: func.func @multi_slice_only_center(
// CHECK-SAME: %[[VAL_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: tensor<20x24x80xf64>) -> tensor<1x8x72xf64> {
// CHECK: %[[VAL_1:.*]] = stablehlo.slice %[[VAL_0]] [1:2, 0:8, 3:75] : (tensor<20x24x80xf64>) -> tensor<1x8x72xf64>
// CHECK: return %[[VAL_1]] : tensor<1x8x72xf64>
// CHECK: }


// Test 2: Only left-most result used - should become a regular slice
func.func @multi_slice_only_left(%arg0: tensor<20x24x80xf64>) -> tensor<1x8x72xf64> {
%0, %1, %2, %3, %4, %5 = "enzymexla.multi_slice"(%arg0) <{
start_indices = array<i64: 1, 0, 3>,
limit_indices = array<i64: 2, 8, 75>,
strides = array<i64: 1, 1, 1>,
dimension = 2 : si32,
left_amount = 2 : si32,
right_amount = 3 : si32
}> : (tensor<20x24x80xf64>) -> (tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>)
return %0 : tensor<1x8x72xf64>
}

// CHECK-LABEL: func.func @multi_slice_only_left(
// CHECK-SAME: %[[VAL_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: tensor<20x24x80xf64>) -> tensor<1x8x72xf64> {
// CHECK: %[[VAL_1:.*]] = stablehlo.slice %[[VAL_0]] [1:2, 0:8, 1:73] : (tensor<20x24x80xf64>) -> tensor<1x8x72xf64>
// CHECK: return %[[VAL_1]] : tensor<1x8x72xf64>
// CHECK: }


// Test 3: Only right-most result used - should become a regular slice
func.func @multi_slice_only_right(%arg0: tensor<20x24x80xf64>) -> tensor<1x8x72xf64> {
%0, %1, %2, %3, %4, %5 = "enzymexla.multi_slice"(%arg0) <{
start_indices = array<i64: 1, 0, 3>,
limit_indices = array<i64: 2, 8, 75>,
strides = array<i64: 1, 1, 1>,
dimension = 2 : si32,
left_amount = 2 : si32,
right_amount = 3 : si32
}> : (tensor<20x24x80xf64>) -> (tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>)
return %5 : tensor<1x8x72xf64>
}

// CHECK-LABEL: func.func @multi_slice_only_right(
// CHECK-SAME: %[[VAL_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: tensor<20x24x80xf64>) -> tensor<1x8x72xf64> {
// CHECK: %[[VAL_1:.*]] = stablehlo.slice %[[VAL_0]] [1:2, 0:8, 6:78] : (tensor<20x24x80xf64>) -> tensor<1x8x72xf64>
// CHECK: return %[[VAL_1]] : tensor<1x8x72xf64>
// CHECK: }


// Test 4: Two consecutive results used - should become smaller multi_slice
func.func @multi_slice_consecutive(%arg0: tensor<20x24x80xf64>) -> (tensor<1x8x72xf64>, tensor<1x8x72xf64>) {
%0, %1, %2, %3, %4, %5 = "enzymexla.multi_slice"(%arg0) <{
start_indices = array<i64: 1, 0, 3>,
limit_indices = array<i64: 2, 8, 75>,
strides = array<i64: 1, 1, 1>,
dimension = 2 : si32,
left_amount = 2 : si32,
right_amount = 3 : si32
}> : (tensor<20x24x80xf64>) -> (tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>)
return %2, %3 : tensor<1x8x72xf64>, tensor<1x8x72xf64>
}

// CHECK-LABEL: func.func @multi_slice_consecutive(
// CHECK-SAME: %[[VAL_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: tensor<20x24x80xf64>) -> (tensor<1x8x72xf64>, tensor<1x8x72xf64>) {
// CHECK: %[[VAL_1:.*]]:2 = "enzymexla.multi_slice"(%[[VAL_0]]) <{dimension = 2 : si32, left_amount = 0 : si32, limit_indices = array<i64: 2, 8, 75>, right_amount = 1 : si32, start_indices = array<i64: 1, 0, 3>, strides = array<i64: 1, 1, 1>}> : (tensor<20x24x80xf64>) -> (tensor<1x8x72xf64>, tensor<1x8x72xf64>)
// CHECK: return %[[VAL_1]]#0, %[[VAL_1]]#1 : tensor<1x8x72xf64>, tensor<1x8x72xf64>
// CHECK: }


// Test 5: Non-contiguous results used - should keep range between first and last used
func.func @multi_slice_non_contiguous(%arg0: tensor<20x24x80xf64>) -> (tensor<1x8x72xf64>, tensor<1x8x72xf64>) {
%0, %1, %2, %3, %4, %5 = "enzymexla.multi_slice"(%arg0) <{
start_indices = array<i64: 1, 0, 3>,
limit_indices = array<i64: 2, 8, 75>,
strides = array<i64: 1, 1, 1>,
dimension = 2 : si32,
left_amount = 2 : si32,
right_amount = 3 : si32
}> : (tensor<20x24x80xf64>) -> (tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>)
return %2, %5 : tensor<1x8x72xf64>, tensor<1x8x72xf64>
}

// CHECK-LABEL: func.func @multi_slice_non_contiguous(
// CHECK-SAME: %[[VAL_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: tensor<20x24x80xf64>) -> (tensor<1x8x72xf64>, tensor<1x8x72xf64>) {
// CHECK: %[[VAL_1:.*]]:4 = "enzymexla.multi_slice"(%[[VAL_0]]) <{dimension = 2 : si32, left_amount = 0 : si32, limit_indices = array<i64: 2, 8, 75>, right_amount = 3 : si32, start_indices = array<i64: 1, 0, 3>, strides = array<i64: 1, 1, 1>}> : (tensor<20x24x80xf64>) -> (tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>)
// CHECK: return %[[VAL_1]]#0, %[[VAL_1]]#3 : tensor<1x8x72xf64>, tensor<1x8x72xf64>
// CHECK: }


// Test 6: All results used - should not change
func.func @multi_slice_all_used(%arg0: tensor<20x24x80xf64>) -> (tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>) {
%0, %1, %2, %3, %4, %5 = "enzymexla.multi_slice"(%arg0) <{
start_indices = array<i64: 1, 0, 3>,
limit_indices = array<i64: 2, 8, 75>,
strides = array<i64: 1, 1, 1>,
dimension = 2 : si32,
left_amount = 2 : si32,
right_amount = 3 : si32
}> : (tensor<20x24x80xf64>) -> (tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>)
return %0, %1, %2, %3, %4, %5 : tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>
}

// CHECK-LABEL: func.func @multi_slice_all_used(
// CHECK-SAME: %[[VAL_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: tensor<20x24x80xf64>) -> (tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>) {
// CHECK: %[[VAL_1:.*]]:6 = "enzymexla.multi_slice"(%[[VAL_0]]) <{dimension = 2 : si32, left_amount = 2 : si32, limit_indices = array<i64: 2, 8, 75>, right_amount = 3 : si32, start_indices = array<i64: 1, 0, 3>, strides = array<i64: 1, 1, 1>}> : (tensor<20x24x80xf64>) -> (tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>)
// CHECK: return %[[VAL_1]]#0, %[[VAL_1]]#1, %[[VAL_1]]#2, %[[VAL_1]]#3, %[[VAL_1]]#4, %[[VAL_1]]#5 : tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>
// CHECK: }


// Test 7: Different dimension - test on dimension 0
func.func @multi_slice_dim0(%arg0: tensor<20x24x80xf64>) -> tensor<4x24x80xf64> {
%0, %1, %2, %3, %4 = "enzymexla.multi_slice"(%arg0) <{
start_indices = array<i64: 8, 0, 0>,
limit_indices = array<i64: 12, 24, 80>,
strides = array<i64: 1, 1, 1>,
dimension = 0 : si32,
left_amount = 2 : si32,
right_amount = 2 : si32
}> : (tensor<20x24x80xf64>) -> (tensor<4x24x80xf64>, tensor<4x24x80xf64>, tensor<4x24x80xf64>, tensor<4x24x80xf64>, tensor<4x24x80xf64>)
return %2 : tensor<4x24x80xf64>
}

// CHECK-LABEL: func.func @multi_slice_dim0(
// CHECK-SAME: %[[VAL_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: tensor<20x24x80xf64>) -> tensor<4x24x80xf64> {
// CHECK: %[[VAL_1:.*]] = stablehlo.slice %[[VAL_0]] [8:12, 0:24, 0:80] : (tensor<20x24x80xf64>) -> tensor<4x24x80xf64>
// CHECK: return %[[VAL_1]] : tensor<4x24x80xf64>
// CHECK: }
Loading