Skip to content
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
134 changes: 134 additions & 0 deletions src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30293,6 +30293,133 @@ struct ReduceWindowWrapSimplify final
}
};

// Pattern to reduce MultiSliceOp when some results are unused
struct ReduceUnusedMultiSlice final
: CheckedOpRewritePattern<enzymexla::MultiSliceOp, ReduceUnusedMultiSlice> {
using CheckedOpRewritePattern::CheckedOpRewritePattern;

LogicalResult matchAndRewriteImpl(enzymexla::MultiSliceOp op,
PatternRewriter &rewriter) const {
int32_t leftAmount = op.getLeftAmount();
int32_t rightAmount = op.getRightAmount();
int32_t totalResults = leftAmount + rightAmount + 1;

// Check which results are actually used
SmallVector<bool> used(totalResults, false);
int usedCount = 0;
for (int i = 0; i < totalResults; i++) {
if (!op.getResult(i).use_empty()) {
used[i] = true;
usedCount++;
}
}

// If all results are used, nothing to optimize
if (usedCount == totalResults)
return failure();

// If no results are used, this should be handled by dead code elimination
if (usedCount == 0) {
rewriter.eraseOp(op);
return success();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also might as well do deletion

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same success comment

}

// Find the range of used results
int firstUsed = -1, lastUsed = -1;
for (int i = 0; i < totalResults; i++) {
if (used[i]) {
if (firstUsed == -1)
firstUsed = i;
lastUsed = i;
}
}

// Calculate new left and right amounts
int centerIdx = leftAmount;
int newLeftAmount = centerIdx - firstUsed;
int newRightAmount = lastUsed - centerIdx;

// If only one result is used, replace with a single SliceOp
if (usedCount == 1) {
int usedIdx = firstUsed;
int offset = usedIdx - centerIdx; // How much to shift the slice

auto startIndices = SmallVector<int64_t>(op.getStartIndices());
auto limitIndices = SmallVector<int64_t>(op.getLimitIndices());
auto strides = SmallVector<int64_t>(op.getStrides());
int32_t dim = op.getDimension();

// Adjust start and limit indices for the offset
if (dim >= 0 && dim < (int64_t)startIndices.size()) {
startIndices[dim] += offset;
limitIndices[dim] += offset;
}

auto sliceOp = rewriter.create<stablehlo::SliceOp>(
op.getLoc(), op.getOperand(),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same replaceOpWithNewOp comment [with the sharding comment] from rotate

rewriter.getDenseI64ArrayAttr(startIndices),
rewriter.getDenseI64ArrayAttr(limitIndices),
rewriter.getDenseI64ArrayAttr(strides));
// Propagate sharding if present
if (auto shard = sdy::getShardingPerValue(op)) {
sdy::setShardings(sliceOp, shard);
}

rewriter.replaceAllUsesWith(op.getResult(usedIdx), sliceOp.getResult());
rewriter.eraseOp(op);
return success();
}

// Otherwise, create a smaller MultiSliceOp
if (newLeftAmount != leftAmount || newRightAmount != rightAmount) {
// Adjust start indices for the new center
int offset = firstUsed - centerIdx;
auto startIndices = SmallVector<int64_t>(op.getStartIndices());
auto limitIndices = SmallVector<int64_t>(op.getLimitIndices());
int32_t dim = op.getDimension();

if (dim >= 0 && dim < (int64_t)startIndices.size()) {
startIndices[dim] += offset;
limitIndices[dim] += offset;
}

// Determine result types
auto resultType = cast<RankedTensorType>(op.getResultTypes().front());
SmallVector<Type> resultTypes;
for (int i = 0; i < newLeftAmount + newRightAmount + 1; i++) {
resultTypes.push_back(resultType); // Will be properly typed by the op
}

auto newOp = rewriter.create<enzymexla::MultiSliceOp>(
op.getLoc(), resultTypes, op.getOperand(), startIndices, limitIndices,
op.getStrides(), op.getDimension(), newLeftAmount, newRightAmount);
// Propagate sharding if present
if (auto shard = sdy::getShardingPerValue(op)) {
sdy::setShardings(newOp, shard);
}

// Map old results to new results
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

make sur eto keep sharding

SmallVector<Value> replacements(totalResults);
int newIdx = 0;
for (int oldIdx = firstUsed; oldIdx <= lastUsed; oldIdx++) {
replacements[oldIdx] = newOp.getResult(newIdx++);
}

// Replace uses
for (int i = 0; i < totalResults; i++) {
if (used[i]) {
op.getResult(i).replaceAllUsesWith(replacements[i]);
}
}

rewriter.eraseOp(op);
return success();
}

return failure();
}
};

