Skip to content

Commit a02f53d

Browse files
committed
try fix sharding handling
1 parent bfcfc0d commit a02f53d

File tree

1 file changed

+27
-3
lines changed

1 file changed

+27
-3
lines changed

src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30794,6 +30794,23 @@ struct RecognizeMultiSlice
3079430794
if (slicesToCombine < 2)
3079530795
continue;
3079630796

30797+
// Check that all slices in the range have the same sharding
30798+
std::optional<sdy::TensorShardingPerValueAttr> commonSharding;
30799+
bool shardingMismatch = false;
30800+
for (auto &[slice, offset] : matchingSlices) {
30801+
if (offset >= rangeStart && offset <= rangeEnd) {
30802+
auto shardPerValue = sdy::getShardingPerValue(slice);
30803+
if (!commonSharding.has_value()) {
30804+
commonSharding = shardPerValue;
30805+
} else if (commonSharding.value() != shardPerValue) {
30806+
shardingMismatch = true;
30807+
break;
30808+
}
30809+
}
30810+
}
30811+
if (shardingMismatch)
30812+
return failure();
30813+
3079730814
// Calculate the shift amount
3079830815
// leftShift covers negative offsets (slices shifted left/earlier)
3079930816
int32_t leftShift = rangeStart < 0 ? (int32_t)(-rangeStart) : 0;
@@ -30816,9 +30833,16 @@ struct RecognizeMultiSlice
3081630833
op.getLoc(), resultTypes, input, adjustedStartIndices,
3081730834
adjustedLimitIndices, strides, (int32_t)dim, amount);
3081830835

30819-
// Propagate sharding if present
30820-
if (auto shard = sdy::getShardingPerValue(op)) {
30821-
sdy::setShardings(newOp, shard);
30836+
// Propagate sharding if present (all slices have the same sharding)
30837+
if (commonSharding.has_value() && commonSharding.value()) {
30838+
auto shardings = commonSharding.value().getShardings();
30839+
if (!shardings.empty()) {
30840+
sdy::TensorShardingAttr singleShard = shardings[0];
30841+
SmallVector<sdy::TensorShardingAttr> newShardings(totalResults,
30842+
singleShard);
30843+
sdy::setShardings(newOp, sdy::TensorShardingPerValueAttr::get(
30844+
op.getContext(), newShardings));
30845+
}
3082230846
}
3082330847

3082430848
// Replace slices that fall within the selected range

0 commit comments

Comments
 (0)