@@ -32354,6 +32354,24 @@ struct RecognizeMultiRotate
3235432354 if (rotatesToCombine < 2)
3235532355 return failure();
3235632356
32357+ // Check that all rotates in the range have the same sharding
32358+ std::optional<sdy::TensorShardingPerValueAttr> commonSharding;
32359+ bool shardingMismatch = false;
32360+ for (auto rotate : rotates) {
32361+ int32_t amt = rotate.getAmount();
32362+ if (amt >= rangeStart && amt <= rangeEnd) {
32363+ auto shardPerValue = sdy::getShardingPerValue(rotate);
32364+ if (!commonSharding.has_value()) {
32365+ commonSharding = shardPerValue;
32366+ } else if (commonSharding.value() != shardPerValue) {
32367+ shardingMismatch = true;
32368+ break;
32369+ }
32370+ }
32371+ }
32372+ if (shardingMismatch)
32373+ return failure();
32374+
3235732375 // Calculate left and right amounts for MultiRotateOp
3235832376 // leftAmount covers positive rotations (rotate left)
3235932377 // rightAmount covers negative rotations (rotate right)
@@ -32368,9 +32386,16 @@ struct RecognizeMultiRotate
3236832386 op.getDimensionAttr(), rewriter.getSI32IntegerAttr(leftAmount),
3236932387 rewriter.getSI32IntegerAttr(rightAmount));
3237032388
32371- // Propagate sharding if present
32372- if (auto shard = sdy::getShardingPerValue(op)) {
32373- sdy::setShardings(newOp, shard);
32389+ // Propagate sharding if present (all rotates have the same sharding)
32390+ if (commonSharding.has_value() && commonSharding.value()) {
32391+ auto shardings = commonSharding.value().getShardings();
32392+ if (!shardings.empty()) {
32393+ sdy::TensorShardingAttr singleShard = shardings[0];
32394+ SmallVector<sdy::TensorShardingAttr> newShardings(totalResults,
32395+ singleShard);
32396+ sdy::setShardings(newOp, sdy::TensorShardingPerValueAttr::get(
32397+ op.getContext(), newShardings));
32398+ }
3237432399 }
3237532400
3237632401 // Replace rotations that fall within the selected range
0 commit comments