@@ -32007,6 +32007,23 @@ struct RecognizeMultiSlice
3200732007 if (slicesToCombine < 2)
3200832008 continue;
3200932009
32010+ // Check that all slices in the range have the same sharding
32011+ std::optional<sdy::TensorShardingPerValueAttr> commonSharding;
32012+ bool shardingMismatch = false;
32013+ for (auto &[slice, offset] : matchingSlices) {
32014+ if (offset >= rangeStart && offset <= rangeEnd) {
32015+ auto shardPerValue = sdy::getShardingPerValue(slice);
32016+ if (!commonSharding.has_value()) {
32017+ commonSharding = shardPerValue;
32018+ } else if (commonSharding.value() != shardPerValue) {
32019+ shardingMismatch = true;
32020+ break;
32021+ }
32022+ }
32023+ }
32024+ if (shardingMismatch)
32025+ return failure();
32026+
3201032027 // Calculate the shift amount
3201132028 // leftShift covers negative offsets (slices shifted left/earlier)
3201232029 int32_t leftShift = rangeStart < 0 ? (int32_t)(-rangeStart) : 0;
@@ -32029,9 +32046,16 @@ struct RecognizeMultiSlice
3202932046 op.getLoc(), resultTypes, input, adjustedStartIndices,
3203032047 adjustedLimitIndices, strides, (int32_t)dim, amount);
3203132048
32032- // Propagate sharding if present
32033- if (auto shard = sdy::getShardingPerValue(op)) {
32034- sdy::setShardings(newOp, shard);
32049+ // Propagate sharding if present (all slices have the same sharding)
32050+ if (commonSharding.has_value() && commonSharding.value()) {
32051+ auto shardings = commonSharding.value().getShardings();
32052+ if (!shardings.empty()) {
32053+ sdy::TensorShardingAttr singleShard = shardings[0];
32054+ SmallVector<sdy::TensorShardingAttr> newShardings(totalResults,
32055+ singleShard);
32056+ sdy::setShardings(newOp, sdy::TensorShardingPerValueAttr::get(
32057+ op.getContext(), newShardings));
32058+ }
3203532059 }
3203632060
3203732061 // Replace slices that fall within the selected range
0 commit comments