Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
256 changes: 253 additions & 3 deletions src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31851,6 +31851,231 @@ struct ReduceWindowWrapSimplify final
}
};

// Pattern to recognize multiple SliceOps on the same input that differ only
// by offset along one dimension and combine them into a single MultiSliceOp
struct RecognizeMultiSlice
: public CheckedOpRewritePattern<stablehlo::SliceOp, RecognizeMultiSlice> {
using CheckedOpRewritePattern::CheckedOpRewritePattern;

LogicalResult matchAndRewriteImpl(stablehlo::SliceOp op,
PatternRewriter &rewriter) const {
Value input = op.getOperand();
auto startIndices = SmallVector<int64_t>(op.getStartIndices());
auto limitIndices = SmallVector<int64_t>(op.getLimitIndices());
auto strides = SmallVector<int64_t>(op.getStrides());

// Find all SliceOps on the same input
SmallVector<stablehlo::SliceOp> slices;

for (Operation *user : input.getUsers()) {
if (auto slice = dyn_cast<stablehlo::SliceOp>(user)) {
if (!slice.use_empty()) {
slices.push_back(slice);
}
}
}

// Need at least 2 slices to consider combining
if (slices.size() < 2)
return failure();

int64_t rank = startIndices.size();

// Try each dimension as the potential varying dimension
for (int64_t dim = 0; dim < rank; dim++) {
// Collect slices that match this op except for offset along `dim`
SmallVector<std::pair<stablehlo::SliceOp, int64_t>> matchingSlices;
DenseSet<int64_t> offsetSet;

for (auto slice : slices) {
auto sliceStart = SmallVector<int64_t>(slice.getStartIndices());
auto sliceLimit = SmallVector<int64_t>(slice.getLimitIndices());
auto sliceStrides = SmallVector<int64_t>(slice.getStrides());

// Check if strides match
if (sliceStrides != strides)
continue;

// Check if all dimensions except `dim` match exactly
bool matches = true;
for (int64_t d = 0; d < rank; d++) {
if (d == dim)
continue;
if (sliceStart[d] != startIndices[d] ||
sliceLimit[d] != limitIndices[d]) {
matches = false;
break;
}
}
if (!matches)
continue;

// Check if the slice shape along `dim` matches
int64_t thisSliceSize = limitIndices[dim] - startIndices[dim];
int64_t otherSliceSize = sliceLimit[dim] - sliceStart[dim];
if (thisSliceSize != otherSliceSize)
continue;

// Calculate offset relative to this slice (this slice has offset 0)
int64_t offset = sliceStart[dim] - startIndices[dim];
matchingSlices.push_back({slice, offset});
offsetSet.insert(offset);
}

// Need at least 2 matching slices along this dimension
if (matchingSlices.size() < 2)
continue;

// Sort offsets to find contiguous groups
SmallVector<int64_t> sortedOffsets(offsetSet.begin(), offsetSet.end());
llvm::sort(sortedOffsets);

// Find all contiguous groups (no gaps in offsets)
SmallVector<std::pair<int64_t, int64_t>>
contiguousGroups; // (start, end) inclusive
int64_t groupStart = sortedOffsets[0];
int64_t groupEnd = sortedOffsets[0];

for (size_t i = 1; i < sortedOffsets.size(); i++) {
if (sortedOffsets[i] == groupEnd + 1) {
// Extend current group
groupEnd = sortedOffsets[i];
} else {
// Save current group and start new one
contiguousGroups.push_back({groupStart, groupEnd});
groupStart = sortedOffsets[i];
groupEnd = sortedOffsets[i];
}
}
contiguousGroups.push_back({groupStart, groupEnd});

// Collect all offsets from groups that contain identity (0) or neighbor
// it
SmallVector<int64_t> qualifyingOffsets;

for (auto &[start, end] : contiguousGroups) {
bool containsIdentity = (start <= 0 && end >= 0);
bool neighborsIdentity = (end == -1 || start == 1);

if (containsIdentity || neighborsIdentity) {
for (int64_t o = start; o <= end; o++) {
if (offsetSet.contains(o)) {
qualifyingOffsets.push_back(o);
}
}
}
}

// No qualifying groups found
if (qualifyingOffsets.size() < 2)
continue;

// Determine the multiSlice range from qualifying offsets, extended to
// include 0
int64_t rangeStart =
*std::min_element(qualifyingOffsets.begin(), qualifyingOffsets.end());
int64_t rangeEnd =
*std::max_element(qualifyingOffsets.begin(), qualifyingOffsets.end());

// Extend to include identity (0) if not already included
if (rangeStart > 0)
rangeStart = 0;
if (rangeEnd < 0)
rangeEnd = 0;

// Check if this op (offset 0) is part of the selected range
if (0 < rangeStart || 0 > rangeEnd)
continue;

// Find the minimum offset among actual slices in the selected range
// to use as the canonical trigger (prevents processing same group
// multiple times)
int64_t minOffsetInRange = INT64_MAX;
int slicesToCombine = 0;
for (auto &[slice, offset] : matchingSlices) {
if (offset >= rangeStart && offset <= rangeEnd) {
minOffsetInRange = std::min(minOffsetInRange, offset);
slicesToCombine++;
}
}

// Only proceed if this op has the minimum offset in the range
if (0 != minOffsetInRange)
return failure();

// Need at least 2 slices in the selected range to combine
if (slicesToCombine < 2)
continue;

// Check that all slices in the range have the same sharding
std::optional<sdy::TensorShardingPerValueAttr> commonSharding;
bool shardingMismatch = false;
for (auto &[slice, offset] : matchingSlices) {
if (offset >= rangeStart && offset <= rangeEnd) {
auto shardPerValue = sdy::getShardingPerValue(slice);
if (!commonSharding.has_value()) {
commonSharding = shardPerValue;
} else if (commonSharding.value() != shardPerValue) {
shardingMismatch = true;
break;
}
}
}
if (shardingMismatch)
return failure();

// Calculate the shift amount
// leftShift covers negative offsets (slices shifted left/earlier)
int32_t leftShift = rangeStart < 0 ? (int32_t)(-rangeStart) : 0;
int32_t rightExtent = rangeEnd > 0 ? (int32_t)rangeEnd : 0;
int32_t amount = leftShift + rightExtent;
int32_t totalResults = amount + 1;

// Adjust start and limit indices by shifting left along the dimension
SmallVector<int64_t> adjustedStartIndices = startIndices;
SmallVector<int64_t> adjustedLimitIndices = limitIndices;
adjustedStartIndices[dim] -= leftShift;
adjustedLimitIndices[dim] -= leftShift;

// Create the MultiSliceOp
auto resultType = op.getResult().getType();
SmallVector<Type> resultTypes(totalResults, resultType);

rewriter.setInsertionPointAfterValue(input);
auto newOp = rewriter.create<enzymexla::MultiSliceOp>(
op.getLoc(), resultTypes, input, adjustedStartIndices,
adjustedLimitIndices, strides, (int32_t)dim, amount);

// Propagate sharding if present (all slices have the same sharding)
if (commonSharding.has_value() && commonSharding.value()) {
auto shardings = commonSharding.value().getShardings();
if (!shardings.empty()) {
sdy::TensorShardingAttr singleShard = shardings[0];
SmallVector<sdy::TensorShardingAttr> newShardings(totalResults,
singleShard);
sdy::setShardings(newOp, sdy::TensorShardingPerValueAttr::get(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good, @jumerckx can you also port this to recognizemultirotate?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

op.getContext(), newShardings));
}
}

// Replace slices that fall within the selected range
for (auto &[slice, offset] : matchingSlices) {
if (offset >= rangeStart && offset <= rangeEnd) {
// Result index for offset 'o' is: leftShift + o
// (since we shifted start_indices left by leftShift)
int32_t resultIdx = leftShift + (int32_t)offset;
rewriter.replaceOp(slice, newOp.getResult(resultIdx));
}
// Slices outside the range are left unchanged
}

return success();
}

return failure();
}
};