struct ScatterOpCanon final
: CheckedOpRewritePattern<stablehlo::ScatterOp, ScatterOpCanon> {
using CheckedOpRewritePattern::CheckedOpRewritePattern;
Expand Down Expand Up @@ -30692,6 +30819,13 @@ void mlir::transform::addExtendLICM(RewritePatternSet &patterns,
patterns.insert<LICM<enzymexla::ExtendOp>>(single_user, &context, benefit);
}

void mlir::transform::addMultiSliceLICM(RewritePatternSet &patterns,
bool single_user, MLIRContext &context,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comment here on do we need this defn?

PatternBenefit benefit) {
patterns.insert<LICM<enzymexla::MultiSliceOp>>(single_user, &context,
benefit);
}

void mlir::transform::addElementwiseLICM(RewritePatternSet &patterns,
bool single_user, MLIRContext &context,
PatternBenefit benefit) {
Expand Down
2 changes: 2 additions & 0 deletions src/enzyme_ad/jax/Passes/EnzymeHLOPatterns.h
Original file line number Diff line number Diff line change
Expand Up @@ -134,5 +134,7 @@ void addSelfMulToConvolutionLike(RewritePatternSet &patterns,
MLIRContext &context, PatternBenefit benefit);
void addEnzymeHLOUnroll(RewritePatternSet &patterns, int64_t maxNumIterations,
MLIRContext &context, PatternBenefit benefit);
void addMultiSliceLICM(RewritePatternSet &patterns, bool single_user,
MLIRContext &context, PatternBenefit benefit);

} // namespace mlir::transform
5 changes: 5 additions & 0 deletions src/enzyme_ad/jax/TransformOps/TransformOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -2172,6 +2172,11 @@ def TransposeRotate : EnzymeHLOPatternOp<
let patterns = ["TransposeRotate"];
}

def ReduceUnusedMultiSlice : EnzymeHLOPatternOp<
"reduce_unused_multislice"> {
let patterns = ["ReduceUnusedMultiSlice"];
}

