-
Notifications
You must be signed in to change notification settings - Fork 29
Add multislice simplification #1958
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 9 commits
0a06aab
0b8a86d
43c04a9
51e8ca4
b2c180e
e59fee1
5894144
4924e9f
a4a2cdf
d99ab52
28882b6
0d0f264
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -30293,6 +30293,133 @@ struct ReduceWindowWrapSimplify final | |
| } | ||
| }; | ||
|
|
||
| // Pattern to reduce MultiSliceOp when some results are unused | ||
| struct ReduceUnusedMultiSlice final | ||
| : CheckedOpRewritePattern<enzymexla::MultiSliceOp, ReduceUnusedMultiSlice> { | ||
| using CheckedOpRewritePattern::CheckedOpRewritePattern; | ||
|
|
||
| LogicalResult matchAndRewriteImpl(enzymexla::MultiSliceOp op, | ||
| PatternRewriter &rewriter) const { | ||
| int32_t leftAmount = op.getLeftAmount(); | ||
| int32_t rightAmount = op.getRightAmount(); | ||
| int32_t totalResults = leftAmount + rightAmount + 1; | ||
|
|
||
| // Check which results are actually used | ||
| SmallVector<bool> used(totalResults, false); | ||
| int usedCount = 0; | ||
| for (int i = 0; i < totalResults; i++) { | ||
| if (!op.getResult(i).use_empty()) { | ||
| used[i] = true; | ||
| usedCount++; | ||
| } | ||
| } | ||
|
|
||
| // If all results are used, nothing to optimize | ||
| if (usedCount == totalResults) | ||
| return failure(); | ||
|
|
||
| // If no results are used, this should be handled by dead code elimination | ||
| if (usedCount == 0) { | ||
| rewriter.eraseOp(op); | ||
| return success(); | ||
| } | ||
|
|
||
| // Find the range of used results | ||
| int firstUsed = -1, lastUsed = -1; | ||
| for (int i = 0; i < totalResults; i++) { | ||
| if (used[i]) { | ||
| if (firstUsed == -1) | ||
| firstUsed = i; | ||
| lastUsed = i; | ||
| } | ||
| } | ||
|
|
||
| // Calculate new left and right amounts | ||
| int centerIdx = leftAmount; | ||
| int newLeftAmount = centerIdx - firstUsed; | ||
| int newRightAmount = lastUsed - centerIdx; | ||
|
|
||
| // If only one result is used, replace with a single SliceOp | ||
| if (usedCount == 1) { | ||
| int usedIdx = firstUsed; | ||
| int offset = usedIdx - centerIdx; // How much to shift the slice | ||
|
|
||
| auto startIndices = SmallVector<int64_t>(op.getStartIndices()); | ||
| auto limitIndices = SmallVector<int64_t>(op.getLimitIndices()); | ||
| auto strides = SmallVector<int64_t>(op.getStrides()); | ||
| int32_t dim = op.getDimension(); | ||
|
|
||
| // Adjust start and limit indices for the offset | ||
| if (dim >= 0 && dim < (int64_t)startIndices.size()) { | ||
| startIndices[dim] += offset; | ||
| limitIndices[dim] += offset; | ||
| } | ||
|
|
||
| auto sliceOp = rewriter.create<stablehlo::SliceOp>( | ||
| op.getLoc(), op.getOperand(), | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same replaceOpWithNewOp comment [with the sharding comment] from rotate |
||
| rewriter.getDenseI64ArrayAttr(startIndices), | ||
| rewriter.getDenseI64ArrayAttr(limitIndices), | ||
| rewriter.getDenseI64ArrayAttr(strides)); | ||
| // Propagate sharding if present | ||
| if (auto shard = sdy::getShardingPerValue(op)) { | ||
| sdy::setShardings(sliceOp, shard); | ||
| } | ||
|
|
||
| rewriter.replaceAllUsesWith(op.getResult(usedIdx), sliceOp.getResult()); | ||
| rewriter.eraseOp(op); | ||
| return success(); | ||
| } | ||
|
|
||
| // Otherwise, create a smaller MultiSliceOp | ||
| if (newLeftAmount != leftAmount || newRightAmount != rightAmount) { | ||
| // Adjust start indices for the new center | ||
| int offset = firstUsed - centerIdx; | ||
| auto startIndices = SmallVector<int64_t>(op.getStartIndices()); | ||
| auto limitIndices = SmallVector<int64_t>(op.getLimitIndices()); | ||
| int32_t dim = op.getDimension(); | ||
|
|
||
| if (dim >= 0 && dim < (int64_t)startIndices.size()) { | ||
| startIndices[dim] += offset; | ||
| limitIndices[dim] += offset; | ||
| } | ||
|
|
||
| // Determine result types | ||
| auto resultType = cast<RankedTensorType>(op.getResultTypes().front()); | ||
| SmallVector<Type> resultTypes; | ||
| for (int i = 0; i < newLeftAmount + newRightAmount + 1; i++) { | ||
| resultTypes.push_back(resultType); // Will be properly typed by the op | ||
| } | ||
|
|
||
| auto newOp = rewriter.create<enzymexla::MultiSliceOp>( | ||
| op.getLoc(), resultTypes, op.getOperand(), startIndices, limitIndices, | ||
| op.getStrides(), op.getDimension(), newLeftAmount, newRightAmount); | ||
| // Propagate sharding if present | ||
| if (auto shard = sdy::getShardingPerValue(op)) { | ||
| sdy::setShardings(newOp, shard); | ||
| } | ||
|
|
||
| // Map old results to new results | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. make sur eto keep sharding |
||
| SmallVector<Value> replacements(totalResults); | ||
| int newIdx = 0; | ||
| for (int oldIdx = firstUsed; oldIdx <= lastUsed; oldIdx++) { | ||
| replacements[oldIdx] = newOp.getResult(newIdx++); | ||
| } | ||
|
|
||
| // Replace uses | ||
| for (int i = 0; i < totalResults; i++) { | ||
| if (used[i]) { | ||
| op.getResult(i).replaceAllUsesWith(replacements[i]); | ||
| } | ||
| } | ||
|
|
||
| rewriter.eraseOp(op); | ||
| return success(); | ||
| } | ||
|
|
||
| return failure(); | ||
| } | ||
| }; | ||
|
|
||
| struct ScatterOpCanon final | ||
| : CheckedOpRewritePattern<stablehlo::ScatterOp, ScatterOpCanon> { | ||
| using CheckedOpRewritePattern::CheckedOpRewritePattern; | ||
|
|
@@ -30692,6 +30819,13 @@ void mlir::transform::addExtendLICM(RewritePatternSet &patterns, | |
| patterns.insert<LICM<enzymexla::ExtendOp>>(single_user, &context, benefit); | ||
| } | ||
|
|
||
| void mlir::transform::addMultiSliceLICM(RewritePatternSet &patterns, | ||
| bool single_user, MLIRContext &context, | ||
|
||
| PatternBenefit benefit) { | ||
| patterns.insert<LICM<enzymexla::MultiSliceOp>>(single_user, &context, | ||
| benefit); | ||
| } | ||
|
|
||
| void mlir::transform::addElementwiseLICM(RewritePatternSet &patterns, | ||
| bool single_user, MLIRContext &context, | ||
| PatternBenefit benefit) { | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,140 @@ | ||
| // RUN: enzymexlamlir-opt --enzyme-hlo-generate-td="patterns=reduce_unused_multislice" --transform-interpreter --enzyme-hlo-remove-transform %s | FileCheck %s | ||
|
|
||
| // Test 1: Only center result used - should become a regular slice | ||
| func.func @multi_slice_only_center(%arg0: tensor<20x24x80xf64>) -> tensor<1x8x72xf64> { | ||
| %0, %1, %2, %3, %4, %5 = "enzymexla.multi_slice"(%arg0) <{ | ||
| start_indices = array<i64: 1, 0, 3>, | ||
| limit_indices = array<i64: 2, 8, 75>, | ||
| strides = array<i64: 1, 1, 1>, | ||
| dimension = 2 : si32, | ||
| left_amount = 2 : si32, | ||
| right_amount = 3 : si32 | ||
| }> : (tensor<20x24x80xf64>) -> (tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>) | ||
| return %2 : tensor<1x8x72xf64> | ||
| } | ||
|
|
||
| // CHECK-LABEL: func.func @multi_slice_only_center( | ||
| // CHECK-SAME: %[[VAL_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: tensor<20x24x80xf64>) -> tensor<1x8x72xf64> { | ||
| // CHECK: %[[VAL_1:.*]] = stablehlo.slice %[[VAL_0]] [1:2, 0:8, 3:75] : (tensor<20x24x80xf64>) -> tensor<1x8x72xf64> | ||
| // CHECK: return %[[VAL_1]] : tensor<1x8x72xf64> | ||
| // CHECK: } | ||
|
|
||
|
|
||
| // Test 2: Only left-most result used - should become a regular slice | ||
| func.func @multi_slice_only_left(%arg0: tensor<20x24x80xf64>) -> tensor<1x8x72xf64> { | ||
| %0, %1, %2, %3, %4, %5 = "enzymexla.multi_slice"(%arg0) <{ | ||
| start_indices = array<i64: 1, 0, 3>, | ||
| limit_indices = array<i64: 2, 8, 75>, | ||
| strides = array<i64: 1, 1, 1>, | ||
| dimension = 2 : si32, | ||
| left_amount = 2 : si32, | ||
| right_amount = 3 : si32 | ||
| }> : (tensor<20x24x80xf64>) -> (tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>) | ||
| return %0 : tensor<1x8x72xf64> | ||
| } | ||
|
|
||
| // CHECK-LABEL: func.func @multi_slice_only_left( | ||
| // CHECK-SAME: %[[VAL_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: tensor<20x24x80xf64>) -> tensor<1x8x72xf64> { | ||
| // CHECK: %[[VAL_1:.*]] = stablehlo.slice %[[VAL_0]] [1:2, 0:8, 1:73] : (tensor<20x24x80xf64>) -> tensor<1x8x72xf64> | ||
| // CHECK: return %[[VAL_1]] : tensor<1x8x72xf64> | ||
| // CHECK: } | ||
|
|
||
|
|
||
| // Test 3: Only right-most result used - should become a regular slice | ||
| func.func @multi_slice_only_right(%arg0: tensor<20x24x80xf64>) -> tensor<1x8x72xf64> { | ||
| %0, %1, %2, %3, %4, %5 = "enzymexla.multi_slice"(%arg0) <{ | ||
| start_indices = array<i64: 1, 0, 3>, | ||
| limit_indices = array<i64: 2, 8, 75>, | ||
| strides = array<i64: 1, 1, 1>, | ||
| dimension = 2 : si32, | ||
| left_amount = 2 : si32, | ||
| right_amount = 3 : si32 | ||
| }> : (tensor<20x24x80xf64>) -> (tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>) | ||
| return %5 : tensor<1x8x72xf64> | ||
| } | ||
|
|
||
| // CHECK-LABEL: func.func @multi_slice_only_right( | ||
| // CHECK-SAME: %[[VAL_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: tensor<20x24x80xf64>) -> tensor<1x8x72xf64> { | ||
| // CHECK: %[[VAL_1:.*]] = stablehlo.slice %[[VAL_0]] [1:2, 0:8, 6:78] : (tensor<20x24x80xf64>) -> tensor<1x8x72xf64> | ||
| // CHECK: return %[[VAL_1]] : tensor<1x8x72xf64> | ||
| // CHECK: } | ||
|
|
||
|
|
||
| // Test 4: Two consecutive results used - should become smaller multi_slice | ||
| func.func @multi_slice_consecutive(%arg0: tensor<20x24x80xf64>) -> (tensor<1x8x72xf64>, tensor<1x8x72xf64>) { | ||
| %0, %1, %2, %3, %4, %5 = "enzymexla.multi_slice"(%arg0) <{ | ||
| start_indices = array<i64: 1, 0, 3>, | ||
| limit_indices = array<i64: 2, 8, 75>, | ||
| strides = array<i64: 1, 1, 1>, | ||
| dimension = 2 : si32, | ||
| left_amount = 2 : si32, | ||
| right_amount = 3 : si32 | ||
| }> : (tensor<20x24x80xf64>) -> (tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>) | ||
| return %2, %3 : tensor<1x8x72xf64>, tensor<1x8x72xf64> | ||
| } | ||
|
|
||
| // CHECK-LABEL: func.func @multi_slice_consecutive( | ||
| // CHECK-SAME: %[[VAL_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: tensor<20x24x80xf64>) -> (tensor<1x8x72xf64>, tensor<1x8x72xf64>) { | ||
| // CHECK: %[[VAL_1:.*]]:2 = "enzymexla.multi_slice"(%[[VAL_0]]) <{dimension = 2 : si32, left_amount = 0 : si32, limit_indices = array<i64: 2, 8, 75>, right_amount = 1 : si32, start_indices = array<i64: 1, 0, 3>, strides = array<i64: 1, 1, 1>}> : (tensor<20x24x80xf64>) -> (tensor<1x8x72xf64>, tensor<1x8x72xf64>) | ||
| // CHECK: return %[[VAL_1]]#0, %[[VAL_1]]#1 : tensor<1x8x72xf64>, tensor<1x8x72xf64> | ||
| // CHECK: } | ||
|
|
||
|
|
||
| // Test 5: Non-contiguous results used - should keep range between first and last used | ||
| func.func @multi_slice_non_contiguous(%arg0: tensor<20x24x80xf64>) -> (tensor<1x8x72xf64>, tensor<1x8x72xf64>) { | ||
| %0, %1, %2, %3, %4, %5 = "enzymexla.multi_slice"(%arg0) <{ | ||
| start_indices = array<i64: 1, 0, 3>, | ||
| limit_indices = array<i64: 2, 8, 75>, | ||
| strides = array<i64: 1, 1, 1>, | ||
| dimension = 2 : si32, | ||
| left_amount = 2 : si32, | ||
| right_amount = 3 : si32 | ||
| }> : (tensor<20x24x80xf64>) -> (tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>) | ||
| return %2, %5 : tensor<1x8x72xf64>, tensor<1x8x72xf64> | ||
| } | ||
|
|
||
| // CHECK-LABEL: func.func @multi_slice_non_contiguous( | ||
| // CHECK-SAME: %[[VAL_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: tensor<20x24x80xf64>) -> (tensor<1x8x72xf64>, tensor<1x8x72xf64>) { | ||
| // CHECK: %[[VAL_1:.*]]:4 = "enzymexla.multi_slice"(%[[VAL_0]]) <{dimension = 2 : si32, left_amount = 0 : si32, limit_indices = array<i64: 2, 8, 75>, right_amount = 3 : si32, start_indices = array<i64: 1, 0, 3>, strides = array<i64: 1, 1, 1>}> : (tensor<20x24x80xf64>) -> (tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>) | ||
| // CHECK: return %[[VAL_1]]#0, %[[VAL_1]]#3 : tensor<1x8x72xf64>, tensor<1x8x72xf64> | ||
| // CHECK: } | ||
|
|
||
|
|
||
| // Test 6: All results used - should not change | ||
| func.func @multi_slice_all_used(%arg0: tensor<20x24x80xf64>) -> (tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>) { | ||
| %0, %1, %2, %3, %4, %5 = "enzymexla.multi_slice"(%arg0) <{ | ||
| start_indices = array<i64: 1, 0, 3>, | ||
| limit_indices = array<i64: 2, 8, 75>, | ||
| strides = array<i64: 1, 1, 1>, | ||
| dimension = 2 : si32, | ||
| left_amount = 2 : si32, | ||
| right_amount = 3 : si32 | ||
| }> : (tensor<20x24x80xf64>) -> (tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>) | ||
| return %0, %1, %2, %3, %4, %5 : tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64> | ||
| } | ||
|
|
||
| // CHECK-LABEL: func.func @multi_slice_all_used( | ||
| // CHECK-SAME: %[[VAL_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: tensor<20x24x80xf64>) -> (tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>) { | ||
| // CHECK: %[[VAL_1:.*]]:6 = "enzymexla.multi_slice"(%[[VAL_0]]) <{dimension = 2 : si32, left_amount = 2 : si32, limit_indices = array<i64: 2, 8, 75>, right_amount = 3 : si32, start_indices = array<i64: 1, 0, 3>, strides = array<i64: 1, 1, 1>}> : (tensor<20x24x80xf64>) -> (tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>) | ||
| // CHECK: return %[[VAL_1]]#0, %[[VAL_1]]#1, %[[VAL_1]]#2, %[[VAL_1]]#3, %[[VAL_1]]#4, %[[VAL_1]]#5 : tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64>, tensor<1x8x72xf64> | ||
| // CHECK: } | ||
|
|
||
|
|
||
| // Test 7: Different dimension - test on dimension 0 | ||
| func.func @multi_slice_dim0(%arg0: tensor<20x24x80xf64>) -> tensor<4x24x80xf64> { | ||
| %0, %1, %2, %3, %4 = "enzymexla.multi_slice"(%arg0) <{ | ||
| start_indices = array<i64: 8, 0, 0>, | ||
| limit_indices = array<i64: 12, 24, 80>, | ||
| strides = array<i64: 1, 1, 1>, | ||
| dimension = 0 : si32, | ||
| left_amount = 2 : si32, | ||
| right_amount = 2 : si32 | ||
| }> : (tensor<20x24x80xf64>) -> (tensor<4x24x80xf64>, tensor<4x24x80xf64>, tensor<4x24x80xf64>, tensor<4x24x80xf64>, tensor<4x24x80xf64>) | ||
| return %2 : tensor<4x24x80xf64> | ||
| } | ||
|
|
||
| // CHECK-LABEL: func.func @multi_slice_dim0( | ||
| // CHECK-SAME: %[[VAL_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: tensor<20x24x80xf64>) -> tensor<4x24x80xf64> { | ||
| // CHECK: %[[VAL_1:.*]] = stablehlo.slice %[[VAL_0]] [8:12, 0:24, 0:80] : (tensor<20x24x80xf64>) -> tensor<4x24x80xf64> | ||
| // CHECK: return %[[VAL_1]] : tensor<4x24x80xf64> | ||
| // CHECK: } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
also might as well do deletion
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same success comment