@@ -30794,6 +30794,23 @@ struct RecognizeMultiSlice
3079430794 if (slicesToCombine < 2)
3079530795 continue;
3079630796
30797+ // Check that all slices in the range have the same sharding
30798+ std::optional<sdy::TensorShardingPerValueAttr> commonSharding;
30799+ bool shardingMismatch = false;
30800+ for (auto &[slice, offset] : matchingSlices) {
30801+ if (offset >= rangeStart && offset <= rangeEnd) {
30802+ auto shardPerValue = sdy::getShardingPerValue(slice);
30803+ if (!commonSharding.has_value()) {
30804+ commonSharding = shardPerValue;
30805+ } else if (commonSharding.value() != shardPerValue) {
30806+ shardingMismatch = true;
30807+ break;
30808+ }
30809+ }
30810+ }
30811+ if (shardingMismatch)
30812+ return failure();
30813+
3079730814 // Calculate the shift amount
3079830815 // leftShift covers negative offsets (slices shifted left/earlier)
3079930816 int32_t leftShift = rangeStart < 0 ? (int32_t)(-rangeStart) : 0;
@@ -30816,9 +30833,16 @@ struct RecognizeMultiSlice
3081630833 op.getLoc(), resultTypes, input, adjustedStartIndices,
3081730834 adjustedLimitIndices, strides, (int32_t)dim, amount);
3081830835
30819- // Propagate sharding if present
30820- if (auto shard = sdy::getShardingPerValue(op)) {
30821- sdy::setShardings(newOp, shard);
30836+ // Propagate sharding if present (all slices have the same sharding)
30837+ if (commonSharding.has_value() && commonSharding.value()) {
30838+ auto shardings = commonSharding.value().getShardings();
30839+ if (!shardings.empty()) {
30840+ sdy::TensorShardingAttr singleShard = shardings[0];
30841+ SmallVector<sdy::TensorShardingAttr> newShardings(totalResults,
30842+ singleShard);
30843+ sdy::setShardings(newOp, sdy::TensorShardingPerValueAttr::get(
30844+ op.getContext(), newShardings));
30845+ }
3082230846 }
3082330847
3082430848 // Replace slices that fall within the selected range
0 commit comments