Skip to content

Commit bdc475a

Browse files
jumerckxwsmoses
andauthored
Add multislice simplification (#1958)
* ReduceUnusedMultiSlice * filecheck test * immediately erase * keep sharding * fixes Co-authored-by: William S. Moses <gh@wsmoses.com> * fix * revert use of replaceOpWithNewOp * remove addMultiSliceOpt * fix merge mistake --------- Co-authored-by: William Moses <wmoses@google.com> Co-authored-by: William S. Moses <gh@wsmoses.com>
1 parent bb071ab commit bdc475a

File tree

4 files changed

+281
-0
lines changed

4 files changed

+281
-0
lines changed

src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp

Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30304,6 +30304,133 @@ struct ReduceWindowWrapSimplify final
3030430304
}
3030530305
};
3030630306

30307+
// Pattern to reduce MultiSliceOp when some results are unused
30308+
struct ReduceUnusedMultiSlice final
30309+
: CheckedOpRewritePattern<enzymexla::MultiSliceOp, ReduceUnusedMultiSlice> {
30310+
using CheckedOpRewritePattern::CheckedOpRewritePattern;
30311+
30312+
LogicalResult matchAndRewriteImpl(enzymexla::MultiSliceOp op,
30313+
PatternRewriter &rewriter) const {
30314+
int32_t leftAmount = op.getLeftAmount();
30315+
int32_t rightAmount = op.getRightAmount();
30316+
int32_t totalResults = leftAmount + rightAmount + 1;
30317+
30318+
// Check which results are actually used
30319+
SmallVector<bool> used(totalResults, false);
30320+
int usedCount = 0;
30321+
for (int i = 0; i < totalResults; i++) {
30322+
if (!op.getResult(i).use_empty()) {
30323+
used[i] = true;
30324+
usedCount++;
30325+
}
30326+
}
30327+
30328+
// If all results are used, nothing to optimize
30329+
if (usedCount == totalResults)
30330+
return failure();
30331+
30332+
// If no results are used, this should be handled by dead code elimination
30333+
if (usedCount == 0) {
30334+
rewriter.eraseOp(op);
30335+
return success();
30336+
}
30337+
30338+
// Find the range of used results
30339+
int firstUsed = -1, lastUsed = -1;
30340+
for (int i = 0; i < totalResults; i++) {
30341+
if (used[i]) {
30342+
if (firstUsed == -1)
30343+
firstUsed = i;
30344+
lastUsed = i;
30345+
}
30346+
}
30347+
30348+
// Calculate new left and right amounts
30349+
int centerIdx = leftAmount;
30350+
int newLeftAmount = centerIdx - firstUsed;
30351+
int newRightAmount = lastUsed - centerIdx;
30352+
30353+
// If only one result is used, replace with a single SliceOp
30354+
if (usedCount == 1) {
30355+
int usedIdx = firstUsed;
30356+
int offset = usedIdx - centerIdx; // How much to shift the slice
30357+
30358+
auto startIndices = SmallVector<int64_t>(op.getStartIndices());
30359+
auto limitIndices = SmallVector<int64_t>(op.getLimitIndices());
30360+
auto strides = SmallVector<int64_t>(op.getStrides());
30361+
int32_t dim = op.getDimension();
30362+
30363+
// Adjust start and limit indices for the offset
30364+
if (dim >= 0 && dim < (int64_t)startIndices.size()) {
30365+
startIndices[dim] += offset;
30366+
limitIndices[dim] += offset;
30367+
}
30368+
30369+
auto sliceOp = rewriter.create<stablehlo::SliceOp>(
30370+
op.getLoc(), op.getOperand(),
30371+
rewriter.getDenseI64ArrayAttr(startIndices),
30372+
rewriter.getDenseI64ArrayAttr(limitIndices),
30373+
rewriter.getDenseI64ArrayAttr(strides));
30374+
// Propagate sharding if present
30375+
if (auto shard = sdy::getShardingPerValue(op)) {
30376+
sdy::setShardings(sliceOp, shard);
30377+
}
30378+
30379+
rewriter.replaceAllUsesWith(op.getResult(usedIdx), sliceOp.getResult());
30380+
rewriter.eraseOp(op);
30381+
return success();
30382+
}
30383+
30384+
// Otherwise, create a smaller MultiSliceOp
30385+
if (newLeftAmount != leftAmount || newRightAmount != rightAmount) {
30386+
// Adjust start indices for the new center
30387+
int offset = firstUsed - centerIdx;
30388+
auto startIndices = SmallVector<int64_t>(op.getStartIndices());
30389+
auto limitIndices = SmallVector<int64_t>(op.getLimitIndices());
30390+
int32_t dim = op.getDimension();
30391+
30392+
if (dim >= 0 && dim < (int64_t)startIndices.size()) {
30393+
startIndices[dim] += offset;
30394+
limitIndices[dim] += offset;
30395+
}
30396+
30397+
// Determine result types
30398+
auto resultType = cast<RankedTensorType>(op.getResultTypes().front());
30399+
SmallVector<Type> resultTypes;
30400+
for (int i = 0; i < newLeftAmount + newRightAmount + 1; i++) {
30401+
resultTypes.push_back(resultType); // Will be properly typed by the op
30402+
}
30403+
30404+
auto newOp = rewriter.create<enzymexla::MultiSliceOp>(
30405+
op.getLoc(), resultTypes, op.getOperand(), startIndices, limitIndices,
30406+
op.getStrides(), op.getDimension(), newLeftAmount, newRightAmount);
30407+
// Propagate sharding if present
30408+
if (auto shard = sdy::getShardingPerValue(op)) {
30409+
sdy::setShardings(newOp, shard);
30410+
}
30411+
30412+
// Map old results to new results
30413+
SmallVector<Value> replacements(totalResults);
30414+
int newIdx = 0;
30415+
for (int oldIdx = firstUsed; oldIdx <= lastUsed; oldIdx++) {
30416+
replacements[oldIdx] = newOp.getResult(newIdx++);
30417+
}
30418+
30419+
// Replace uses
30420+
for (int i = 0; i < totalResults; i++) {
30421+
if (used[i]) {
30422+
op.getResult(i).replaceAllUsesWith(replacements[i]);
30423+
}
30424+
}
30425+
30426+
rewriter.eraseOp(op);
30427+
return success();
30428+
}
30429+
30430+
return failure();
30431+
}
30432+
};
30433+
3030730434
struct RecognizeMultiRotate
3030830435
: public CheckedOpRewritePattern<enzymexla::RotateOp,
3030930436
RecognizeMultiRotate> {
@@ -30955,6 +31082,13 @@ void mlir::transform::addExtendLICM(RewritePatternSet &patterns,
3095531082
patterns.insert<LICM<enzymexla::ExtendOp>>(single_user, &context, benefit);
3095631083
}
3095731084

31085+
void mlir::transform::addMultiSliceLICM(RewritePatternSet &patterns,
31086+
bool single_user, MLIRContext &context,
31087+
PatternBenefit benefit) {
31088+
patterns.insert<LICM<enzymexla::MultiSliceOp>>(single_user, &context,
31089+
benefit);
31090+
}
31091+
3095831092
void mlir::transform::addMultiRotateLICM(RewritePatternSet &patterns,
3095931093
bool single_user, MLIRContext &context,
3096031094
PatternBenefit benefit) {

src/enzyme_ad/jax/Passes/EnzymeHLOPatterns.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,8 @@ void addSelfMulToConvolutionLike(RewritePatternSet &patterns,
134134
MLIRContext &context, PatternBenefit benefit);
135135
void addEnzymeHLOUnroll(RewritePatternSet &patterns, int64_t maxNumIterations,
136136
MLIRContext &context, PatternBenefit benefit);
137+
void addMultiSliceLICM(RewritePatternSet &patterns, bool single_user,
138+
MLIRContext &context, PatternBenefit benefit);
137139
void addMultiRotateLICM(RewritePatternSet &patterns, bool single_user,
138140
MLIRContext &context, PatternBenefit benefit);
139141

src/enzyme_ad/jax/TransformOps/TransformOps.td

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2172,6 +2172,11 @@ def TransposeRotate : EnzymeHLOPatternOp<
21722172
let patterns = ["TransposeRotate"];
21732173
}
21742174

2175+
def ReduceUnusedMultiSlice : EnzymeHLOPatternOp<
2176+
"reduce_unused_multislice"> {
2177+
let patterns = ["ReduceUnusedMultiSlice"];
2178+
}
2179+
21752180
def RecognizeMultiRotate : EnzymeHLOPatternOp<
21762181
"recognize_multirotate"> {
21772182
let patterns = ["RecognizeMultiRotate"];
Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,140 @@
1+
// RUN: enzymexlamlir-opt --enzyme-hlo-generate-td="patterns=reduce_unused_multislice" --transform-interpreter --enzyme-hlo-remove-transform %s | FileCheck %s
2+
3+
// Test 1: Only center result used - should become a regular slice
4+
func.func @multi_slice_only_center(%arg0: tensor<20x24x80xf64>) -> tensor<1x8x72xf64> {
5+
%0, %1, %2, %3, %4, %5 = "enzymexla.multi_slice"(%arg0) <{
6+
start_indices = array<i64: 1, 0, 3>,
7+
limit_indices = array<i64: 2, 8, 75>,
8+
strides = array<i64: 1, 1, 1>,
9+
dimension = 2 : si32,
10+
left_amount = 2 : si32,
11+
right_amount = 3 : si32
12+
}> : (tensor<20x24x80xf64>) -> (tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>)
13+
return %2 : tensor<1x8x72xf64>
14+
}
15+
16+
// CHECK-LABEL: func.func @multi_slice_only_center(
17+
// CHECK-SAME: %[[VAL_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: tensor<20x24x80xf64>) -> tensor<1x8x72xf64> {
18+
// CHECK: %[[VAL_1:.*]] = stablehlo.slice %[[VAL_0]] [1:2, 0:8, 3:75] : (tensor<20x24x80xf64>) -> tensor<1x8x72xf64>
19+
// CHECK: return %[[VAL_1]] : tensor<1x8x72xf64>
20+
// CHECK: }
21+
22+
23+
// Test 2: Only left-most result used - should become a regular slice
24+
func.func @multi_slice_only_left(%arg0: tensor<20x24x80xf64>) -> tensor<1x8x72xf64> {
25+
%0, %1, %2, %3, %4, %5 = "enzymexla.multi_slice"(%arg0) <{
26+
start_indices = array<i64: 1, 0, 3>,
27+
limit_indices = array<i64: 2, 8, 75>,
28+
strides = array<i64: 1, 1, 1>,
29+
dimension = 2 : si32,
30+
left_amount = 2 : si32,
31+
right_amount = 3 : si32
32+
}> : (tensor<20x24x80xf64>) -> (tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>)
33+
return %0 : tensor<1x8x72xf64>
34+
}
35+
36+
// CHECK-LABEL: func.func @multi_slice_only_left(
37+
// CHECK-SAME: %[[VAL_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: tensor<20x24x80xf64>) -> tensor<1x8x72xf64> {
38+
// CHECK: %[[VAL_1:.*]] = stablehlo.slice %[[VAL_0]] [1:2, 0:8, 1:73] : (tensor<20x24x80xf64>) -> tensor<1x8x72xf64>
39+
// CHECK: return %[[VAL_1]] : tensor<1x8x72xf64>
40+
// CHECK: }
41+
42+
43+
// Test 3: Only right-most result used - should become a regular slice
44+
func.func @multi_slice_only_right(%arg0: tensor<20x24x80xf64>) -> tensor<1x8x72xf64> {
45+
%0, %1, %2, %3, %4, %5 = "enzymexla.multi_slice"(%arg0) <{
46+
start_indices = array<i64: 1, 0, 3>,
47+
limit_indices = array<i64: 2, 8, 75>,
48+
strides = array<i64: 1, 1, 1>,
49+
dimension = 2 : si32,
50+
left_amount = 2 : si32,
51+
right_amount = 3 : si32
52+
}> : (tensor<20x24x80xf64>) -> (tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>)
53+
return %5 : tensor<1x8x72xf64>
54+
}
55+
56+
// CHECK-LABEL: func.func @multi_slice_only_right(
57+
// CHECK-SAME: %[[VAL_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: tensor<20x24x80xf64>) -> tensor<1x8x72xf64> {
58+
// CHECK: %[[VAL_1:.*]] = stablehlo.slice %[[VAL_0]] [1:2, 0:8, 6:78] : (tensor<20x24x80xf64>) -> tensor<1x8x72xf64>
59+
// CHECK: return %[[VAL_1]] : tensor<1x8x72xf64>
60+
// CHECK: }
61+
62+
63+
// Test 4: Two consecutive results used - should become smaller multi_slice
64+
func.func @multi_slice_consecutive(%arg0: tensor<20x24x80xf64>) -> (tensor<1x8x72xf64>, tensor<1x8x72xf64>) {
65+
%0, %1, %2, %3, %4, %5 = "enzymexla.multi_slice"(%arg0) <{
66+
start_indices = array<i64: 1, 0, 3>,
67+
limit_indices = array<i64: 2, 8, 75>,
68+
strides = array<i64: 1, 1, 1>,
69+
dimension = 2 : si32,
70+
left_amount = 2 : si32,
71+
right_amount = 3 : si32
72+
}> : (tensor<20x24x80xf64>) -> (tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>)
73+
return %2, %3 : tensor<1x8x72xf64>, tensor<1x8x72xf64>
74+
}
75+
76+
// CHECK-LABEL: func.func @multi_slice_consecutive(
77+
// CHECK-SAME: %[[VAL_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: tensor<20x24x80xf64>) -> (tensor<1x8x72xf64>, tensor<1x8x72xf64>) {
78+
// CHECK: %[[VAL_1:.*]]:2 = "enzymexla.multi_slice"(%[[VAL_0]]) <{dimension = 2 : si32, left_amount = 0 : si32, limit_indices = array<i64: 2, 8, 75>, right_amount = 1 : si32, start_indices = array<i64: 1, 0, 3>, strides = array<i64: 1, 1, 1>}> : (tensor<20x24x80xf64>) -> (tensor<1x8x72xf64>, tensor<1x8x72xf64>)
79+
// CHECK: return %[[VAL_1]]#0, %[[VAL_1]]#1 : tensor<1x8x72xf64>, tensor<1x8x72xf64>
80+
// CHECK: }
81+
82+
83+
// Test 5: Non-contiguous results used - should keep range between first and last used
84+
func.func @multi_slice_non_contiguous(%arg0: tensor<20x24x80xf64>) -> (tensor<1x8x72xf64>, tensor<1x8x72xf64>) {
85+
%0, %1, %2, %3, %4, %5 = "enzymexla.multi_slice"(%arg0) <{
86+
start_indices = array<i64: 1, 0, 3>,
87+
limit_indices = array<i64: 2, 8, 75>,
88+
strides = array<i64: 1, 1, 1>,
89+
dimension = 2 : si32,
90+
left_amount = 2 : si32,
91+
right_amount = 3 : si32
92+
}> : (tensor<20x24x80xf64>) -> (tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>)
93+
return %2, %5 : tensor<1x8x72xf64>, tensor<1x8x72xf64>
94+
}
95+
96+
// CHECK-LABEL: func.func @multi_slice_non_contiguous(
97+
// CHECK-SAME: %[[VAL_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: tensor<20x24x80xf64>) -> (tensor<1x8x72xf64>, tensor<1x8x72xf64>) {
98+
// CHECK: %[[VAL_1:.*]]:4 = "enzymexla.multi_slice"(%[[VAL_0]]) <{dimension = 2 : si32, left_amount = 0 : si32, limit_indices = array<i64: 2, 8, 75>, right_amount = 3 : si32, start_indices = array<i64: 1, 0, 3>, strides = array<i64: 1, 1, 1>}> : (tensor<20x24x80xf64>) -> (tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>)
99+
// CHECK: return %[[VAL_1]]#0, %[[VAL_1]]#3 : tensor<1x8x72xf64>, tensor<1x8x72xf64>
100+
// CHECK: }
101+
102+
103+
// Test 6: All results used - should not change
104+
func.func @multi_slice_all_used(%arg0: tensor<20x24x80xf64>) -> (tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>) {
105+
%0, %1, %2, %3, %4, %5 = "enzymexla.multi_slice"(%arg0) <{
106+
start_indices = array<i64: 1, 0, 3>,
107+
limit_indices = array<i64: 2, 8, 75>,
108+
strides = array<i64: 1, 1, 1>,
109+
dimension = 2 : si32,
110+
left_amount = 2 : si32,
111+
right_amount = 3 : si32
112+
}> : (tensor<20x24x80xf64>) -> (tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>)
113+
return %0, %1, %2, %3, %4, %5 : tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>
114+
}
115+
116+
// CHECK-LABEL: func.func @multi_slice_all_used(
117+
// CHECK-SAME: %[[VAL_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: tensor<20x24x80xf64>) -> (tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>) {
118+
// CHECK: %[[VAL_1:.*]]:6 = "enzymexla.multi_slice"(%[[VAL_0]]) <{dimension = 2 : si32, left_amount = 2 : si32, limit_indices = array<i64: 2, 8, 75>, right_amount = 3 : si32, start_indices = array<i64: 1, 0, 3>, strides = array<i64: 1, 1, 1>}> : (tensor<20x24x80xf64>) -> (tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>)
119+
// CHECK: return %[[VAL_1]]#0, %[[VAL_1]]#1, %[[VAL_1]]#2, %[[VAL_1]]#3, %[[VAL_1]]#4, %[[VAL_1]]#5 : tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>
120+
// CHECK: }
121+
122+
123+
// Test 7: Different dimension - test on dimension 0
124+
func.func @multi_slice_dim0(%arg0: tensor<20x24x80xf64>) -> tensor<4x24x80xf64> {
125+
%0, %1, %2, %3, %4 = "enzymexla.multi_slice"(%arg0) <{
126+
start_indices = array<i64: 8, 0, 0>,
127+
limit_indices = array<i64: 12, 24, 80>,
128+
strides = array<i64: 1, 1, 1>,
129+
dimension = 0 : si32,
130+
left_amount = 2 : si32,
131+
right_amount = 2 : si32
132+
}> : (tensor<20x24x80xf64>) -> (tensor<4x24x80xf64>, tensor<4x24x80xf64>, tensor<4x24x80xf64>, tensor<4x24x80xf64>, tensor<4x24x80xf64>)
133+
return %2 : tensor<4x24x80xf64>
134+
}
135+
136+
// CHECK-LABEL: func.func @multi_slice_dim0(
137+
// CHECK-SAME: %[[VAL_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: tensor<20x24x80xf64>) -> tensor<4x24x80xf64> {
138+
// CHECK: %[[VAL_1:.*]] = stablehlo.slice %[[VAL_0]] [8:12, 0:24, 0:80] : (tensor<20x24x80xf64>) -> tensor<4x24x80xf64>
139+
// CHECK: return %[[VAL_1]] : tensor<4x24x80xf64>
140+
// CHECK: }

0 commit comments

Comments
 (0)