Skip to content

Commit e60c8c2

Browse files
jumerckxwsmoses
authored andcommitted
try fix sharding handling
1 parent 21d3b2d commit e60c8c2

File tree

1 file changed

+27
-3
lines changed

1 file changed

+27
-3
lines changed

src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32007,6 +32007,23 @@ struct RecognizeMultiSlice
3200732007
if (slicesToCombine < 2)
3200832008
continue;
3200932009

32010+
// Check that all slices in the range have the same sharding
32011+
std::optional<sdy::TensorShardingPerValueAttr> commonSharding;
32012+
bool shardingMismatch = false;
32013+
for (auto &[slice, offset] : matchingSlices) {
32014+
if (offset >= rangeStart && offset <= rangeEnd) {
32015+
auto shardPerValue = sdy::getShardingPerValue(slice);
32016+
if (!commonSharding.has_value()) {
32017+
commonSharding = shardPerValue;
32018+
} else if (commonSharding.value() != shardPerValue) {
32019+
shardingMismatch = true;
32020+
break;
32021+
}
32022+
}
32023+
}
32024+
if (shardingMismatch)
32025+
return failure();
32026+
3201032027
// Calculate the shift amount
3201132028
// leftShift covers negative offsets (slices shifted left/earlier)
3201232029
int32_t leftShift = rangeStart < 0 ? (int32_t)(-rangeStart) : 0;
@@ -32029,9 +32046,16 @@ struct RecognizeMultiSlice
3202932046
op.getLoc(), resultTypes, input, adjustedStartIndices,
3203032047
adjustedLimitIndices, strides, (int32_t)dim, amount);
3203132048

32032-
// Propagate sharding if present
32033-
if (auto shard = sdy::getShardingPerValue(op)) {
32034-
sdy::setShardings(newOp, shard);
32049+
// Propagate sharding if present (all slices have the same sharding)
32050+
if (commonSharding.has_value() && commonSharding.value()) {
32051+
auto shardings = commonSharding.value().getShardings();
32052+
if (!shardings.empty()) {
32053+
sdy::TensorShardingAttr singleShard = shardings[0];
32054+
SmallVector<sdy::TensorShardingAttr> newShardings(totalResults,
32055+
singleShard);
32056+
sdy::setShardings(newOp, sdy::TensorShardingPerValueAttr::get(
32057+
op.getContext(), newShardings));
32058+
}
3203532059
}
3203632060

3203732061
// Replace slices that fall within the selected range

0 commit comments

Comments
 (0)