Skip to content

Commit a49eba5

Browse files
EganJwsmoses
authored andcommitted
Distributed dialect- change subbranches to regions
1 parent 12bca85 commit a49eba5

File tree

3 files changed

+64
-55
lines changed

3 files changed

+64
-55
lines changed

src/enzyme_ad/jax/Dialect/Distributed/Ops.cpp

Lines changed: 34 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -111,20 +111,44 @@ GroupSplitOp::verifySymbolUses(::mlir::SymbolTableCollection &symbol_table) {
111111
getDeviceGroupAttr());
112112
}
113113

114-
LogicalResult
115-
SplitBranchOp::verifySymbolUses(::mlir::SymbolTableCollection &symbol_table) {
116-
// Split branches have programs for individual devices or channels
117-
Operation *dev_or_chan =
118-
symbol_table.lookupNearestSymbolFrom(*this, getDeviceOrChannelAttr());
119-
if (!dev_or_chan || !(dev_or_chan->hasTrait<DeviceDefTrait>() ||
120-
dev_or_chan->hasTrait<ChannelDefTrait>())) {
121-
mlir::emitError(getLoc())
122-
<< "branches must reference a valid device or channel";
123-
return mlir::failure();
114+
// Printer/parser for GroupsplitOp branches
115+
mlir::ParseResult parseSplitBranches(
116+
OpAsmParser &parser, mlir::ArrayAttr &branchAssignments,
117+
llvm::SmallVector<std::unique_ptr<::mlir::Region>, 2> &branchesRegions) {
118+
// Expect 0 or more `branch` $symbol_name $symbol_region
119+
// While next token is `branch`:
120+
llvm::SmallVector<mlir::Attribute, 2> assignment_symbols;
121+
while (parser.parseOptionalKeyword("branch").succeeded()) {
122+
// Parse symbol name
123+
mlir::SymbolRefAttr sym;
124+
auto sym_parse_failed = parser.parseAttribute<mlir::SymbolRefAttr>(sym);
125+
if (sym_parse_failed)
126+
return mlir::failure();
127+
assignment_symbols.push_back(sym);
128+
129+
// Put placeholder region in list and parse into it
130+
branchesRegions.push_back(std::make_unique<mlir::Region>());
131+
auto parse_region_failed = parser.parseRegion(*branchesRegions.back());
132+
if (parse_region_failed)
133+
return mlir::failure();
124134
}
135+
136+
branchAssignments = mlir::ArrayAttr::get(parser.getBuilder().getContext(),
137+
assignment_symbols);
125138
return mlir::success();
126139
}
127140

141+
void printSplitBranches(OpAsmPrinter &printer, const GroupSplitOp &op,
142+
const mlir::ArrayAttr branchAssignments,
143+
const llvm::MutableArrayRef<mlir::Region> branches) {
144+
// Print each branch as `branch` $symbol_name $symbol_region
145+
for (size_t i = 0; i < branches.size(); i++) {
146+
printer << " branch ";
147+
printer.printAttribute(branchAssignments[i]);
148+
printer.printRegion(branches[i]);
149+
}
150+
}
151+
128152
LogicalResult
129153
DefineTokenOp::verifySymbolUses(::mlir::SymbolTableCollection &symbol_table) {
130154
// Tokens need to indicate which channel they communicate over

src/enzyme_ad/jax/Dialect/Distributed/Ops.td

Lines changed: 6 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -59,25 +59,16 @@ def MeshForOp : DistributedOp<"MeshFor", [DeclareOpInterfaceMethods<SymbolUserOp
5959
let assemblyFormat = "$mesh $body attr-dict";
6060
}
6161

62-
def GroupSplitOp : DistributedOp<"GroupSplit", [DeclareOpInterfaceMethods<SymbolUserOpInterface>, NoTerminator, SingleBlock]>{
62+
def GroupSplitOp : DistributedOp<"GroupSplit", [DeclareOpInterfaceMethods<SymbolUserOpInterface>, NoTerminator]>{
6363
let arguments = (ins
64-
SymbolRefAttr:$device_group // TODO: verify it's a group
64+
SymbolRefAttr:$device_group,
65+
ArrayAttr:$branch_assignments // Symbols mapping to each branch region
6566
);
66-
let regions = (region SizedRegion<1>:$declarations); // Takes as args the devices and channels in the group
67+
// TODO check that declarations only declares tokens.
68+
let regions = (region VariadicRegion<SizedRegion<1>>:$branches);
6769
let results = (outs ); // TODO
6870
// let hasVerifier = 1; // TODO
69-
// let hasCanonicalizer = 1; // TODO: token declarations up front, followed by device and channel branches in order of listing in the group
70-
let assemblyFormat = "$device_group $declarations attr-dict";
71-
}
72-
73-
def SplitBranchOp : DistributedOp<"SplitBranch", [DeclareOpInterfaceMethods<SymbolUserOpInterface>, NoTerminator, SingleBlock]>{
74-
let arguments = (ins
75-
SymbolRefAttr:$device_or_channel // TODO: verify it's a device or channel
76-
);
77-
let regions = (region MaxSizedRegion<1>:$body); // Takes as args the device or channel
78-
let results = (outs ); // TODO
79-
// let hasVerifier = 1; // TODO: parent is a groupsplitop
80-
let assemblyFormat = "$device_or_channel $body attr-dict";
71+
let assemblyFormat = "$device_group custom<SplitBranches>($branch_assignments, $branches) attr-dict";
8172
}
8273

8374
def DefineTokenOp : DistributedOp<"DefineToken", [DeclareOpInterfaceMethods<SymbolUserOpInterface>]>{

test/lit_tests/distributed/roundtrip.mlir

Lines changed: 24 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -6,37 +6,31 @@ distributed.Channel @chan1 [@myCpu, @gpuMesh] [@gpuMesh, @myCpu]
66
distributed.DeviceGroup @gpusWithHost [@myGpu, @myCpu] [@chan1]
77

88
func.func @foo() {
9-
distributed.GroupSplit @gpusWithHost {
10-
%tok = distributed.DefineToken @chan1
11-
distributed.SplitBranch @chan1 { }
12-
distributed.SplitBranch @myCpu {}
13-
distributed.SplitBranch @gpuMesh {
14-
distributed.MeshFor @gpuMesh {
9+
distributed.GroupSplit @gpusWithHost
10+
branch @myGpu {
11+
distributed.MeshFor @gpuMesh {
12+
}
13+
}
14+
branch @myCpu {
15+
distributed.DefineToken @chan1
16+
}
1517

16-
}
17-
}
18-
}
1918
func.return
2019
}
2120

22-
// CHECK: module {
23-
// CHECK-NEXT: distributed.LeafDevice @myGpu
24-
// CHECK-NEXT: distributed.DeviceMesh @gpuMesh @myGpu [2, 2]
25-
// CHECK-NEXT: distributed.LeafDevice @myCpu
26-
// CHECK-NEXT: distributed.Channel @chan1 [@myCpu, @gpuMesh] [@gpuMesh, @myCpu]
27-
// CHECK-NEXT: distributed.DeviceGroup @gpusWithHost [@myGpu, @myCpu] [@chan1]
28-
// CHECK-NEXT: func.func @foo() {
29-
// CHECK-NEXT: distributed.GroupSplit @gpusWithHost {
30-
// CHECK-NEXT: %0 = distributed.DefineToken @chan1
31-
// CHECK-NEXT: distributed.SplitBranch @chan1 {
32-
// CHECK-NEXT: }
33-
// CHECK-NEXT: distributed.SplitBranch @myCpu {
34-
// CHECK-NEXT: }
35-
// CHECK-NEXT: distributed.SplitBranch @gpuMesh {
36-
// CHECK-NEXT: distributed.MeshFor @gpuMesh {
37-
// CHECK-NEXT: }
38-
// CHECK-NEXT: }
39-
// CHECK-NEXT: }
40-
// CHECK-NEXT: return
41-
// CHECK-NEXT: }
42-
// CHECK-NEXT: }
21+
//CHECK: module {
22+
//CHECK-NEXT: distributed.LeafDevice @myGpu
23+
//CHECK-NEXT: distributed.DeviceMesh @gpuMesh @myGpu [2, 2]
24+
//CHECK-NEXT: distributed.LeafDevice @myCpu
25+
//CHECK-NEXT: distributed.Channel @chan1 [@myCpu, @gpuMesh] [@gpuMesh, @myCpu]
26+
//CHECK-NEXT: distributed.DeviceGroup @gpusWithHost [@myGpu, @myCpu] [@chan1]
27+
//CHECK-NEXT: func.func @foo() {
28+
//CHECK-NEXT: distributed.GroupSplit @gpusWithHost branch @myGpu{
29+
//CHECK-NEXT: distributed.MeshFor @gpuMesh {
30+
//CHECK-NEXT: }
31+
//CHECK-NEXT: } branch @myCpu{
32+
//CHECK-NEXT: %0 = distributed.DefineToken @chan1
33+
//CHECK-NEXT: }
34+
//CHECK-NEXT: return
35+
//CHECK-NEXT: }
36+
//CHECK-NEXT: }

0 commit comments

Comments
 (0)