def SelectPad : EnzymeHLOPatternOp<
"select_pad"> {
let patterns = ["SelectPad"];
Expand Down
140 changes: 140 additions & 0 deletions test/lit_tests/reduce_unused_multi_slice.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
// RUN: enzymexlamlir-opt --enzyme-hlo-generate-td="patterns=reduce_unused_multislice" --transform-interpreter --enzyme-hlo-remove-transform %s | FileCheck %s

// Test 1: Only center result used - should become a regular slice
func.func @multi_slice_only_center(%arg0: tensor<20x24x80xf64>) -> tensor<1x8x72xf64> {
%0, %1, %2, %3, %4, %5 = "enzymexla.multi_slice"(%arg0) <{
start_indices = array<i64: 1, 0, 3>,
limit_indices = array<i64: 2, 8, 75>,
strides = array<i64: 1, 1, 1>,
dimension = 2 : si32,
left_amount = 2 : si32,
right_amount = 3 : si32
}> : (tensor<20x24x80xf64>) -> (tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>)
return %2 : tensor<1x8x72xf64>
}

// CHECK-LABEL: func.func @multi_slice_only_center(
// CHECK-SAME: %[[VAL_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: tensor<20x24x80xf64>) -> tensor<1x8x72xf64> {
// CHECK: %[[VAL_1:.*]] = stablehlo.slice %[[VAL_0]] [1:2, 0:8, 3:75] : (tensor<20x24x80xf64>) -> tensor<1x8x72xf64>
// CHECK: return %[[VAL_1]] : tensor<1x8x72xf64>
// CHECK: }


// Test 2: Only left-most result used - should become a regular slice
func.func @multi_slice_only_left(%arg0: tensor<20x24x80xf64>) -> tensor<1x8x72xf64> {
%0, %1, %2, %3, %4, %5 = "enzymexla.multi_slice"(%arg0) <{
start_indices = array<i64: 1, 0, 3>,
limit_indices = array<i64: 2, 8, 75>,
strides = array<i64: 1, 1, 1>,
dimension = 2 : si32,
left_amount = 2 : si32,
right_amount = 3 : si32
}> : (tensor<20x24x80xf64>) -> (tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>)
return %0 : tensor<1x8x72xf64>
}

// CHECK-LABEL: func.func @multi_slice_only_left(
// CHECK-SAME: %[[VAL_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: tensor<20x24x80xf64>) -> tensor<1x8x72xf64> {
// CHECK: %[[VAL_1:.*]] = stablehlo.slice %[[VAL_0]] [1:2, 0:8, 1:73] : (tensor<20x24x80xf64>) -> tensor<1x8x72xf64>
// CHECK: return %[[VAL_1]] : tensor<1x8x72xf64>
// CHECK: }


// Test 3: Only right-most result used - should become a regular slice
func.func @multi_slice_only_right(%arg0: tensor<20x24x80xf64>) -> tensor<1x8x72xf64> {
%0, %1, %2, %3, %4, %5 = "enzymexla.multi_slice"(%arg0) <{
start_indices = array<i64: 1, 0, 3>,
limit_indices = array<i64: 2, 8, 75>,
strides = array<i64: 1, 1, 1>,
dimension = 2 : si32,
left_amount = 2 : si32,
right_amount = 3 : si32
}> : (tensor<20x24x80xf64>) -> (tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>)
return %5 : tensor<1x8x72xf64>
}

// CHECK-LABEL: func.func @multi_slice_only_right(
// CHECK-SAME: %[[VAL_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: tensor<20x24x80xf64>) -> tensor<1x8x72xf64> {
// CHECK: %[[VAL_1:.*]] = stablehlo.slice %[[VAL_0]] [1:2, 0:8, 6:78] : (tensor<20x24x80xf64>) -> tensor<1x8x72xf64>
// CHECK: return %[[VAL_1]] : tensor<1x8x72xf64>
// CHECK: }


// Test 4: Two consecutive results used - should become smaller multi_slice
func.func @multi_slice_consecutive(%arg0: tensor<20x24x80xf64>) -> (tensor<1x8x72xf64>, tensor<1x8x72xf64>) {
%0, %1, %2, %3, %4, %5 = "enzymexla.multi_slice"(%arg0) <{
start_indices = array<i64: 1, 0, 3>,
limit_indices = array<i64: 2, 8, 75>,
strides = array<i64: 1, 1, 1>,
dimension = 2 : si32,
left_amount = 2 : si32,
right_amount = 3 : si32
}> : (tensor<20x24x80xf64>) -> (tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>)
return %2, %3 : tensor<1x8x72xf64>, tensor<1x8x72xf64>
}

// CHECK-LABEL: func.func @multi_slice_consecutive(
// CHECK-SAME: %[[VAL_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: tensor<20x24x80xf64>) -> (tensor<1x8x72xf64>, tensor<1x8x72xf64>) {
// CHECK: %[[VAL_1:.*]]:2 = "enzymexla.multi_slice"(%[[VAL_0]]) <{dimension = 2 : si32, left_amount = 0 : si32, limit_indices = array<i64: 2, 8, 75>, right_amount = 1 : si32, start_indices = array<i64: 1, 0, 3>, strides = array<i64: 1, 1, 1>}> : (tensor<20x24x80xf64>) -> (tensor<1x8x72xf64>, tensor<1x8x72xf64>)
// CHECK: return %[[VAL_1]]#0, %[[VAL_1]]#1 : tensor<1x8x72xf64>, tensor<1x8x72xf64>
// CHECK: }


// Test 5: Non-contiguous results used - should keep range between first and last used
func.func @multi_slice_non_contiguous(%arg0: tensor<20x24x80xf64>) -> (tensor<1x8x72xf64>, tensor<1x8x72xf64>) {
%0, %1, %2, %3, %4, %5 = "enzymexla.multi_slice"(%arg0) <{
start_indices = array<i64: 1, 0, 3>,
limit_indices = array<i64: 2, 8, 75>,
strides = array<i64: 1, 1, 1>,
dimension = 2 : si32,
left_amount = 2 : si32,
right_amount = 3 : si32
}> : (tensor<20x24x80xf64>) -> (tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>)
return %2, %5 : tensor<1x8x72xf64>, tensor<1x8x72xf64>
}

// CHECK-LABEL: func.func @multi_slice_non_contiguous(
// CHECK-SAME: %[[VAL_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: tensor<20x24x80xf64>) -> (tensor<1x8x72xf64>, tensor<1x8x72xf64>) {
// CHECK: %[[VAL_1:.*]]:4 = "enzymexla.multi_slice"(%[[VAL_0]]) <{dimension = 2 : si32, left_amount = 0 : si32, limit_indices = array<i64: 2, 8, 75>, right_amount = 3 : si32, start_indices = array<i64: 1, 0, 3>, strides = array<i64: 1, 1, 1>}> : (tensor<20x24x80xf64>) -> (tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>)
// CHECK: return %[[VAL_1]]#0, %[[VAL_1]]#3 : tensor<1x8x72xf64>, tensor<1x8x72xf64>
// CHECK: }


// Test 6: All results used - should not change
func.func @multi_slice_all_used(%arg0: tensor<20x24x80xf64>) -> (tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>) {
%0, %1, %2, %3, %4, %5 = "enzymexla.multi_slice"(%arg0) <{
start_indices = array<i64: 1, 0, 3>,
limit_indices = array<i64: 2, 8, 75>,
strides = array<i64: 1, 1, 1>,
dimension = 2 : si32,
left_amount = 2 : si32,
right_amount = 3 : si32
}> : (tensor<20x24x80xf64>) -> (tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>)
return %0, %1, %2, %3, %4, %5 : tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>
}

// CHECK-LABEL: func.func @multi_slice_all_used(
// CHECK-SAME: %[[VAL_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: tensor<20x24x80xf64>) -> (tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>) {
// CHECK: %[[VAL_1:.*]]:6 = "enzymexla.multi_slice"(%[[VAL_0]]) <{dimension = 2 : si32, left_amount = 2 : si32, limit_indices = array<i64: 2, 8, 75>, right_amount = 3 : si32, start_indices = array<i64: 1, 0, 3>, strides = array<i64: 1, 1, 1>}> : (tensor<20x24x80xf64>) -> (tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>)
// CHECK: return %[[VAL_1]]#0, %[[VAL_1]]#1, %[[VAL_1]]#2, %[[VAL_1]]#3, %[[VAL_1]]#4, %[[VAL_1]]#5 : tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>
// CHECK: }


// Test 7: Different dimension - test on dimension 0
func.func @multi_slice_dim0(%arg0: tensor<20x24x80xf64>) -> tensor<4x24x80xf64> {
%0, %1, %2, %3, %4 = "enzymexla.multi_slice"(%arg0) <{
start_indices = array<i64: 8, 0, 0>,
limit_indices = array<i64: 12, 24, 80>,
strides = array<i64: 1, 1, 1>,
dimension = 0 : si32,
left_amount = 2 : si32,
right_amount = 2 : si32
}> : (tensor<20x24x80xf64>) -> (tensor<4x24x80xf64>, tensor<4x24x80xf64>, tensor<4x24x80xf64>, tensor<4x24x80xf64>, tensor<4x24x80xf64>)
return %2 : tensor<4x24x80xf64>
}

// CHECK-LABEL: func.func @multi_slice_dim0(
// CHECK-SAME: %[[VAL_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: tensor<20x24x80xf64>) -> tensor<4x24x80xf64> {
// CHECK: %[[VAL_1:.*]] = stablehlo.slice %[[VAL_0]] [8:12, 0:24, 0:80] : (tensor<20x24x80xf64>) -> tensor<4x24x80xf64>
// CHECK: return %[[VAL_1]] : tensor<4x24x80xf64>
// CHECK: }