Skip to content

Commit f7ebe6c

Browse files
authored
Add RecognizeMultiSlice (#2023)
* add RecognizeMultiSlice * update for single amount * register transform pattern * test * try fix sharding handling * also update RecognizeMultiRotate
1 parent 2093e94 commit f7ebe6c

File tree

3 files changed

+532
-3
lines changed

3 files changed

+532
-3
lines changed

src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp

Lines changed: 253 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31851,6 +31851,231 @@ struct ReduceWindowWrapSimplify final
3185131851
}
3185231852
};
3185331853

31854+
// Pattern to recognize multiple SliceOps on the same input that differ only
31855+
// by offset along one dimension and combine them into a single MultiSliceOp
31856+
struct RecognizeMultiSlice
31857+
: public CheckedOpRewritePattern<stablehlo::SliceOp, RecognizeMultiSlice> {
31858+
using CheckedOpRewritePattern::CheckedOpRewritePattern;
31859+
31860+
LogicalResult matchAndRewriteImpl(stablehlo::SliceOp op,
31861+
PatternRewriter &rewriter) const {
31862+
Value input = op.getOperand();
31863+
auto startIndices = SmallVector<int64_t>(op.getStartIndices());
31864+
auto limitIndices = SmallVector<int64_t>(op.getLimitIndices());
31865+
auto strides = SmallVector<int64_t>(op.getStrides());
31866+
31867+
// Find all SliceOps on the same input
31868+
SmallVector<stablehlo::SliceOp> slices;
31869+
31870+
for (Operation *user : input.getUsers()) {
31871+
if (auto slice = dyn_cast<stablehlo::SliceOp>(user)) {
31872+
if (!slice.use_empty()) {
31873+
slices.push_back(slice);
31874+
}
31875+
}
31876+
}
31877+
31878+
// Need at least 2 slices to consider combining
31879+
if (slices.size() < 2)
31880+
return failure();
31881+
31882+
int64_t rank = startIndices.size();
31883+
31884+
// Try each dimension as the potential varying dimension
31885+
for (int64_t dim = 0; dim < rank; dim++) {
31886+
// Collect slices that match this op except for offset along `dim`
31887+
SmallVector<std::pair<stablehlo::SliceOp, int64_t>> matchingSlices;
31888+
DenseSet<int64_t> offsetSet;
31889+
31890+
for (auto slice : slices) {
31891+
auto sliceStart = SmallVector<int64_t>(slice.getStartIndices());
31892+
auto sliceLimit = SmallVector<int64_t>(slice.getLimitIndices());
31893+
auto sliceStrides = SmallVector<int64_t>(slice.getStrides());
31894+
31895+
// Check if strides match
31896+
if (sliceStrides != strides)
31897+
continue;
31898+
31899+
// Check if all dimensions except `dim` match exactly
31900+
bool matches = true;
31901+
for (int64_t d = 0; d < rank; d++) {
31902+
if (d == dim)
31903+
continue;
31904+
if (sliceStart[d] != startIndices[d] ||
31905+
sliceLimit[d] != limitIndices[d]) {
31906+
matches = false;
31907+
break;
31908+
}
31909+
}
31910+
if (!matches)
31911+
continue;
31912+
31913+
// Check if the slice shape along `dim` matches
31914+
int64_t thisSliceSize = limitIndices[dim] - startIndices[dim];
31915+
int64_t otherSliceSize = sliceLimit[dim] - sliceStart[dim];
31916+
if (thisSliceSize != otherSliceSize)
31917+
continue;
31918+
31919+
// Calculate offset relative to this slice (this slice has offset 0)
31920+
int64_t offset = sliceStart[dim] - startIndices[dim];
31921+
matchingSlices.push_back({slice, offset});
31922+
offsetSet.insert(offset);
31923+
}
31924+
31925+
// Need at least 2 matching slices along this dimension
31926+
if (matchingSlices.size() < 2)
31927+
continue;
31928+
31929+
// Sort offsets to find contiguous groups
31930+
SmallVector<int64_t> sortedOffsets(offsetSet.begin(), offsetSet.end());
31931+
llvm::sort(sortedOffsets);
31932+
31933+
// Find all contiguous groups (no gaps in offsets)
31934+
SmallVector<std::pair<int64_t, int64_t>>
31935+
contiguousGroups; // (start, end) inclusive
31936+
int64_t groupStart = sortedOffsets[0];
31937+
int64_t groupEnd = sortedOffsets[0];
31938+
31939+
for (size_t i = 1; i < sortedOffsets.size(); i++) {
31940+
if (sortedOffsets[i] == groupEnd + 1) {
31941+
// Extend current group
31942+
groupEnd = sortedOffsets[i];
31943+
} else {
31944+
// Save current group and start new one
31945+
contiguousGroups.push_back({groupStart, groupEnd});
31946+
groupStart = sortedOffsets[i];
31947+
groupEnd = sortedOffsets[i];
31948+
}
31949+
}
31950+
contiguousGroups.push_back({groupStart, groupEnd});
31951+
31952+
// Collect all offsets from groups that contain identity (0) or neighbor
31953+
// it
31954+
SmallVector<int64_t> qualifyingOffsets;
31955+
31956+
for (auto &[start, end] : contiguousGroups) {
31957+
bool containsIdentity = (start <= 0 && end >= 0);
31958+
bool neighborsIdentity = (end == -1 || start == 1);
31959+
31960+
if (containsIdentity || neighborsIdentity) {
31961+
for (int64_t o = start; o <= end; o++) {
31962+
if (offsetSet.contains(o)) {
31963+
qualifyingOffsets.push_back(o);
31964+
}
31965+
}
31966+
}
31967+
}
31968+
31969+
// No qualifying groups found
31970+
if (qualifyingOffsets.size() < 2)
31971+
continue;
31972+
31973+
// Determine the multiSlice range from qualifying offsets, extended to
31974+
// include 0
31975+
int64_t rangeStart =
31976+
*std::min_element(qualifyingOffsets.begin(), qualifyingOffsets.end());
31977+
int64_t rangeEnd =
31978+
*std::max_element(qualifyingOffsets.begin(), qualifyingOffsets.end());
31979+
31980+
// Extend to include identity (0) if not already included
31981+
if (rangeStart > 0)
31982+
rangeStart = 0;
31983+
if (rangeEnd < 0)
31984+
rangeEnd = 0;
31985+
31986+
// Check if this op (offset 0) is part of the selected range
31987+
if (0 < rangeStart || 0 > rangeEnd)
31988+
continue;
31989+
31990+
// Find the minimum offset among actual slices in the selected range
31991+
// to use as the canonical trigger (prevents processing same group
31992+
// multiple times)
31993+
int64_t minOffsetInRange = INT64_MAX;
31994+
int slicesToCombine = 0;
31995+
for (auto &[slice, offset] : matchingSlices) {
31996+
if (offset >= rangeStart && offset <= rangeEnd) {
31997+
minOffsetInRange = std::min(minOffsetInRange, offset);
31998+
slicesToCombine++;
31999+
}
32000+
}
32001+
32002+
// Only proceed if this op has the minimum offset in the range
32003+
if (0 != minOffsetInRange)
32004+
return failure();
32005+
32006+
// Need at least 2 slices in the selected range to combine
32007+
if (slicesToCombine < 2)
32008+
continue;
32009+
32010+
// Check that all slices in the range have the same sharding
32011+
std::optional<sdy::TensorShardingPerValueAttr> commonSharding;
32012+
bool shardingMismatch = false;
32013+
for (auto &[slice, offset] : matchingSlices) {
32014+
if (offset >= rangeStart && offset <= rangeEnd) {
32015+
auto shardPerValue = sdy::getShardingPerValue(slice);
32016+
if (!commonSharding.has_value()) {
32017+
commonSharding = shardPerValue;
32018+
} else if (commonSharding.value() != shardPerValue) {
32019+
shardingMismatch = true;
32020+
break;
32021+
}
32022+
}
32023+
}
32024+
if (shardingMismatch)
32025+
return failure();
32026+
32027+
// Calculate the shift amount
32028+
// leftShift covers negative offsets (slices shifted left/earlier)
32029+
int32_t leftShift = rangeStart < 0 ? (int32_t)(-rangeStart) : 0;
32030+
int32_t rightExtent = rangeEnd > 0 ? (int32_t)rangeEnd : 0;
32031+
int32_t amount = leftShift + rightExtent;
32032+
int32_t totalResults = amount + 1;
32033+
32034+
// Adjust start and limit indices by shifting left along the dimension
32035+
SmallVector<int64_t> adjustedStartIndices = startIndices;
32036+
SmallVector<int64_t> adjustedLimitIndices = limitIndices;
32037+
adjustedStartIndices[dim] -= leftShift;
32038+
adjustedLimitIndices[dim] -= leftShift;
32039+
32040+
// Create the MultiSliceOp
32041+
auto resultType = op.getResult().getType();
32042+
SmallVector<Type> resultTypes(totalResults, resultType);
32043+
32044+
rewriter.setInsertionPointAfterValue(input);
32045+
auto newOp = rewriter.create<enzymexla::MultiSliceOp>(
32046+
op.getLoc(), resultTypes, input, adjustedStartIndices,
32047+
adjustedLimitIndices, strides, (int32_t)dim, amount);
32048+
32049+
// Propagate sharding if present (all slices have the same sharding)
32050+
if (commonSharding.has_value() && commonSharding.value()) {
32051+
auto shardings = commonSharding.value().getShardings();
32052+
if (!shardings.empty()) {
32053+
sdy::TensorShardingAttr singleShard = shardings[0];
32054+
SmallVector<sdy::TensorShardingAttr> newShardings(totalResults,
32055+
singleShard);
32056+
sdy::setShardings(newOp, sdy::TensorShardingPerValueAttr::get(
32057+
op.getContext(), newShardings));
32058+
}
32059+
}
32060+
32061+
// Replace slices that fall within the selected range
32062+
for (auto &[slice, offset] : matchingSlices) {
32063+
if (offset >= rangeStart && offset <= rangeEnd) {
32064+
// Result index for offset 'o' is: leftShift + o
32065+
// (since we shifted start_indices left by leftShift)
32066+
int32_t resultIdx = leftShift + (int32_t)offset;
32067+
rewriter.replaceOp(slice, newOp.getResult(resultIdx));
32068+
}
32069+
// Slices outside the range are left unchanged
32070+
}
32071+
32072+
return success();
32073+
}
32074+
32075+
return failure();
32076+
}
32077+
};
32078+
3185432079
// Pattern to reduce MultiSliceOp when some results are unused
3185532080
struct ReduceUnusedMultiSlice final
3185632081
: CheckedOpRewritePattern<enzymexla::MultiSliceOp, ReduceUnusedMultiSlice> {
@@ -32129,6 +32354,24 @@ struct RecognizeMultiRotate
3212932354
if (rotatesToCombine < 2)
3213032355
return failure();
3213132356

32357+
// Check that all rotates in the range have the same sharding
32358+
std::optional<sdy::TensorShardingPerValueAttr> commonSharding;
32359+
bool shardingMismatch = false;
32360+
for (auto rotate : rotates) {
32361+
int32_t amt = rotate.getAmount();
32362+
if (amt >= rangeStart && amt <= rangeEnd) {
32363+
auto shardPerValue = sdy::getShardingPerValue(rotate);
32364+
if (!commonSharding.has_value()) {
32365+
commonSharding = shardPerValue;
32366+
} else if (commonSharding.value() != shardPerValue) {
32367+
shardingMismatch = true;
32368+
break;
32369+
}
32370+
}
32371+
}
32372+
if (shardingMismatch)
32373+
return failure();
32374+
3213232375
// Calculate left and right amounts for MultiRotateOp
3213332376
// leftAmount covers positive rotations (rotate left)
3213432377
// rightAmount covers negative rotations (rotate right)
@@ -32143,9 +32386,16 @@ struct RecognizeMultiRotate
3214332386
op.getDimensionAttr(), rewriter.getSI32IntegerAttr(leftAmount),
3214432387
rewriter.getSI32IntegerAttr(rightAmount));
3214532388

