Skip to content

Commit 1b726af

Browse files
EganJwsmoses
authored andcommitted
Single device parallel op with block arg tokens
1 parent ff303ea commit 1b726af

File tree

5 files changed

+216
-61
lines changed

5 files changed

+216
-61
lines changed

src/enzyme_ad/jax/Dialect/Distributed/Ops.cpp

Lines changed: 54 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#include "llvm/ADT/TypeSwitch.h"
33

44
#include "Dialect.h"
5+
#include "Utils.h"
56

67
using mlir::OpTrait::enzyme::distributed::ChannelDefTrait;
78
using mlir::OpTrait::enzyme::distributed::DeviceDefTrait;
@@ -98,21 +99,60 @@ DeviceMeshOp::verifySymbolUses(::mlir::SymbolTableCollection &symbol_table) {
9899
getDeviceType());
99100
}
100101

101-
LogicalResult
102-
MeshForOp::verifySymbolUses(::mlir::SymbolTableCollection &symbol_table) {
103-
// Mesh for ops apply only to meshes
104-
return checkSymbolIsA<DeviceMeshOp>(symbol_table, *this, getMeshAttr());
102+
Operation *DeviceParallelOp::getEnclosingDeviceOp() {
103+
return mlir::SymbolTable::lookupNearestSymbolFrom(*this,
104+
getEnclosingDeviceAttr());
105105
}
106106

107-
LogicalResult
108-
GroupSplitOp::verifySymbolUses(::mlir::SymbolTableCollection &symbol_table) {
109-
// Group splits apply only to device groups
110-
return checkSymbolIsA<DeviceGroupOp>(symbol_table, *this,
111-
getDeviceGroupAttr());
107+
LogicalResult DeviceParallelOp::verifySymbolUses(
108+
::mlir::SymbolTableCollection &symbol_table) {
109+
Operation *device_op = this->getEnclosingDeviceOp();
110+
if (isa<DeviceGroupOp>(device_op) || isa<DeviceMeshOp>(device_op)) {
111+
return mlir::success();
112+
}
113+
return emitOpError()
114+
<< "enclosing device symbol must refer to a device group or mesh";
115+
}
116+
117+
LogicalResult DeviceParallelOp::verify() {
118+
// Check number of branches matches number of assignments
119+
120+
if (getNumRegions() != getBranchAssignments().size()) {
121+
return emitOpError()
122+
<< "number of regions must match number of branch assignments";
123+
}
124+
125+
// Look at device type to determine number of branches
126+
auto device_op = mlir::SymbolTable::lookupNearestSymbolFrom(
127+
*this, getEnclosingDeviceAttr());
128+
if (!device_op) {
129+
return emitOpError() << "could not find enclosing device symbol";
130+
}
131+
132+
if (DeviceGroupOp deviceGroup = dyn_cast<DeviceGroupOp>(device_op)) {
133+
// Device group: number of branches must match number of devices in group
134+
auto devices = deviceGroup.getDevices();
135+
auto channels = deviceGroup.getChannels();
136+
if (getNumRegions() != devices.size() + channels.size()) {
137+
return emitOpError() << "number of regions must match number of devices "
138+
"and channels in device group";
139+
}
140+
} else if (DeviceMeshOp mesh = dyn_cast<DeviceMeshOp>(device_op)) {
141+
// Exactly one branch for the mesth type
142+
if (getNumRegions() != 1) {
143+
return emitOpError()
144+
<< "device mesh must have exactly one region for its single type";
145+
}
146+
} else {
147+
return emitOpError()
148+
<< "enclosing device symbol must refer to a device group or mesh";
149+
}
150+
151+
return mlir::success();
112152
}
113153

114-
// Printer/parser for GroupsplitOp branches
115-
mlir::ParseResult parseSplitBranches(
154+
// Printer/parser for subdevice branches
155+
mlir::ParseResult parseDeviceBranches(
116156
OpAsmParser &parser, mlir::ArrayAttr &branchAssignments,
117157
llvm::SmallVector<std::unique_ptr<::mlir::Region>, 2> &branchesRegions) {
118158
// Expect 0 or more `branch` $symbol_name $symbol_region
@@ -138,9 +178,9 @@ mlir::ParseResult parseSplitBranches(
138178
return mlir::success();
139179
}
140180

141-
void printSplitBranches(OpAsmPrinter &printer, const GroupSplitOp &op,
142-
const mlir::ArrayAttr branchAssignments,
143-
const llvm::MutableArrayRef<mlir::Region> branches) {
181+
void printDeviceBranches(OpAsmPrinter &printer, const DeviceParallelOp &op,
182+
const mlir::ArrayAttr branchAssignments,
183+
const llvm::MutableArrayRef<mlir::Region> branches) {
144184
// Print each branch as `branch` $symbol_name $symbol_region
145185
for (size_t i = 0; i < branches.size(); i++) {
146186
printer << " branch ";
@@ -149,13 +189,6 @@ void printSplitBranches(OpAsmPrinter &printer, const GroupSplitOp &op,
149189
}
150190
}
151191

152-
LogicalResult
153-
DefineTokenOp::verifySymbolUses(::mlir::SymbolTableCollection &symbol_table) {
154-
// Tokens need to indicate which channel they communicate over
155-
return checkSymbolHasTrait<ChannelDefTrait>(symbol_table, *this,
156-
getChannelAttr());
157-
}
158-
159192
llvm::ArrayRef<mlir::TypedValue<TokenType>> SendOp::getWriteTokens() {
160193
return llvm::SmallVector<mlir::TypedValue<TokenType>, 1>{getToken()};
161194
}

src/enzyme_ad/jax/Dialect/Distributed/Ops.td

Lines changed: 44 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ include "Interfaces.td"
1010

1111
// Device definition ops
1212

13-
def ChannelOp : DistributedOp<"Channel", [Symbol, ChannelDefTrait, DeclareOpInterfaceMethods<SymbolUserOpInterface>]>{
13+
def ChannelOp : DistributedOp<"channel", [Symbol, ChannelDefTrait, DeclareOpInterfaceMethods<SymbolUserOpInterface>]>{
1414
let arguments = (ins
1515
SymbolNameAttr:$sym_name,
1616
// a variadic list of devices connected by this channel
@@ -21,15 +21,15 @@ def ChannelOp : DistributedOp<"Channel", [Symbol, ChannelDefTrait, DeclareOpInte
2121
let assemblyFormat = "$sym_name $sending_devices $receiving_devices attr-dict";
2222
}
2323

24-
def LeafDeviceOp : DistributedOp<"LeafDevice", [Symbol, DeviceDefTrait]>{
24+
def LeafDeviceOp : DistributedOp<"leaf_device", [Symbol, DeviceDefTrait]>{
2525
let arguments = (ins
2626
SymbolNameAttr:$sym_name
2727
// TODO: device type, e.g. TPU, GPU, CPU, and other attributes
2828
);
2929
let assemblyFormat = "$sym_name attr-dict";
3030
}
3131

32-
def DeviceGroupOp : DistributedOp<"DeviceGroup", [Symbol, DeviceDefTrait, DeclareOpInterfaceMethods<SymbolUserOpInterface>]>{
32+
def DeviceGroupOp : DistributedOp<"device_group", [Symbol, DeviceDefTrait, DeclareOpInterfaceMethods<SymbolUserOpInterface>]>{
3333
let arguments = (ins
3434
SymbolNameAttr:$sym_name,
3535
// a variadic list of devices in the group
@@ -39,7 +39,7 @@ def DeviceGroupOp : DistributedOp<"DeviceGroup", [Symbol, DeviceDefTrait, Declar
3939
);
4040
let assemblyFormat = "$sym_name $devices $channels attr-dict";
4141
}
42-
def DeviceMeshOp : DistributedOp<"DeviceMesh", [Symbol, DeviceDefTrait, DeclareOpInterfaceMethods<SymbolUserOpInterface>]>{
42+
def DeviceMeshOp : DistributedOp<"device_mesh", [Symbol, DeviceDefTrait, DeclareOpInterfaceMethods<SymbolUserOpInterface>]>{
4343
let arguments = (ins
4444
SymbolNameAttr:$sym_name,
4545
SymbolRefAttr:$device_type,
@@ -49,50 +49,63 @@ def DeviceMeshOp : DistributedOp<"DeviceMesh", [Symbol, DeviceDefTrait, DeclareO
4949
let assemblyFormat = "$sym_name $device_type $shape attr-dict";
5050
}
5151

52-
// Ops for breaking down computation across the device hierarchy
52+
// def ContinueOp : DistributedOp<"continue", [Terminator]> {
53+
// let description = [{
54+
// A terminator for DeviceParallelOp regions. Takes as arguments the tokens to be passed to the
55+
// continuation of the DeviceParallelOp. These values can then be used in a subsequent DeviceParallelOp
56+
// that is a sibling to the original DeviceParallelOp by referencing the returned tokens.
57+
// }];
58+
// let arguments = (ins Variadic<AnyType>:$operands);
59+
// let results = (outs ); // No outputs for terminators, the token is output by the parent DeviceParallelOp.
60+
// let assemblyFormat = "$operands type($operands) attr-dict";
61+
// }
5362

54-
def MeshForOp : DistributedOp<"MeshFor", [DeclareOpInterfaceMethods<SymbolUserOpInterface>, NoTerminator, SingleBlock]>{
55-
let arguments = (ins SymbolRefAttr:$mesh); // TODO: verify it's a mesh
56-
let regions = (region MaxSizedRegion<1>:$body); // TODO: body's block args are device type and mesh index
57-
let results = (outs ); // TODO
58-
// let hasVerifier = 1; // TODO: verify body's block args take mesh index
59-
let assemblyFormat = "$mesh $body attr-dict";
60-
}
63+
def DeviceParallelOp : DistributedOp<"device_parallel", [DeclareOpInterfaceMethods<SymbolUserOpInterface>, NoTerminator]>{
64+
let description = [{
65+
An op for mapping computations to subdevices. Serves both for homogenous device meshes as well
66+
as explicitly enumerated device groups. In the case of device meshes, this op should contain
67+
a single region to be executed in parallel on each device. In the case of device groups, this
68+
op should contain one region per device and channel in the group.
6169

62-
def GroupSplitOp : DistributedOp<"GroupSplit", [DeclareOpInterfaceMethods<SymbolUserOpInterface>, NoTerminator]>{
70+
In either case, regions must take as argument one device index within the parent device followed
71+
by a number of token arguments. Tokens are matched by positionally between different branches,
72+
and all branches must have the same number and type of token arguments (though they may be unused).
73+
}];
74+
6375
let arguments = (ins
64-
SymbolRefAttr:$device_group,
65-
ArrayAttr:$branch_assignments // Symbols mapping to each branch region
76+
SymbolRefAttr:$enclosing_device,
77+
ArrayAttr:$branch_assignments // the device components for each region (device-specific branch)
6678
);
67-
// TODO check that declarations only declares tokens.
6879
let regions = (region VariadicRegion<SizedRegion<1>>:$branches);
69-
let results = (outs ); // TODO
70-
// let hasVerifier = 1; // TODO
71-
let assemblyFormat = "$device_group custom<SplitBranches>($branch_assignments, $branches) attr-dict";
72-
}
73-
74-
def DefineTokenOp : DistributedOp<"DefineToken", [DeclareOpInterfaceMethods<SymbolUserOpInterface>]>{
75-
let arguments = (ins
76-
SymbolRefAttr:$channel
77-
);
78-
let results = (outs TokenType:$token);
79-
// let hasVerifier = 1; // TODO: verify writers and readers are connected to the channel
80-
let assemblyFormat = "$channel attr-dict";
80+
// let results = (outs Variadic<TokenType>:$continuation_tokens);
81+
let results = (outs );
82+
let hasVerifier = 1; // TODO
83+
let assemblyFormat = "$enclosing_device `{` custom<DeviceBranches>($branch_assignments, $branches) `}` attr-dict";
84+
let extraClassDeclaration = [{
85+
Operation* getEnclosingDeviceOp();
86+
}];
8187
}
8288

83-
def SendOp : DistributedOp<"Send", [DeclareOpInterfaceMethods<TokenWriterOpInterface>]>{
89+
def SendOp : DistributedOp<"send", [DeclareOpInterfaceMethods<TokenWriterOpInterface>]>{
8490
let arguments = (ins
85-
WriteTokenType:$token,
91+
TokenType:$token,
8692
// value to send
8793
AnyType:$value);
8894
let assemblyFormat = "$token type($value) $value attr-dict";
8995
}
9096

91-
def RecvOp : DistributedOp<"Recv", [DeclareOpInterfaceMethods<TokenReaderOpInterface>]>{
97+
def RecvOp : DistributedOp<"recv", [DeclareOpInterfaceMethods<TokenReaderOpInterface>]>{
9298
let arguments = (ins
9399
TokenType:$token);
94100
let results = (outs AnyType:$value);
95101
let assemblyFormat = "$token type($value) attr-dict";
96102
}
97103

104+
def NoopOp : DistributedOp<"noop", []>{
105+
let description = [{
106+
A placeholder no-op.
107+
}];
108+
let assemblyFormat = "attr-dict";
109+
}
110+
98111
#endif // ENZYME_AD_JAX_DIALECT_DISTRIBUTED_OPS_TD
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
#include "Utils.h"
2+
namespace mlir::enzyme::distributed {
3+
Region *getEnclosingDeviceParallelBranch(DeviceParallelOp parent,
4+
Operation *op) {
5+
auto region = op->getParentRegion();
6+
while (region->getParentOp() != parent) {
7+
auto region_parent =
8+
region->getParentOp(); // All regoins have parent ops...
9+
if (!region_parent->getParentRegion()) // But not all ops have parent
10+
// regions (e.g. top level ops)
11+
return nullptr;
12+
region = region_parent->getParentRegion();
13+
}
14+
return region;
15+
}
16+
17+
int getDeviceParallelBranchIndex(DeviceParallelOp parent, Region *branch) {
18+
assert(branch->getParentOp() == parent && "branch is not a region of parent");
19+
for (int i = 0; i < parent.getNumRegions(); i++) {
20+
if (&parent.getRegion(i) == branch)
21+
return i;
22+
}
23+
llvm_unreachable("branch not found in parent regions");
24+
return -1;
25+
}
26+
27+
mlir::Operation *getExecutingDevice(mlir::Operation *op) {
28+
// Find current branch
29+
auto parent = op->getParentOfType<DeviceParallelOp>();
30+
auto branch = getEnclosingDeviceParallelBranch(parent, op);
31+
if (!branch)
32+
return nullptr;
33+
// Find index of branch and cross-reference to parent device symbol
34+
int branch_idx = getDeviceParallelBranchIndex(parent, branch);
35+
auto device_sym = llvm::cast<mlir::SymbolRefAttr>(
36+
parent.getBranchAssignments()[branch_idx]);
37+
38+
return SymbolTable::lookupNearestSymbolFrom(parent, device_sym);
39+
}
40+
41+
llvm::SmallVector<mlir::BlockArgument>
42+
getCorrespondingTokens(mlir::BlockArgument token) {
43+
unsigned idx = token.getArgNumber();
44+
auto op = token.getOwner()->getParentOp();
45+
DeviceParallelOp parent = llvm::cast<DeviceParallelOp>(op);
46+
llvm::SmallVector<mlir::BlockArgument> results;
47+
results.reserve(parent.getNumRegions());
48+
for (auto region : parent.getRegions()) {
49+
results.push_back(region->getArgument(idx));
50+
}
51+
return results;
52+
}
53+
54+
llvm::SmallVector<mlir::Operation *> getTokenUsers(mlir::BlockArgument token) {
55+
llvm::SmallVector<mlir::Operation *, 4> results;
56+
for (auto user : token.getUsers()) {
57+
results.push_back(user);
58+
}
59+
return results;
60+
}
61+
62+
} // namespace mlir::enzyme::distributed
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
#ifndef ENZYME_AD_JAX_DIALECT_DISTRIBUTED_UTILS_H
2+
#define ENZYME_AD_JAX_DIALECT_DISTRIBUTED_UTILS_H
3+
4+
#include "Dialect.h"
5+
6+
namespace mlir::enzyme::distributed {
7+
8+
/** Get the enclosing device parallel branch for a given operation, or nullptr
9+
* if the provided deviceParallelOp is not an ancestor of op.
10+
*/
11+
Region *getEnclosingDeviceParallelBranch(DeviceParallelOp parent,
12+
Operation *op);
13+
14+
/** Get the index of a device parallel branch within its parent operation.
15+
* Parent op must be the direct parent of the branch region.
16+
*/
17+
int getDeviceParallelBranchIndex(DeviceParallelOp parent, Region *branch);
18+
19+
/**
20+
* Returns the defining op of the enclosing device of a given computational op
21+
* (e.g. not the parent of a device defintion op). Returns nullptr if no such
22+
* device can be found (not inside a device parallel region).
23+
*/
24+
mlir::Operation *getExecutingDevice(mlir::Operation *op);
25+
26+
/**
27+
* Returns all block arguments in the same device parallel region corresponding
28+
* to the provided token, including the provided token itself. Will be provided
29+
* in the same order as the branch assignments of the parent device parallel op.
30+
*/
31+
llvm::SmallVector<mlir::BlockArgument>
32+
getCorrespondingTokens(mlir::BlockArgument token);
33+
llvm::SmallVector<mlir::Operation *> getTokenUsers(mlir::BlockArgument token);
34+
} // namespace mlir::enzyme::distributed
35+
36+
#endif // ENZYME_AD_JAX_DIALECT_DISTRIBUTED_UTILS_H

test/lit_tests/distributed/roundtrip.mlir

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,30 @@
11
// RUN: enzymexlamlir-opt %s | FileCheck %s
2-
distributed.LeafDevice @myGpu
3-
distributed.DeviceMesh @gpuMesh @myGpu [2, 2]
4-
distributed.LeafDevice @myCpu
5-
distributed.Channel @chan1 [@myCpu, @gpuMesh] [@gpuMesh, @myCpu]
6-
distributed.DeviceGroup @gpusWithHost [@myGpu, @myCpu] [@chan1]
2+
distributed.leaf_device @myGpu
3+
distributed.device_mesh @gpuMesh @myGpu [2, 2]
4+
distributed.leaf_device @myCpu
5+
distributed.channel @chan1 [@myCpu, @gpuMesh] [@gpuMesh, @myCpu]
6+
distributed.device_group @gpusWithHost [@myGpu, @myCpu] [@chan1]
77

88
func.func @foo() {
9-
distributed.GroupSplit @gpusWithHost
9+
distributed.device_parallel @gpusWithHost {
1010
branch @myGpu {
11-
distributed.MeshFor @gpuMesh {
11+
^entry():
12+
distributed.device_parallel @gpuMesh {
13+
branch @myGpu {
14+
^entry():
15+
distributed.noop
16+
}
17+
}
1218
}
13-
}
1419
branch @myCpu {
15-
distributed.DefineToken @chan1
20+
^entry():
21+
distributed.noop
22+
}
23+
branch @chan1 {
24+
^entry():
25+
distributed.noop
1626
}
27+
}
1728

1829
func.return
1930
}

0 commit comments

Comments
 (0)