Skip to content

Commit c801fb3

Browse files
jumerckxwsmoses
andauthored
Only lower multislice to custom call if efficient lowering is possible (#2187)
* Add detectCrossShardPattern * cleanup * move detection to multislice lowering instead * update test * remove previous lowering test * relax shard size constraint * include backend_config in lit tests * _SPMDEnzymeInternalOp_MultiSlice --> _SPMDOp_MultiSlice * get rid of getNumShardsAlongDim * fixup * fix * fix * fix * fix api * Update override_commit to new SHA value * Add 'mspp' to reactant_commit in workflow * Comment out linux-x86-ct6e-180-4tpu in workflow Comment out one of the OS options in the workflow matrix. * Update workspace.bzl * Update workspace.bzl * create pre-slice operation if a sliced dimension lives on one device * test * rename variable * fix * fmt * use proper upstream commit * fix * fix? * add printer * more wrap print * tostring * fix * Update test-gb-25.yml * Change gb25_commit branch from 'main' to 'wsmoses-patch-6' * Update XLA_FLAGS to include HLO pass regex * now shardy patch * fix * change shardy * change xla * fix * fix * fix * fix * ms --------- Co-authored-by: William S. Moses <gh@wsmoses.com> Co-authored-by: William Moses <wmoses@google.com>
1 parent e5480b4 commit c801fb3

File tree

13 files changed

+226
-82
lines changed

13 files changed

+226
-82
lines changed

.github/workflows/test-gb-25.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,7 @@ jobs:
201201
ALL_TO_ALL_THRESHOLD=0
202202
ALL_GATHER_THRESHOLD=0
203203
ALL_REDUCE_THRESHOLD=0
204-
COLLECTIVE_PERMUTE_THRESHOLD=339
204+
COLLECTIVE_PERMUTE_THRESHOLD=345
205205
elif [[ '${{ contains(matrix.os, 'tpu') }}' == 'true' ]]; then
206206
ALL_TO_ALL_THRESHOLD=0
207207
ALL_GATHER_THRESHOLD=0

patches/xla_spmd.patch

Lines changed: 0 additions & 15 deletions
This file was deleted.

src/enzyme_ad/jax/Dialect/Ops.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1551,7 +1551,7 @@ LogicalResult fixupGetFunc(LLVM::CallOp op, OpBuilder &rewriter,
15511551
struct NoopResource : public SideEffects::Resource::Base<NoopResource> {
15521552
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(NoopResource)
15531553

1554-
StringRef getName() final { return "<NoopResource>"; }
1554+
StringRef getName() const final { return "<NoopResource>"; }
15551555
};
15561556

15571557
void NoopOp::build(OpBuilder &builder, OperationState &result,

src/enzyme_ad/jax/Implementations/TritonDerivatives.td

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@ def : TritonInactiveOp<"MakeRangeOp">;
2525
def : TritonInactiveOp<"PrintOp">;
2626

2727
def : ReadOnlyIdentityOp<"triton", "AddPtrOp", [0]>;
28-
def : ReadOnlyIdentityOp<"triton", "AdvanceOp", [0]>;
2928
def : ReadOnlyIdentityOp<"triton", "LoadOp", [0]>;
3029
def : ReadOnlyIdentityOp<"triton", "SplatOp", [0]>;
3130
def : MemoryIdentityOp<"triton", "StoreOp", [1], [0]>;

src/enzyme_ad/jax/Passes/OptimizeCommunication.cpp

Lines changed: 155 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2261,6 +2261,62 @@ struct MultiRotateCustomCallOptimize
22612261
}
22622262
};
22632263

2264+
/// Detect whether this MultiSliceOp matches the cross-shard pattern:
2265+
/// 1. All strides are 1.
2266+
/// 2. For every sharded dimension except the multi-slice dimension,
2267+
/// start/limit span the full tensor extent.
2268+
/// 3. Along the multi-slice dimension, every slice's start falls within
2269+
/// one shard and its end falls within a different shard.
2270+
bool detectCrossShardPattern(Value operand, Operation *op,
2271+
ArrayRef<int64_t> startIndices,
2272+
ArrayRef<int64_t> limitIndices,
2273+
ArrayRef<int64_t> strides, int32_t dim,
2274+
int32_t amount, bool &needsSlice) {
2275+
// --- Condition 1: unit strides everywhere ---
2276+
if (!llvm::all_of(strides, [](int64_t s) { return s == 1; }))
2277+
return false;
2278+
2279+
auto operandType = cast<RankedTensorType>(operand.getType());
2280+
auto operandSharding = mlir::sdy::getSharding(operand);
2281+
if (!operandSharding) {
2282+
return false;
2283+
}
2284+
ArrayRef<int64_t> shape = operandType.getShape();
2285+
int64_t rank = shape.size();
2286+
2287+
if (dim < 0 || dim >= rank)
2288+
return false;
2289+
2290+
// --- Condition 2: full span on every sharded dim except `dim` ---
2291+
for (int64_t d = 0; d < rank; ++d) {
2292+
if (d == dim)
2293+
continue;
2294+
int64_t numShards = getNumDevicesAlongDimension(operandSharding, d, op);
2295+
if (startIndices[d] != 0 || limitIndices[d] != shape[d]) {
2296+
needsSlice = true;
2297+
if (numShards > 1) {
2298+
return false;
2299+
}
2300+
}
2301+
}
2302+
2303+
// --- Condition 3: cross-shard slicing along `dim` ---
2304+
int64_t numShards = getNumDevicesAlongDimension(operandSharding, dim, op);
2305+
if (numShards <= 1)
2306+
return false; // Not sharded along the slice dimension.
2307+
2308+
int64_t dimSize = shape[dim];
2309+
int64_t shardSize = (dimSize + numShards - 1) / numShards;
2310+
2311+
if (startIndices[dim] > shardSize) {
2312+
return false;
2313+
}
2314+
if (shape[dim] - limitIndices[dim] > shardSize) {
2315+
return false;
2316+
}
2317+
return true;
2318+
}
2319+
22642320
struct MultiSliceCustomCallOptimize
22652321
: public OpRewritePattern<enzymexla::MultiSliceOp> {
22662322

@@ -2283,40 +2339,123 @@ struct MultiSliceCustomCallOptimize
22832339
if (slice->getParentOfType<sdy::ManualComputationOp>())
22842340
return failure();
22852341

2286-
auto rotateDimension = slice.getDimension();
2342+
auto sliceDimension = slice.getDimension();
22872343
auto shardings = mlir::sdy::getShardingPerValue(slice);
22882344
if (!shardings)
22892345
return rewriter.notifyMatchFailure(slice, "No sharding found.");
2290-
auto rotateSharding = shardings.getSharding(0);
2346+
auto sliceSharding = shardings.getSharding(0);
2347+
for (int64_t i = 1; i < slice.getNumResults(); ++i) {
2348+
if (shardings.getSharding(i) != sliceSharding)
2349+
return rewriter.notifyMatchFailure(
2350+
slice, "Not all results have the same sharding");
2351+
}
22912352

22922353
int64_t numDevicesAlongDimension =
2293-
getNumDevicesAlongDimension(rotateSharding, rotateDimension, slice);
2354+
getNumDevicesAlongDimension(sliceSharding, sliceDimension, slice);
22942355

22952356
if (numDevicesAlongDimension == 1) {
22962357
return rewriter.notifyMatchFailure(
22972358
slice,
22982359
"numDevicesAlongDimension == 1. Communication is already optimized.");
22992360
}
23002361

2301-
std::string start_indices =
2302-
serializeDenseI64ArrayAttr(slice.getStartIndices());
2303-
std::string limit_indices =
2304-
serializeDenseI64ArrayAttr(slice.getLimitIndices());
2305-
std::string strides = serializeDenseI64ArrayAttr(slice.getStrides());
2362+
Value customCallOperand = slice.getOperand();
2363+
auto operandSharding = mlir::sdy::getSharding(customCallOperand);
2364+
if (!operandSharding) {
2365+
return rewriter.notifyMatchFailure(slice, "No operand shardings");
2366+
}
2367+
if (sliceSharding != operandSharding) {
2368+
return rewriter.notifyMatchFailure(slice,
2369+
"Mismatched input/output sharding");
2370+
}
23062371

2307-
std::string opaque = "dimension=" + std::to_string(rotateDimension) +
2372+
// Only lower to custom call if the cross-shard pattern is detected.
2373+
auto startIndices = SmallVector<int64_t>(slice.getStartIndices());
2374+
auto limitIndices = SmallVector<int64_t>(slice.getLimitIndices());
2375+
auto strideVals = SmallVector<int64_t>(slice.getStrides());
2376+
bool needs_slice = false;
2377+
if (!detectCrossShardPattern(customCallOperand, slice, startIndices,
2378+
limitIndices, strideVals, sliceDimension,
2379+
slice.getAmount(), needs_slice))
2380+
return rewriter.notifyMatchFailure(
2381+
slice, "MultiSlice does not match cross-shard pattern.");
2382+
2383+
// --- Replace the needs_slice bail-out and custom-call emission with this:
2384+
// ---
2385+
2386+
SmallVector<int64_t> finalStartIndices(startIndices);
2387+
SmallVector<int64_t> finalLimitIndices(limitIndices);
2388+
SmallVector<int64_t> finalStrides(strideVals);
2389+
2390+
if (needs_slice) {
2391+
// Emit a preliminary stablehlo::SliceOp that trims replicated
2392+
// (unsharded) dimensions down to the requested range, so that
2393+
// the MultiSlice custom call afterwards spans the full axis on
2394+
// every dimension except `dim`.
2395+
auto operandType = cast<RankedTensorType>(customCallOperand.getType());
2396+
ArrayRef<int64_t> shape = operandType.getShape();
2397+
int64_t rank = shape.size();
2398+
2399+
auto operandSharding = sdy::getSharding(slice.getOperand());
2400+
2401+
SmallVector<int64_t> preStart(rank);
2402+
SmallVector<int64_t> preLimit(rank);
2403+
SmallVector<int64_t> preStrides(rank, 1);
2404+
2405+
for (int64_t d = 0; d < rank; ++d) {
2406+
if (d == sliceDimension) {
2407+
// Keep the full extent along the multi-slice dimension;
2408+
// the custom call handles cross-shard slicing there.
2409+
preStart[d] = 0;
2410+
preLimit[d] = shape[d];
2411+
} else {
2412+
int64_t numShards =
2413+
getNumDevicesAlongDimension(operandSharding, d, slice);
2414+
if (numShards <= 1 &&
2415+
(startIndices[d] != 0 || limitIndices[d] != shape[d])) {
2416+
// Replicated dim that doesn't span the full tensor —
2417+
// slice it now so the custom call can assume full extent.
2418+
preStart[d] = startIndices[d];
2419+
preLimit[d] = limitIndices[d];
2420+
// After pre-slicing, the custom call sees [0, newSize).
2421+
finalStartIndices[d] = 0;
2422+
finalLimitIndices[d] = limitIndices[d] - startIndices[d];
2423+
} else {
2424+
preStart[d] = 0;
2425+
preLimit[d] = shape[d];
2426+
}
2427+
}
2428+
}
2429+
2430+
auto preSliceOp = rewriter.create<stablehlo::SliceOp>(
2431+
slice.getLoc(), customCallOperand, preStart, preLimit, preStrides);
2432+
2433+
SmallVector<TensorShardingAttr> opShardings(1, sliceSharding);
2434+
sdy::setShardings(preSliceOp, TensorShardingPerValueAttr::get(
2435+
rewriter.getContext(), opShardings));
2436+
2437+
customCallOperand = preSliceOp.getResult();
2438+
}
2439+
2440+
std::string start_indices_str =
2441+
serializeDenseI64ArrayAttr(finalStartIndices);
2442+
std::string limit_indices_str =
2443+
serializeDenseI64ArrayAttr(finalLimitIndices);
2444+
std::string strides_str = serializeDenseI64ArrayAttr(finalStrides);
2445+
2446+
std::string opaque = "dimension=" + std::to_string(sliceDimension) +
23082447
",amount=" + std::to_string(slice.getAmount()) +
2309-
",start_indices=" + start_indices +
2310-
",limit_indices=" + limit_indices +
2311-
",strides=" + strides;
2448+
",start_indices=" + start_indices_str +
2449+
",limit_indices=" + limit_indices_str +
2450+
",strides=" + strides_str;
23122451

2313-
auto fnSym = rewriter.getStringAttr("_SPMDEnzymeInternalOp_MultiSlice");
2452+
auto fnSym = rewriter.getStringAttr("_SPMDInternalOp_MultiSlice");
23142453

2315-
SmallVector<TensorShardingAttr> opShardings(slice.getNumResults(),
2316-
rotateSharding);
2454+
SmallVector<TensorShardingAttr> opShardings(slice.getAmount() + 1,
2455+
sliceSharding);
23172456

23182457
auto ccall = rewriter.replaceOpWithNewOp<stablehlo::CustomCallOp>(
2319-
slice, slice->getResultTypes(), slice->getOperands(), fnSym,
2458+
slice, slice->getResultTypes(), ValueRange{customCallOperand}, fnSym,
23202459
/*has_side_effect=*/rewriter.getBoolAttr(false),
23212460
/*backend_config=*/rewriter.getStringAttr(opaque),
23222461
/*api_version=*/nullptr,

src/enzyme_ad/jax/clang_compile.cc

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -505,7 +505,6 @@ struct tensor<T, n0, N...>
505505

506506
DiagsBuffer->FlushDiagnostics(Clang->getDiagnostics());
507507
if (!Success) {
508-
Clang->getDiagnosticClient().finish();
509508
llvm::errs() << " failed diag\n";
510509
return {};
511510
}
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
// RUN: enzymexlamlir-opt %s --optimize-communication="multislice_custom_call=1" | FileCheck %s
2+
3+
module {
4+
sdy.mesh @mesh = <["a"=2]>
5+
func.func public @main(%arg0: tensor<10xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"a", ?}]>}) -> (tensor<7xf32>, tensor<7xf32>, tensor<7xf32>) {
6+
%1, %2, %3 = "enzymexla.multi_slice"(%arg0) {
7+
dimension = 0 : i32,
8+
amount = 2 : i32,
9+
start_indices = array<i64: 0>,
10+
limit_indices = array<i64: 7>,
11+
strides = array<i64: 1>,
12+
sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"a", ?}]>, <@mesh, [{"a", ?}]>, <@mesh, [{"a", ?}]>]>
13+
} : (tensor<10xf32>) -> (tensor<7xf32>, tensor<7xf32>, tensor<7xf32>)
14+
return %1, %2, %3 : tensor<7xf32>, tensor<7xf32>, tensor<7xf32>
15+
}
16+
}
17+
18+
// CHECK: func.func public @main(%arg0: tensor<10xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"a", ?}]>}) -> (tensor<7xf32>, tensor<7xf32>, tensor<7xf32>) {
19+
// CHECK-NEXT: %0:3 = stablehlo.custom_call @_SPMDInternalOp_MultiSlice(%arg0) {backend_config = "dimension=0,amount=2,start_indices=[0],limit_indices=[7],strides=[1]", sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"a", ?}]>, <@mesh, [{"a", ?}]>, <@mesh, [{"a", ?}]>]>} : (tensor<10xf32>) -> (tensor<7xf32>, tensor<7xf32>, tensor<7xf32>)
20+
// CHECK-NEXT: return %0#0, %0#1, %0#2 : tensor<7xf32>, tensor<7xf32>, tensor<7xf32>
21+
// CHECK-NEXT: }
22+
23+
24+
module {
25+
sdy.mesh @mesh = <["x"=2, "y"=2]>
26+
func.func public @main(%arg0: tensor<20x1536x3056xf64> {sdy.sharding = #sdy.sharding<@mesh, [{}, {"y"}, {"x"}]>}) -> (tensor<4x1520x3056xf64>, tensor<4x1520x3056xf64>) {
27+
%0:2 = "enzymexla.multi_slice"(%arg0) {
28+
amount = 1 : i32,
29+
dimension = 1 : i32,
30+
limit_indices = array<i64: 12, 1529, 3056>,
31+
start_indices = array<i64: 8, 9, 0>,
32+
strides = array<i64: 1, 1, 1>,
33+
sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{}, {"y"}, {"x"}]>, <@mesh, [{}, {"y"}, {"x"}]>]>
34+
} : (tensor<20x1536x3056xf64>) -> (tensor<4x1520x3056xf64>, tensor<4x1520x3056xf64>)
35+
return %0#0, %0#1 : tensor<4x1520x3056xf64>, tensor<4x1520x3056xf64>
36+
}
37+
}
38+
39+
// CHECK: func.func public @main(%arg0: tensor<20x1536x3056xf64> {sdy.sharding = #sdy.sharding<@mesh, [{}, {"y"}, {"x"}]>}) -> (tensor<4x1520x3056xf64>, tensor<4x1520x3056xf64>) {
40+
// CHECK-NEXT: %[[SLICE:.*]] = stablehlo.slice %arg0 [8:12, 0:1536, 0:3056] {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{}, {"y"}, {"x"}]>]>} : (tensor<20x1536x3056xf64>) -> tensor<4x1536x3056xf64>
41+
// CHECK-NEXT: %[[CC:.*]]:2 = stablehlo.custom_call @_SPMDInternalOp_MultiSlice(%[[SLICE]]) {backend_config = "dimension=1,amount=1,start_indices=[0, 9, 0],limit_indices=[4, 1529, 3056],strides=[1, 1, 1]", sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{}, {"y"}, {"x"}]>, <@mesh, [{}, {"y"}, {"x"}]>]>} : (tensor<4x1536x3056xf64>) -> (tensor<4x1520x3056xf64>, tensor<4x1520x3056xf64>)
42+
// CHECK-NEXT: return %[[CC]]#0, %[[CC]]#1 : tensor<4x1520x3056xf64>, tensor<4x1520x3056xf64>
43+
// CHECK-NEXT: }

