Skip to content

Commit da4af7d

Browse files
committed
Demo pass eliminate constant communication
1 parent e9c6e3c commit da4af7d

File tree

10 files changed

+240
-45
lines changed

10 files changed

+240
-45
lines changed

src/enzyme_ad/jax/Dialect/Distributed/Dialect.h

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,4 +21,29 @@
2121
#define GET_OP_CLASSES
2222
#include "src/enzyme_ad/jax/Dialect/Distributed/DistributedOps.h.inc"
2323

24+
/**
25+
* Convenience class to manage tokens, which are sometimes used as
26+
* block args and other time as typed values.
27+
*/
28+
namespace mlir::enzyme::distributed {
29+
class Token {
30+
mlir::TypedValue<TokenType> typedValue;
31+
mlir::BlockArgument blockArg;
32+
33+
public:
34+
Token(mlir::BlockArgument arg) : blockArg(arg) {
35+
typedValue = dyn_cast<mlir::TypedValue<TokenType>>(arg);
36+
assert(typedValue && "Block arg is not a token");
37+
}
38+
Token(mlir::TypedValue<TokenType> val) : typedValue(val) {
39+
assert(val && "Typed value is null");
40+
blockArg = dyn_cast<mlir::BlockArgument>(val);
41+
assert(blockArg && "Typed value is not a block argument");
42+
}
43+
44+
const mlir::TypedValue<TokenType> asTypedValue() const { return typedValue; }
45+
const mlir::BlockArgument asBlockArg() const { return blockArg; }
46+
};
47+
} // namespace mlir::enzyme::distributed
48+
2449
#endif // ENZYME_AD_JAX_DIALECT_DISTRIBUTED_DIALECT_H

src/enzyme_ad/jax/Dialect/Distributed/Ops.td

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -101,11 +101,4 @@ def RecvOp : DistributedOp<"recv", [DeclareOpInterfaceMethods<TokenReaderOpInter
101101
let assemblyFormat = "$token type($value) attr-dict";
102102
}
103103

104-
def NoopOp : DistributedOp<"noop", []>{
105-
let description = [{
106-
A placeholder no-op.
107-
}];
108-
let assemblyFormat = "attr-dict";
109-
}
110-
111104
#endif // ENZYME_AD_JAX_DIALECT_DISTRIBUTED_OPS_TD

src/enzyme_ad/jax/Dialect/Distributed/Utils.cpp