32146-
// Propagate sharding if present
32147-
if (auto shard = sdy::getShardingPerValue(op)) {
32148-
sdy::setShardings(newOp, shard);
32389+
// Propagate sharding if present (all rotates have the same sharding)
32390+
if (commonSharding.has_value() && commonSharding.value()) {
32391+
auto shardings = commonSharding.value().getShardings();
32392+
if (!shardings.empty()) {
32393+
sdy::TensorShardingAttr singleShard = shardings[0];
32394+
SmallVector<sdy::TensorShardingAttr> newShardings(totalResults,
32395+
singleShard);
32396+
sdy::setShardings(newOp, sdy::TensorShardingPerValueAttr::get(
32397+
op.getContext(), newShardings));
32398+
}
3214932399
}
3215032400

3215132401
// Replace rotations that fall within the selected range

src/enzyme_ad/jax/TransformOps/TransformOps.td

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2224,6 +2224,11 @@ def TransposeRotate : EnzymeHLOPatternOp<
22242224
let patterns = ["TransposeRotate"];
22252225
}
22262226

2227+
def RecognizeMultiSlice : EnzymeHLOPatternOp<
2228+
"recognize_multislice"> {
2229+
let patterns = ["RecognizeMultiSlice"];
2230+
}
2231+
22272232
def ReduceUnusedMultiSlice : EnzymeHLOPatternOp<
22282233
"reduce_unused_multislice"> {
22292234
let patterns = ["ReduceUnusedMultiSlice"];

0 commit comments

Comments
 (0)