@@ -31851,6 +31851,231 @@ struct ReduceWindowWrapSimplify final
3185131851 }
3185231852};
3185331853
31854+ // Pattern to recognize multiple SliceOps on the same input that differ only
31855+ // by offset along one dimension and combine them into a single MultiSliceOp
31856+ struct RecognizeMultiSlice
31857+ : public CheckedOpRewritePattern<stablehlo::SliceOp, RecognizeMultiSlice> {
31858+ using CheckedOpRewritePattern::CheckedOpRewritePattern;
31859+
31860+ LogicalResult matchAndRewriteImpl(stablehlo::SliceOp op,
31861+ PatternRewriter &rewriter) const {
31862+ Value input = op.getOperand();
31863+ auto startIndices = SmallVector<int64_t>(op.getStartIndices());
31864+ auto limitIndices = SmallVector<int64_t>(op.getLimitIndices());
31865+ auto strides = SmallVector<int64_t>(op.getStrides());
31866+
31867+ // Find all SliceOps on the same input
31868+ SmallVector<stablehlo::SliceOp> slices;
31869+
31870+ for (Operation *user : input.getUsers()) {
31871+ if (auto slice = dyn_cast<stablehlo::SliceOp>(user)) {
31872+ if (!slice.use_empty()) {
31873+ slices.push_back(slice);
31874+ }
31875+ }
31876+ }
31877+
31878+ // Need at least 2 slices to consider combining
31879+ if (slices.size() < 2)
31880+ return failure();
31881+
31882+ int64_t rank = startIndices.size();
31883+
31884+ // Try each dimension as the potential varying dimension
31885+ for (int64_t dim = 0; dim < rank; dim++) {
31886+ // Collect slices that match this op except for offset along `dim`
31887+ SmallVector<std::pair<stablehlo::SliceOp, int64_t>> matchingSlices;
31888+ DenseSet<int64_t> offsetSet;
31889+
31890+ for (auto slice : slices) {
31891+ auto sliceStart = SmallVector<int64_t>(slice.getStartIndices());
31892+ auto sliceLimit = SmallVector<int64_t>(slice.getLimitIndices());
31893+ auto sliceStrides = SmallVector<int64_t>(slice.getStrides());
31894+
31895+ // Check if strides match
31896+ if (sliceStrides != strides)
31897+ continue;
31898+
31899+ // Check if all dimensions except `dim` match exactly
31900+ bool matches = true;
31901+ for (int64_t d = 0; d < rank; d++) {
31902+ if (d == dim)
31903+ continue;
31904+ if (sliceStart[d] != startIndices[d] ||
31905+ sliceLimit[d] != limitIndices[d]) {
31906+ matches = false;
31907+ break;
31908+ }
31909+ }
31910+ if (!matches)
31911+ continue;
31912+
31913+ // Check if the slice shape along `dim` matches
31914+ int64_t thisSliceSize = limitIndices[dim] - startIndices[dim];
31915+ int64_t otherSliceSize = sliceLimit[dim] - sliceStart[dim];
31916+ if (thisSliceSize != otherSliceSize)
31917+ continue;
31918+
31919+ // Calculate offset relative to this slice (this slice has offset 0)
31920+ int64_t offset = sliceStart[dim] - startIndices[dim];
31921+ matchingSlices.push_back({slice, offset});
31922+ offsetSet.insert(offset);
31923+ }
31924+
31925+ // Need at least 2 matching slices along this dimension
31926+ if (matchingSlices.size() < 2)
31927+ continue;
31928+
31929+ // Sort offsets to find contiguous groups
31930+ SmallVector<int64_t> sortedOffsets(offsetSet.begin(), offsetSet.end());
31931+ llvm::sort(sortedOffsets);
31932+
31933+ // Find all contiguous groups (no gaps in offsets)
31934+ SmallVector<std::pair<int64_t, int64_t>>
31935+ contiguousGroups; // (start, end) inclusive
31936+ int64_t groupStart = sortedOffsets[0];
31937+ int64_t groupEnd = sortedOffsets[0];
31938+
31939+ for (size_t i = 1; i < sortedOffsets.size(); i++) {
31940+ if (sortedOffsets[i] == groupEnd + 1) {
31941+ // Extend current group
31942+ groupEnd = sortedOffsets[i];
31943+ } else {
31944+ // Save current group and start new one
31945+ contiguousGroups.push_back({groupStart, groupEnd});
31946+ groupStart = sortedOffsets[i];
31947+ groupEnd = sortedOffsets[i];
31948+ }
31949+ }
31950+ contiguousGroups.push_back({groupStart, groupEnd});
31951+
31952+ // Collect all offsets from groups that contain identity (0) or neighbor
31953+ // it
31954+ SmallVector<int64_t> qualifyingOffsets;
31955+
31956+ for (auto &[start, end] : contiguousGroups) {
31957+ bool containsIdentity = (start <= 0 && end >= 0);
31958+ bool neighborsIdentity = (end == -1 || start == 1);
31959+
31960+ if (containsIdentity || neighborsIdentity) {
31961+ for (int64_t o = start; o <= end; o++) {
31962+ if (offsetSet.contains(o)) {
31963+ qualifyingOffsets.push_back(o);
31964+ }
31965+ }
31966+ }
31967+ }
31968+
31969+ // No qualifying groups found
31970+ if (qualifyingOffsets.size() < 2)
31971+ continue;
31972+
31973+ // Determine the multiSlice range from qualifying offsets, extended to
31974+ // include 0
31975+ int64_t rangeStart =
31976+ *std::min_element(qualifyingOffsets.begin(), qualifyingOffsets.end());
31977+ int64_t rangeEnd =
31978+ *std::max_element(qualifyingOffsets.begin(), qualifyingOffsets.end());
31979+
31980+ // Extend to include identity (0) if not already included
31981+ if (rangeStart > 0)
31982+ rangeStart = 0;
31983+ if (rangeEnd < 0)
31984+ rangeEnd = 0;
31985+
31986+ // Check if this op (offset 0) is part of the selected range
31987+ if (0 < rangeStart || 0 > rangeEnd)
31988+ continue;
31989+
31990+ // Find the minimum offset among actual slices in the selected range
31991+ // to use as the canonical trigger (prevents processing same group
31992+ // multiple times)
31993+ int64_t minOffsetInRange = INT64_MAX;
31994+ int slicesToCombine = 0;
31995+ for (auto &[slice, offset] : matchingSlices) {
31996+ if (offset >= rangeStart && offset <= rangeEnd) {
31997+ minOffsetInRange = std::min(minOffsetInRange, offset);
31998+ slicesToCombine++;
31999+ }
32000+ }
32001+
32002+ // Only proceed if this op has the minimum offset in the range
32003+ if (0 != minOffsetInRange)
32004+ return failure();
32005+
32006+ // Need at least 2 slices in the selected range to combine
32007+ if (slicesToCombine < 2)
32008+ continue;
32009+
32010+ // Check that all slices in the range have the same sharding
32011+ std::optional<sdy::TensorShardingPerValueAttr> commonSharding;
32012+ bool shardingMismatch = false;
32013+ for (auto &[slice, offset] : matchingSlices) {
32014+ if (offset >= rangeStart && offset <= rangeEnd) {
32015+ auto shardPerValue = sdy::getShardingPerValue(slice);
32016+ if (!commonSharding.has_value()) {
32017+ commonSharding = shardPerValue;
32018+ } else if (commonSharding.value() != shardPerValue) {
32019+ shardingMismatch = true;
32020+ break;
32021+ }
32022+ }
32023+ }
32024+ if (shardingMismatch)
32025+ return failure();
32026+
32027+ // Calculate the shift amount
32028+ // leftShift covers negative offsets (slices shifted left/earlier)
32029+ int32_t leftShift = rangeStart < 0 ? (int32_t)(-rangeStart) : 0;
32030+ int32_t rightExtent = rangeEnd > 0 ? (int32_t)rangeEnd : 0;
32031+ int32_t amount = leftShift + rightExtent;
32032+ int32_t totalResults = amount + 1;
32033+
32034+ // Adjust start and limit indices by shifting left along the dimension
32035+ SmallVector<int64_t> adjustedStartIndices = startIndices;
32036+ SmallVector<int64_t> adjustedLimitIndices = limitIndices;
32037+ adjustedStartIndices[dim] -= leftShift;
32038+ adjustedLimitIndices[dim] -= leftShift;
32039+
32040+ // Create the MultiSliceOp
32041+ auto resultType = op.getResult().getType();
32042+ SmallVector<Type> resultTypes(totalResults, resultType);
32043+
32044+ rewriter.setInsertionPointAfterValue(input);
32045+ auto newOp = rewriter.create<enzymexla::MultiSliceOp>(
32046+ op.getLoc(), resultTypes, input, adjustedStartIndices,
32047+ adjustedLimitIndices, strides, (int32_t)dim, amount);
32048+
32049+ // Propagate sharding if present (all slices have the same sharding)
32050+ if (commonSharding.has_value() && commonSharding.value()) {
32051+ auto shardings = commonSharding.value().getShardings();
32052+ if (!shardings.empty()) {
32053+ sdy::TensorShardingAttr singleShard = shardings[0];
32054+ SmallVector<sdy::TensorShardingAttr> newShardings(totalResults,
32055+ singleShard);
32056+ sdy::setShardings(newOp, sdy::TensorShardingPerValueAttr::get(
32057+ op.getContext(), newShardings));
32058+ }
32059+ }
32060+
32061+ // Replace slices that fall within the selected range
32062+ for (auto &[slice, offset] : matchingSlices) {
32063+ if (offset >= rangeStart && offset <= rangeEnd) {
32064+ // Result index for offset 'o' is: leftShift + o
32065+ // (since we shifted start_indices left by leftShift)
32066+ int32_t resultIdx = leftShift + (int32_t)offset;
32067+ rewriter.replaceOp(slice, newOp.getResult(resultIdx));
32068+ }
32069+ // Slices outside the range are left unchanged
32070+ }
32071+
32072+ return success();
32073+ }
32074+
32075+ return failure();
32076+ }
32077+ };
32078+
3185432079// Pattern to reduce MultiSliceOp when some results are unused
3185532080struct ReduceUnusedMultiSlice final
3185632081 : CheckedOpRewritePattern<enzymexla::MultiSliceOp, ReduceUnusedMultiSlice> {
@@ -32129,6 +32354,24 @@ struct RecognizeMultiRotate
3212932354 if (rotatesToCombine < 2)
3213032355 return failure();
3213132356
32357+ // Check that all rotates in the range have the same sharding
32358+ std::optional<sdy::TensorShardingPerValueAttr> commonSharding;
32359+ bool shardingMismatch = false;
32360+ for (auto rotate : rotates) {
32361+ int32_t amt = rotate.getAmount();
32362+ if (amt >= rangeStart && amt <= rangeEnd) {
32363+ auto shardPerValue = sdy::getShardingPerValue(rotate);
32364+ if (!commonSharding.has_value()) {
32365+ commonSharding = shardPerValue;
32366+ } else if (commonSharding.value() != shardPerValue) {
32367+ shardingMismatch = true;
32368+ break;
32369+ }
32370+ }
32371+ }
32372+ if (shardingMismatch)
32373+ return failure();
32374+
3213232375 // Calculate left and right amounts for MultiRotateOp
3213332376 // leftAmount covers positive rotations (rotate left)
3213432377 // rightAmount covers negative rotations (rotate right)
@@ -32143,9 +32386,16 @@ struct RecognizeMultiRotate
3214332386 op.getDimensionAttr(), rewriter.getSI32IntegerAttr(leftAmount),
3214432387 rewriter.getSI32IntegerAttr(rightAmount));
3214532388
32146- // Propagate sharding if present
32147- if (auto shard = sdy::getShardingPerValue(op)) {
32148- sdy::setShardings(newOp, shard);
32389+ // Propagate sharding if present (all rotates have the same sharding)
32390+ if (commonSharding.has_value() && commonSharding.value()) {
32391+ auto shardings = commonSharding.value().getShardings();
32392+ if (!shardings.empty()) {
32393+ sdy::TensorShardingAttr singleShard = shardings[0];
32394+ SmallVector<sdy::TensorShardingAttr> newShardings(totalResults,
32395+ singleShard);
32396+ sdy::setShardings(newOp, sdy::TensorShardingPerValueAttr::get(
32397+ op.getContext(), newShardings));
32398+ }
3214932399 }
3215032400
3215132401 // Replace rotations that fall within the selected range
0 commit comments