Skip to content

Commit 6cce4ac

Browse files
committed
try fix sharding handling
1 parent bfcfc0d commit 6cce4ac

File tree

1 file changed

+27
-3
lines changed

1 file changed

+27
-3
lines changed

src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30793,6 +30793,24 @@ struct RecognizeMultiSlice
3079330793
// Need at least 2 slices in the selected range to combine
3079430794
if (slicesToCombine < 2)
3079530795
continue;
30796+
30797+
// Check that all slices in the range have the same sharding
30798+
std::optional<sdy::TensorShardingPerValueAttr> commonSharding;
30799+
bool shardingMismatch = false;
30800+
for (auto &[slice, offset] : matchingSlices) {
30801+
if (offset >= rangeStart && offset <= rangeEnd) {
30802+
auto shardPerValue = sdy::getShardingPerValue(slice);
30803+
if (!commonSharding.has_value()) {
30804+
commonSharding = shardPerValue;
30805+
} else if (commonSharding.value() != shardPerValue) {
30806+
shardingMismatch = true;
30807+
break;
30808+
}
30809+
}
30810+
}
30811+
if (shardingMismatch)
30812+
return failure();
30813+
3079630814

3079730815
// Calculate the shift amount
3079830816
// leftShift covers negative offsets (slices shifted left/earlier)
@@ -30816,9 +30834,15 @@ struct RecognizeMultiSlice
3081630834
op.getLoc(), resultTypes, input, adjustedStartIndices,
3081730835
adjustedLimitIndices, strides, (int32_t)dim, amount);
3081830836

30819-
// Propagate sharding if present
30820-
if (auto shard = sdy::getShardingPerValue(op)) {
30821-
sdy::setShardings(newOp, shard);
30837+
// Propagate sharding if present (all slices have the same sharding)
30838+
if (commonSharding.has_value() && commonSharding.value()) {
30839+
auto shardings = commonSharding.value().getShardings();
30840+
if (!shardings.empty()) {
30841+
sdy::TensorShardingAttr singleShard = shardings[0];
30842+
SmallVector<sdy::TensorShardingAttr> newShardings(totalResults, singleShard);
30843+
sdy::setShardings(newOp, sdy::TensorShardingPerValueAttr::get(
30844+
op.getContext(), newShardings));
30845+
}
3082230846
}
3082330847

3082430848
// Replace slices that fall within the selected range

0 commit comments

Comments
 (0)