Skip to content

Commit 6a1ec6d

Browse files
jumerckxwsmoses
authored andcommitted
also update RecognizeMultiRotate
1 parent e60c8c2 commit 6a1ec6d

File tree

1 file changed

+28
-3
lines changed

1 file changed

+28
-3
lines changed

src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32354,6 +32354,24 @@ struct RecognizeMultiRotate
3235432354
if (rotatesToCombine < 2)
3235532355
return failure();
3235632356

32357+
// Check that all rotates in the range have the same sharding
32358+
std::optional<sdy::TensorShardingPerValueAttr> commonSharding;
32359+
bool shardingMismatch = false;
32360+
for (auto rotate : rotates) {
32361+
int32_t amt = rotate.getAmount();
32362+
if (amt >= rangeStart && amt <= rangeEnd) {
32363+
auto shardPerValue = sdy::getShardingPerValue(rotate);
32364+
if (!commonSharding.has_value()) {
32365+
commonSharding = shardPerValue;
32366+
} else if (commonSharding.value() != shardPerValue) {
32367+
shardingMismatch = true;
32368+
break;
32369+
}
32370+
}
32371+
}
32372+
if (shardingMismatch)
32373+
return failure();
32374+
3235732375
// Calculate left and right amounts for MultiRotateOp
3235832376
// leftAmount covers positive rotations (rotate left)
3235932377
// rightAmount covers negative rotations (rotate right)
@@ -32368,9 +32386,16 @@ struct RecognizeMultiRotate
3236832386
op.getDimensionAttr(), rewriter.getSI32IntegerAttr(leftAmount),
3236932387
rewriter.getSI32IntegerAttr(rightAmount));
3237032388

32371-
// Propagate sharding if present
32372-
if (auto shard = sdy::getShardingPerValue(op)) {
32373-
sdy::setShardings(newOp, shard);
32389+
// Propagate sharding if present (all rotates have the same sharding)
32390+
if (commonSharding.has_value() && commonSharding.value()) {
32391+
auto shardings = commonSharding.value().getShardings();
32392+
if (!shardings.empty()) {
32393+
sdy::TensorShardingAttr singleShard = shardings[0];
32394+
SmallVector<sdy::TensorShardingAttr> newShardings(totalResults,
32395+
singleShard);
32396+
sdy::setShardings(newOp, sdy::TensorShardingPerValueAttr::get(
32397+
op.getContext(), newShardings));
32398+
}
3237432399
}
3237532400

3237632401
// Replace rotations that fall within the selected range

0 commit comments

Comments
 (0)