Lines changed: 48 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
#include "Utils.h"
2+
#include "Dialect.h"
23
namespace mlir::enzyme::distributed {
4+
35
Region *getEnclosingDeviceParallelBranch(DeviceParallelOp parent,
46
Operation *op) {
57
auto region = op->getParentRegion();
68
while (region->getParentOp() != parent) {
79
auto region_parent =
8-
region->getParentOp(); // All regoins have parent ops...
10+
region->getParentOp(); // All regions have parent ops...
911
if (!region_parent->getParentRegion()) // But not all ops have parent
1012
// regions (e.g. top level ops)
1113
return nullptr;
@@ -38,25 +40,60 @@ mlir::Operation *getExecutingDevice(mlir::Operation *op) {
3840
return SymbolTable::lookupNearestSymbolFrom(parent, device_sym);
3941
}
4042

41-
llvm::SmallVector<mlir::BlockArgument>
42-
getCorrespondingTokens(mlir::BlockArgument token) {
43-
unsigned idx = token.getArgNumber();
44-
auto op = token.getOwner()->getParentOp();
43+
llvm::SmallVector<Token> getCorrespondingTokens(Token token) {
44+
unsigned idx = token.asBlockArg().getArgNumber();
45+
auto op = token.asBlockArg().getOwner()->getParentOp();
4546
DeviceParallelOp parent = llvm::cast<DeviceParallelOp>(op);
46-
llvm::SmallVector<mlir::BlockArgument> results;
47+
llvm::SmallVector<Token> results;
4748
results.reserve(parent.getNumRegions());
4849
for (auto region : parent.getRegions()) {
49-
results.push_back(region->getArgument(idx));
50+
results.push_back(Token(region->getArgument(idx)));
5051
}
5152
return results;
5253
}
5354

54-
llvm::SmallVector<mlir::Operation *> getTokenUsers(mlir::BlockArgument token) {
55-
llvm::SmallVector<mlir::Operation *, 4> results;
56-
for (auto user : token.getUsers()) {
57-
results.push_back(user);
55+
llvm::SmallVector<mlir::Operation *> getTokenUsers(Token token) {
56+
auto all_tokens = getCorrespondingTokens(token);
57+
llvm::SmallVector<mlir::Operation *> results;
58+
// Concatenate all users of all corresponding tokens.
59+
// Due to scoping rules and since each token is a block arg to a
60+
// different region, there should be no duplicates here.
61+
for (auto t : all_tokens) {
62+
for (auto user : t.asBlockArg().getUsers()) {
63+
results.push_back(user);
64+
}
5865
}
5966
return results;
6067
}
6168

69+
bool isSoleSender(TokenWriterOpInterface writer) {
70+
auto tokens = writer.getWriteTokens();
71+
// Check for conflicts on all tokens
72+
for (auto token : tokens) {
73+
auto users = getTokenUsers(token);
74+
if (!isSoleSender(writer, token, users)) {
75+
return false;
76+
}
77+
}
78+
return true;
79+
}
80+
81+
bool isSoleSender(TokenWriterOpInterface writer, Token token,
82+
llvm::ArrayRef<Operation *> others) {
83+
for (auto user : others) {
84+
TypedValue<TokenType> as_val = token.asTypedValue();
85+
TokenWriterOpInterface other = dyn_cast<TokenWriterOpInterface>(user);
86+
if (other && other != writer) {
87+
// Found another writer using the same token. Check if it uses
88+
// the token to write, or only for something else:
89+
auto other_write_tokens = other.getWriteTokens();
90+
for (auto t : other_write_tokens) {
91+
if (t == as_val) {
92+
return false; // Found another op writing to the same token
93+
}
94+
}
95+
}
96+
}
97+
return true;
98+
}
6299
} // namespace mlir::enzyme::distributed

src/enzyme_ad/jax/Dialect/Distributed/Utils.h

Lines changed: 29 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,16 +2,19 @@
22
#define ENZYME_AD_JAX_DIALECT_DISTRIBUTED_UTILS_H
33

44
#include "Dialect.h"
5+
#include "Traits.h"
56

67
namespace mlir::enzyme::distributed {
78

8-
/** Get the enclosing device parallel branch for a given operation, or nullptr
9+
/**
10+
* Get the enclosing device parallel branch for a given operation, or nullptr
911
* if the provided deviceParallelOp is not an ancestor of op.
1012
*/
1113
Region *getEnclosingDeviceParallelBranch(DeviceParallelOp parent,
1214
Operation *op);
1315

14-
/** Get the index of a device parallel branch within its parent operation.
16+
/**
17+
* Get the index of a device parallel branch within its parent operation.
1518
* Parent op must be the direct parent of the branch region.
1619
*/
1720
int getDeviceParallelBranchIndex(DeviceParallelOp parent, Region *branch);
@@ -24,13 +27,31 @@ int getDeviceParallelBranchIndex(DeviceParallelOp parent, Region *branch);
2427
mlir::Operation *getExecutingDevice(mlir::Operation *op);
2528

2629
/**
27-
* Returns all block arguments in the same device parallel region corresponding
28-
* to the provided token, including the provided token itself. Will be provided
29-
* in the same order as the branch assignments of the parent device parallel op.
30+
* Returns the counterpart tokens across all branches for the provided token.
31+
* Each token here corresponds to the same logical token, but passed as a
32+
* different block argument to each branch. Tokens are ordered in the same order
33+
* as the branches of the parent DeviceParallelOp. Includes token itself.
34+
*/
35+
llvm::SmallVector<Token> getCorrespondingTokens(Token token);
36+
37+
/**
38+
* Returns all users of the provided token or its counterpart across all
39+
* branches, including readers, writers, and any other op that takes the token
40+
* as an operand.
41+
*/
42+
llvm::SmallVector<mlir::Operation *> getTokenUsers(Token token);
43+
44+
/**
45+
* Returns true if no other ops ever write to any token written by the
46+
* provided op.
47+
*/
48+
bool isSoleSender(TokenWriterOpInterface writer);
49+
50+
/**
51+
* Returns true if no other ops in the provided list send on the same channel.
3052
*/
31-
llvm::SmallVector<mlir::BlockArgument>
32-
getCorrespondingTokens(mlir::BlockArgument token);
33-
llvm::SmallVector<mlir::Operation *> getTokenUsers(mlir::BlockArgument token);
53+
bool isSoleSender(TokenWriterOpInterface writer, Token token,
54+
llvm::ArrayRef<Operation *> others);
3455
} // namespace mlir::enzyme::distributed
3556

3657
#endif // ENZYME_AD_JAX_DIALECT_DISTRIBUTED_UTILS_H
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
/**
2+
* Replaces send(constant); recv(); with just constant.
3+
*/
4+
5+
#include "Passes.h"
6+
#include "src/enzyme_ad/jax/Dialect/Distributed/Dialect.h"
7+
#include "src/enzyme_ad/jax/Dialect/Distributed/Utils.h"
8+
#include "stablehlo/dialect/StablehloOps.h"
9+
10+
namespace mlir::enzyme::distributed {
11+
#define GEN_PASS_DEF_ELIMINATECONSTANTCOMMUNICATIONPASS
12+
#include "src/enzyme_ad/jax/Passes/Distributed/Passes.h.inc"
13+
14+
bool isConstantOp(Operation *op) { return isa<stablehlo::ConstantOp>(op); }
15+
bool isConstant(Value val) {
16+
if (auto op = val.getDefiningOp()) {
17+
return isConstantOp(op);
18+
}
19+
return false;
20+
}
21+
22+
struct EliminateConstantCommunicationPass
23+
: public impl::EliminateConstantCommunicationPassBase<
24+
EliminateConstantCommunicationPass> {
25+
using EliminateConstantCommunicationPassBase::
26+
EliminateConstantCommunicationPassBase;
27+
void runOnOperation() override {
28+
Operation *op = getOperation();
29+
// Post-order walk is allowed to erase the sends. Less sure if we
30+
// are permitted to erase the recvs during the walk.
31+
op->walk([&](enzyme::distributed::SendOp send) {
32+
if (isConstant(send.getValue())) {
33+
// Check that we are the only sender on this channel, and get
34+
// the corresponding recvs.
35+
auto users = getTokenUsers(send.getToken());
36+
if (!isSoleSender(send, send.getToken(), users)) {
37+
// If we're not the sole sender, we can't eliminate the communication.
38+
return;
39+
}
40+
// If we are the sole sender, we can replace all recvs with a copy of
41+
// the constant value. However, since the recv may be in a different
42+
// scope, we need to replace it with a clone of the constant op.
43+
for (auto user : users) {
44+
if (auto recv = dyn_cast<enzyme::distributed::RecvOp>(user)) {
45+
auto cloned_const = send.getValue().getDefiningOp()->clone();
46+
// Insert the cloned constant right before the recv
47+
recv->getBlock()->getOperations().insert(recv->getIterator(),
48+
cloned_const);
49+
recv.getResult().replaceAllUsesWith(cloned_const->getResult(0));
50+
recv.erase();
51+
}
52+
}
53+
send.erase();
54+
}
55+
});
56+
}
57+
};
58+
} // namespace mlir::enzyme::distributed

src/enzyme_ad/jax/Passes/Distributed/Passes.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
11
#ifndef ENZYMEXLA_DISTRIBUTED_PASSES_H
22
#define ENZYMEXLA_DISTRIBUTED_PASSES_H
33

4+
#include "mlir/Pass/Pass.h"
5+
46
namespace mlir::enzyme::distributed {
57

68
#define GEN_PASS_DECL
79
#include "src/enzyme_ad/jax/Passes/Distributed/Passes.h.inc"
10+
811
#define GEN_PASS_REGISTRATION
912
#include "src/enzyme_ad/jax/Passes/Distributed/Passes.h.inc"
1013

src/enzyme_ad/jax/Passes/Distributed/Passes.td

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,10 @@ def EliminateConstantCommunicationPass : Pass<"eliminate-constant-communication"
99
This pass identifies send instructions with constant operands and replaces
1010
the corresponding receive instructions with local constants.
1111
}];
12-
let dependentDialects = ["enzyme::distributed::DistributedDialect"];
12+
let dependentDialects = [
13+
"enzyme::distributed::DistributedDialect",
14+
"stablehlo::StablehloDialect"
15+
];
1316
}
1417

1518
#endif // ENZYMEXLA_DISTRIBUTED_PASSES

src/enzyme_ad/jax/RegistryUtils.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -295,6 +295,7 @@ void registerInterfaces(mlir::DialectRegistry &registry) {
295295
void initializePasses() {
296296
registerenzymePasses();
297297
enzyme::registerenzymexlaPasses();
298+
enzyme::distributed::registerdistributedPasses();
298299

299300
// Register the standard passes we want.
300301
mlir::registerCSEPass();
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
// RUN: enzymexlamlir-opt --eliminate-constant-communication %s | FileCheck %s
2+
distributed.leaf_device @myGpu
3+
distributed.device_mesh @gpuMesh @myGpu [2, 2]
4+
distributed.leaf_device @myCpu
5+
distributed.channel @chan1 [@myCpu, @gpuMesh] [@gpuMesh, @myCpu]
6+
distributed.device_group @gpusWithHost [@myGpu, @myCpu] [@chan1]
7+
8+
func.func @foo() {
9+
distributed.device_parallel @gpusWithHost {
10+
branch @myGpu {
11+
^entry(%1: !distributed.token):
12+
distributed.device_parallel @gpuMesh {
13+
branch @myGpu {
14+
^entry():
15+
}
16+
}
17+
}
18+
branch @myCpu {
19+
^entry(%1: !distributed.token):
20+
%output = stablehlo.constant() {
21+
value = dense<[[0.0, 1.0], [2.0, 3.0]]> : tensor<2x2xf32>
22+
} : () -> tensor<2x2xf32>
23+
distributed.send %1 tensor<2x2xf32> %output
24+
25+
}
26+
branch @chan1 {
27+
^entry(%1: !distributed.token):
28+
%input = distributed.recv %1 tensor<2x2xf32>
29+
%sum = stablehlo.add %input, %input : tensor<2x2xf32>
30+
}
31+
}
32+
33+
func.return
34+
}
35+
36+
//CHECK: module {
37+
//CHECK-NEXT: distributed.leaf_device @myGpu
38+
//CHECK-NEXT: distributed.device_mesh @gpuMesh @myGpu [2, 2]
39+
//CHECK-NEXT: distributed.leaf_device @myCpu
40+
//CHECK-NEXT: distributed.channel @chan1 [@myCpu, @gpuMesh] [@gpuMesh, @myCpu]
41+
//CHECK-NEXT: distributed.device_group @gpusWithHost [@myGpu, @myCpu] [@chan1]
42+
//CHECK-NEXT: func.func @foo() {
43+
//CHECK-NEXT: distributed.device_parallel @gpusWithHost{ branch @myGpu{
44+
//CHECK-NEXT: ^bb0(%arg0: !distributed.token):
45+
//CHECK-NEXT: distributed.device_parallel @gpuMesh{ branch @myGpu{
46+
//CHECK-NEXT: }}
47+
//CHECK-NEXT: } branch @myCpu{
48+
//CHECK-NEXT: ^bb0(%arg0: !distributed.token):
49+
//CHECK-NEXT{LITERAL}: %cst = stablehlo.constant dense<[[0.000000e+00, 1.000000e+00], [2.000000e+00, 3.000000e+00]]> : tensor<2x2xf32>
50+
//CHECK-NEXT: } branch @chan1{
51+
//CHECK-NEXT: ^bb0(%arg0: !distributed.token):
52+
//CHECK-NEXT{LITERAL}: %cst = stablehlo.constant dense<[[0.000000e+00, 1.000000e+00], [2.000000e+00, 3.000000e+00]]> : tensor<2x2xf32>
53+
//CHECK-NEXT: %0 = stablehlo.add %cst, %cst : tensor<2x2xf32>
54+
//CHECK-NEXT: }}
55+
//CHECK-NEXT: return
56+
//CHECK-NEXT: }
57+
//CHECK-NEXT:}

test/lit_tests/distributed/roundtrip.mlir

Lines changed: 15 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -12,36 +12,33 @@ func.func @foo() {
1212
distributed.device_parallel @gpuMesh {
1313
branch @myGpu {
1414
^entry():
15-
distributed.noop
1615
}
1716
}
1817
}
1918
branch @myCpu {
2019
^entry():
21-
distributed.noop
2220
}
2321
branch @chan1 {
2422
^entry():
25-
distributed.noop
2623
}
2724
}
2825

2926
func.return
3027
}
3128

3229
//CHECK: module {
33-
//CHECK-NEXT: distributed.LeafDevice @myGpu
34-
//CHECK-NEXT: distributed.DeviceMesh @gpuMesh @myGpu [2, 2]
35-
//CHECK-NEXT: distributed.LeafDevice @myCpu
36-
//CHECK-NEXT: distributed.Channel @chan1 [@myCpu, @gpuMesh] [@gpuMesh, @myCpu]
37-
//CHECK-NEXT: distributed.DeviceGroup @gpusWithHost [@myGpu, @myCpu] [@chan1]
38-
//CHECK-NEXT: func.func @foo() {
39-
//CHECK-NEXT: distributed.GroupSplit @gpusWithHost branch @myGpu{
40-
//CHECK-NEXT: distributed.MeshFor @gpuMesh {
41-
//CHECK-NEXT: }
42-
//CHECK-NEXT: } branch @myCpu{
43-
//CHECK-NEXT: %0 = distributed.DefineToken @chan1
44-
//CHECK-NEXT: }
45-
//CHECK-NEXT: return
46-
//CHECK-NEXT: }
47-
//CHECK-NEXT: }
30+
//CHECK-NEXT: distributed.leaf_device @myGpu
31+
//CHECK-NEXT: distributed.device_mesh @gpuMesh @myGpu [2, 2]
32+
//CHECK-NEXT: distributed.leaf_device @myCpu
33+
//CHECK-NEXT: distributed.channel @chan1 [@myCpu, @gpuMesh] [@gpuMesh, @myCpu]
34+
//CHECK-NEXT: distributed.device_group @gpusWithHost [@myGpu, @myCpu] [@chan1]
35+
//CHECK-NEXT: func.func @foo() {
36+
//CHECK-NEXT: distributed.device_parallel @gpusWithHost{ branch @myGpu{
37+
//CHECK-NEXT: distributed.device_parallel @gpuMesh{ branch @myGpu{
38+
//CHECK-NEXT: }}
39+
//CHECK-NEXT: } branch @myCpu{
40+
//CHECK-NEXT: } branch @chan1{
41+
//CHECK-NEXT: }}
42+
//CHECK-NEXT: return
43+
//CHECK-NEXT: }
44+
//CHECK-NEXT:}

0 commit comments

Comments
 (0)