// Pattern to reduce MultiSliceOp when some results are unused
struct ReduceUnusedMultiSlice final
: CheckedOpRewritePattern<enzymexla::MultiSliceOp, ReduceUnusedMultiSlice> {
Expand Down Expand Up @@ -32129,6 +32354,24 @@ struct RecognizeMultiRotate
if (rotatesToCombine < 2)
return failure();

// Check that all rotates in the range have the same sharding
std::optional<sdy::TensorShardingPerValueAttr> commonSharding;
bool shardingMismatch = false;
for (auto rotate : rotates) {
int32_t amt = rotate.getAmount();
if (amt >= rangeStart && amt <= rangeEnd) {
auto shardPerValue = sdy::getShardingPerValue(rotate);
if (!commonSharding.has_value()) {
commonSharding = shardPerValue;
} else if (commonSharding.value() != shardPerValue) {
shardingMismatch = true;
break;
}
}
}
if (shardingMismatch)
return failure();

// Calculate left and right amounts for MultiRotateOp
// leftAmount covers positive rotations (rotate left)
// rightAmount covers negative rotations (rotate right)
Expand All @@ -32143,9 +32386,16 @@ struct RecognizeMultiRotate
op.getDimensionAttr(), rewriter.getSI32IntegerAttr(leftAmount),
rewriter.getSI32IntegerAttr(rightAmount));

// Propagate sharding if present
if (auto shard = sdy::getShardingPerValue(op)) {
sdy::setShardings(newOp, shard);
// Propagate sharding if present (all rotates have the same sharding)
if (commonSharding.has_value() && commonSharding.value()) {
auto shardings = commonSharding.value().getShardings();
if (!shardings.empty()) {
sdy::TensorShardingAttr singleShard = shardings[0];
SmallVector<sdy::TensorShardingAttr> newShardings(totalResults,
singleShard);
sdy::setShardings(newOp, sdy::TensorShardingPerValueAttr::get(
op.getContext(), newShardings));
}
}

// Replace rotations that fall within the selected range
Expand Down
5 changes: 5 additions & 0 deletions src/enzyme_ad/jax/TransformOps/TransformOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -2224,6 +2224,11 @@ def TransposeRotate : EnzymeHLOPatternOp<
let patterns = ["TransposeRotate"];
}

def RecognizeMultiSlice : EnzymeHLOPatternOp<
"recognize_multislice"> {
let patterns = ["RecognizeMultiSlice"];
}

def ReduceUnusedMultiSlice : EnzymeHLOPatternOp<
"reduce_unused_multislice"> {
let patterns = ["ReduceUnusedMultiSlice"];
Expand Down
Loading
Loading