@@ -10,7 +10,7 @@ include "Interfaces.td"
1010
1111// Device definition ops
1212
13- def ChannelOp : DistributedOp<"Channel ", [Symbol, ChannelDefTrait, DeclareOpInterfaceMethods<SymbolUserOpInterface>]>{
13+ def ChannelOp : DistributedOp<"channel ", [Symbol, ChannelDefTrait, DeclareOpInterfaceMethods<SymbolUserOpInterface>]>{
1414 let arguments = (ins
1515 SymbolNameAttr:$sym_name,
1616 // a variadic list of devices connected by this channel
@@ -21,15 +21,15 @@ def ChannelOp : DistributedOp<"Channel", [Symbol, ChannelDefTrait, DeclareOpInte
2121 let assemblyFormat = "$sym_name $sending_devices $receiving_devices attr-dict";
2222}
2323
24- def LeafDeviceOp : DistributedOp<"LeafDevice ", [Symbol, DeviceDefTrait]>{
24+ def LeafDeviceOp : DistributedOp<"leaf_device ", [Symbol, DeviceDefTrait]>{
2525 let arguments = (ins
2626 SymbolNameAttr:$sym_name
2727 // TODO: device type, e.g. TPU, GPU, CPU, and other attributes
2828 );
2929 let assemblyFormat = "$sym_name attr-dict";
3030}
3131
32- def DeviceGroupOp : DistributedOp<"DeviceGroup ", [Symbol, DeviceDefTrait, DeclareOpInterfaceMethods<SymbolUserOpInterface>]>{
32+ def DeviceGroupOp : DistributedOp<"device_group ", [Symbol, DeviceDefTrait, DeclareOpInterfaceMethods<SymbolUserOpInterface>]>{
3333 let arguments = (ins
3434 SymbolNameAttr:$sym_name,
3535 // a variadic list of devices in the group
@@ -39,7 +39,7 @@ def DeviceGroupOp : DistributedOp<"DeviceGroup", [Symbol, DeviceDefTrait, Declar
3939 );
4040 let assemblyFormat = "$sym_name $devices $channels attr-dict";
4141}
42- def DeviceMeshOp : DistributedOp<"DeviceMesh ", [Symbol, DeviceDefTrait, DeclareOpInterfaceMethods<SymbolUserOpInterface>]>{
42+ def DeviceMeshOp : DistributedOp<"device_mesh ", [Symbol, DeviceDefTrait, DeclareOpInterfaceMethods<SymbolUserOpInterface>]>{
4343 let arguments = (ins
4444 SymbolNameAttr:$sym_name,
4545 SymbolRefAttr:$device_type,
@@ -49,50 +49,63 @@ def DeviceMeshOp : DistributedOp<"DeviceMesh", [Symbol, DeviceDefTrait, DeclareO
4949 let assemblyFormat = "$sym_name $device_type $shape attr-dict";
5050}
5151
52- // Ops for breaking down computation across the device hierarchy
52+ // def ContinueOp : DistributedOp<"continue", [Terminator]> {
53+ // let description = [{
54+ // A terminator for DeviceParallelOp regions. Takes as arguments the tokens to be passed to the
55+ // continuation of the DeviceParallelOp. These values can then be used in a subsequent DeviceParallelOp
56+ // that is a sibling to the original DeviceParallelOp by referencing the returned tokens.
57+ // }];
58+ // let arguments = (ins Variadic<AnyType>:$operands);
59+ // let results = (outs ); // No outputs for terminators, the token is output by the parent DeviceParallelOp.
60+ // let assemblyFormat = "$operands type($operands) attr-dict";
61+ // }
5362
54- def MeshForOp : DistributedOp<"MeshFor", [DeclareOpInterfaceMethods<SymbolUserOpInterface>, NoTerminator, SingleBlock]>{
55- let arguments = (ins SymbolRefAttr:$mesh); // TODO: verify it's a mesh
56- let regions = (region MaxSizedRegion<1>:$body); // TODO: body's block args are device type and mesh index
57- let results = (outs ); // TODO
58- // let hasVerifier = 1; // TODO: verify body's block args take mesh index
59- let assemblyFormat = "$mesh $body attr-dict";
60- }
63+ def DeviceParallelOp : DistributedOp<"device_parallel", [DeclareOpInterfaceMethods<SymbolUserOpInterface>, NoTerminator]>{
64+ let description = [{
65+ An op for mapping computations to subdevices. Serves both for homogenous device meshes as well
66+ as explicitly enumerated device groups. In the case of device meshes, this op should contain
67+ a single region to be executed in parallel on each device. In the case of device groups, this
68+ op should contain one region per device and channel in the group.
6169
62- def GroupSplitOp : DistributedOp<"GroupSplit", [DeclareOpInterfaceMethods<SymbolUserOpInterface>, NoTerminator]>{
70+ In either case, regions must take as argument one device index within the parent device followed
71+ by a number of token arguments. Tokens are matched by positionally between different branches,
72+ and all branches must have the same number and type of token arguments (though they may be unused).
73+ }];
74+
6375 let arguments = (ins
64- SymbolRefAttr:$device_group ,
65- ArrayAttr:$branch_assignments // Symbols mapping to each branch region
76+ SymbolRefAttr:$enclosing_device ,
77+ ArrayAttr:$branch_assignments // the device components for each region (device-specific branch)
6678 );
67- // TODO check that declarations only declares tokens.
6879 let regions = (region VariadicRegion<SizedRegion<1>>:$branches);
69- let results = (outs ); // TODO
70- // let hasVerifier = 1; // TODO
71- let assemblyFormat = "$device_group custom<SplitBranches>($branch_assignments, $branches) attr-dict";
72- }
73-
74- def DefineTokenOp : DistributedOp<"DefineToken", [DeclareOpInterfaceMethods<SymbolUserOpInterface>]>{
75- let arguments = (ins
76- SymbolRefAttr:$channel
77- );
78- let results = (outs TokenType:$token);
79- // let hasVerifier = 1; // TODO: verify writers and readers are connected to the channel
80- let assemblyFormat = "$channel attr-dict";
80+ // let results = (outs Variadic<TokenType>:$continuation_tokens);
81+ let results = (outs );
82+ let hasVerifier = 1; // TODO
83+ let assemblyFormat = "$enclosing_device `{` custom<DeviceBranches>($branch_assignments, $branches) `}` attr-dict";
84+ let extraClassDeclaration = [{
85+ Operation* getEnclosingDeviceOp();
86+ }];
8187}
8288
83- def SendOp : DistributedOp<"Send ", [DeclareOpInterfaceMethods<TokenWriterOpInterface>]>{
89+ def SendOp : DistributedOp<"send ", [DeclareOpInterfaceMethods<TokenWriterOpInterface>]>{
8490 let arguments = (ins
85- WriteTokenType :$token,
91+ TokenType :$token,
8692 // value to send
8793 AnyType:$value);
8894 let assemblyFormat = "$token type($value) $value attr-dict";
8995}
9096
91- def RecvOp : DistributedOp<"Recv ", [DeclareOpInterfaceMethods<TokenReaderOpInterface>]>{
97+ def RecvOp : DistributedOp<"recv ", [DeclareOpInterfaceMethods<TokenReaderOpInterface>]>{
9298 let arguments = (ins
9399 TokenType:$token);
94100 let results = (outs AnyType:$value);
95101 let assemblyFormat = "$token type($value) attr-dict";
96102}
97103
104+ def NoopOp : DistributedOp<"noop", []>{
105+ let description = [{
106+ A placeholder no-op.
107+ }];
108+ let assemblyFormat = "attr-dict";
109+ }
110+
98111#endif // ENZYME_AD_JAX_DIALECT_DISTRIBUTED_OPS_TD
0 commit comments