test/lit_tests/lower_multislice_custom_call.mlir

Lines changed: 0 additions & 18 deletions
This file was deleted.

test/lit_tests/parallel-lower-inline.mlir

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,14 +23,14 @@ module {
2323
// CHECK: scf.parallel
2424
// CHECK: memref.alloca_scope {
2525
// CHECK: scf.execute_region {
26-
// CHECK-DAG: %[[a1:.*]] = llvm.alloca %0 x !llvm.struct<(i8)> {alignment = 1 : i64} : (i64) -> !llvm.ptr
2726
// CHECK-DAG: %[[a2:.*]] = llvm.alloca %0 x !llvm.struct<(i8)> {alignment = 1 : i64} : (i64) -> !llvm.ptr
2827
// CHECK: llvm.store %[[ld]], %[[a2]] : !llvm.struct<(i8)>, !llvm.ptr
2928
// CHECK: memref.alloca_scope {
29+
// CHECK-DAG: %[[a1:.*]] = llvm.alloca %0 x !llvm.struct<(i8)> {alignment = 1 : i64} : (i64) -> !llvm.ptr
3030
// CHECK: scf.execute_region {
31-
// CHECK: "llvm.intr.memcpy"(%[[a1]], %[[a2]], %0) <{isVolatile = false}> : (!llvm.ptr, !llvm.ptr, i64) -> ()
32-
// CHECK: %[[rld:.*]] = llvm.load %4 : !llvm.ptr -> !llvm.struct<(i8)>
33-
// CHECK: llvm.store %6, %arg5 : !llvm.struct<(i8)>, !llvm.ptr
31+
// CHECK: "llvm.intr.memcpy"(%[[a1]], %[[a2]], %0) <{arg_attrs = [{llvm.align = 1 : i64}], isVolatile = false}> : (!llvm.ptr, !llvm.ptr, i64) -> ()
32+
// CHECK: %[[rld:.*]] = llvm.load %[[a1]] : !llvm.ptr -> !llvm.struct<(i8)>
33+
// CHECK: llvm.store %[[rld]], %arg5 : !llvm.struct<(i8)>, !llvm.ptr
3434
// CHECK: scf.yield
3535
// CHECK: }
3636
// CHECK: }

0 commit comments

Comments
 (0)