@@ -30793,6 +30793,24 @@ struct RecognizeMultiSlice
3079330793 // Need at least 2 slices in the selected range to combine
3079430794 if (slicesToCombine < 2)
3079530795 continue;
30796+
30797+ // Check that all slices in the range have the same sharding
30798+ std::optional<sdy::TensorShardingPerValueAttr> commonSharding;
30799+ bool shardingMismatch = false;
30800+ for (auto &[slice, offset] : matchingSlices) {
30801+ if (offset >= rangeStart && offset <= rangeEnd) {
30802+ auto shardPerValue = sdy::getShardingPerValue(slice);
30803+ if (!commonSharding.has_value()) {
30804+ commonSharding = shardPerValue;
30805+ } else if (commonSharding.value() != shardPerValue) {
30806+ shardingMismatch = true;
30807+ break;
30808+ }
30809+ }
30810+ }
30811+ if (shardingMismatch)
30812+ return failure();
30813+
3079630814
3079730815 // Calculate the shift amount
3079830816 // leftShift covers negative offsets (slices shifted left/earlier)
@@ -30816,9 +30834,15 @@ struct RecognizeMultiSlice
3081630834 op.getLoc(), resultTypes, input, adjustedStartIndices,
3081730835 adjustedLimitIndices, strides, (int32_t)dim, amount);
3081830836
30819- // Propagate sharding if present
30820- if (auto shard = sdy::getShardingPerValue(op)) {
30821- sdy::setShardings(newOp, shard);
30837+ // Propagate sharding if present (all slices have the same sharding)
30838+ if (commonSharding.has_value() && commonSharding.value()) {
30839+ auto shardings = commonSharding.value().getShardings();
30840+ if (!shardings.empty()) {
30841+ sdy::TensorShardingAttr singleShard = shardings[0];
30842+ SmallVector<sdy::TensorShardingAttr> newShardings(totalResults, singleShard);
30843+ sdy::setShardings(newOp, sdy::TensorShardingPerValueAttr::get(
30844+ op.getContext(), newShardings));
30845+ }
3082230846 }
3082330847
3082430848 // Replace slices that fall within the selected range
0 commit comments