Skip to content

Commit 1d8e1fa

Browse files
EganJwsmoses
authored andcommitted
Send, Recv ops and interfaces
1 parent a49eba5 commit 1d8e1fa

File tree

7 files changed

+67
-9
lines changed

7 files changed

+67
-9
lines changed

src/enzyme_ad/jax/BUILD

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -569,11 +569,11 @@ gentbl_cc_library(
569569
name = "DistributedInterfacesIncGen",
570570
tbl_outs = [
571571
(
572-
["--gen-interface-decls"],
572+
["--gen-op-interface-decls"],
573573
"Dialect/Distributed/DistributedInterfaces.h.inc",
574574
),
575575
(
576-
["--gen-interface-defs"],
576+
["--gen-op-interface-defs"],
577577
"Dialect/Distributed/DistributedInterfaces.cpp.inc",
578578
),
579579
],
@@ -744,6 +744,7 @@ cc_library(
744744
deps = [
745745
":CheckedRewrite",
746746
":DistributedDialectIncGen",
747+
":DistributedInterfacesIncGen",
747748
":DistributedOpsIncGen",
748749
":DistributedTypesIncGen",
749750
":EnzymeHLOPatternsIncGen",

src/enzyme_ad/jax/Dialect/Distributed/Dialect.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
#define GET_TYPEDEF_CLASSES
99
#include "src/enzyme_ad/jax/Dialect/Distributed/DistributedTypes.cpp.inc"
1010

11+
#include "src/enzyme_ad/jax/Dialect/Distributed/DistributedInterfaces.cpp.inc"
12+
1113
// Initialize the dialect
1214
void mlir::enzyme::distributed::DistributedDialect::initialize() {
1315
addTypes<

src/enzyme_ad/jax/Dialect/Distributed/Dialect.h

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,14 +10,14 @@
1010
#include "mlir/IR/SymbolTable.h"
1111
#include "mlir/IR/Types.h"
1212

13-
// Include the dialect
14-
#include "src/enzyme_ad/jax/Dialect/Distributed/DistributedDialect.h.inc"
15-
// Traits and interfaces
1613
#include "Traits.h"
17-
// Types
14+
#include "src/enzyme_ad/jax/Dialect/Distributed/DistributedDialect.h.inc"
15+
1816
#define GET_TYPEDEF_CLASSES
1917
#include "src/enzyme_ad/jax/Dialect/Distributed/DistributedTypes.h.inc"
20-
// Operations
18+
19+
#include "src/enzyme_ad/jax/Dialect/Distributed/DistributedInterfaces.h.inc"
20+
2121
#define GET_OP_CLASSES
2222
#include "src/enzyme_ad/jax/Dialect/Distributed/DistributedOps.h.inc"
2323

src/enzyme_ad/jax/Dialect/Distributed/Interfaces.td

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,4 +6,29 @@ include "mlir/IR/OpBase.td"
66
def DeviceDefTrait : NativeOpTrait<"enzyme::distributed::DeviceDefTrait">;
77
def ChannelDefTrait : NativeOpTrait<"enzyme::distributed::ChannelDefTrait">;
88

9+
def TokenReaderOpInterface : OpInterface<"TokenReaderOpInterface"> {
10+
let cppNamespace = "::mlir::enzyme::distributed";
11+
let description = [{
12+
An interface to determine which ops can read from a channel and what type they expect.
13+
Ops may read from multiple channels.
14+
}];
15+
let methods = [
16+
InterfaceMethod<"Returns the SSA values of tokens read from this op.", "::llvm::ArrayRef<::mlir::TypedValue<::mlir::enzyme::distributed::ReadTokenType>>", "getReadTokens">,
17+
InterfaceMethod<"Returns the types of tokens read from this op, parallel to getReadTokens.", "::llvm::ArrayRef<::mlir::Type>", "getReadTokenTypes">
18+
];
19+
}
20+
21+
def TokenWriterOpInterface : OpInterface<"TokenWriterOpInterface"> {
22+
let cppNamespace = "::mlir::enzyme::distributed";
23+
let description = [{
24+
An interface to determine which ops can write to a channel and what type they provide.
25+
Ops may write to multiple channels.
26+
}];
27+
let methods = [
28+
InterfaceMethod<"Returns the SSA values of tokens written from this op.", "::llvm::ArrayRef<::mlir::TypedValue<::mlir::enzyme::distributed::WriteTokenType>>", "getWriteTokens">,
29+
InterfaceMethod<"Returns the types of tokens written from this op, parallel to getWriteTokens.", "::llvm::ArrayRef<::mlir::Type>", "getWriteTokenTypes">
30+
];
31+
}
32+
33+
934
#endif // ENZYME_AD_JAX_DIALECT_DISTRIBUTED_INTERFACES

src/enzyme_ad/jax/Dialect/Distributed/Ops.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,20 @@ DefineTokenOp::verifySymbolUses(::mlir::SymbolTableCollection &symbol_table) {
156156
getChannelAttr());
157157
}
158158

159+
llvm::ArrayRef<mlir::TypedValue<WriteTokenType>> SendOp::getWriteTokens() {
160+
return llvm::SmallVector<mlir::TypedValue<WriteTokenType>, 1>{getToken()};
161+
}
162+
llvm::ArrayRef<mlir::Type> SendOp::getWriteTokenTypes() {
163+
return llvm::SmallVector<mlir::Type, 1>{getValue().getType()};
164+
}
165+
166+
llvm::ArrayRef<mlir::TypedValue<ReadTokenType>> RecvOp::getReadTokens() {
167+
return llvm::SmallVector<mlir::TypedValue<ReadTokenType>, 1>{getToken()};
168+
}
169+
llvm::ArrayRef<mlir::Type> RecvOp::getReadTokenTypes() {
170+
return llvm::SmallVector<mlir::Type, 1>{getValue().getType()};
171+
}
172+
159173
} // namespace mlir::enzyme::distributed
160174
#define GET_OP_CLASSES
161175
#include "src/enzyme_ad/jax/Dialect/Distributed/DistributedOps.cpp.inc"

src/enzyme_ad/jax/Dialect/Distributed/Ops.td

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,9 +75,24 @@ def DefineTokenOp : DistributedOp<"DefineToken", [DeclareOpInterfaceMethods<Symb
7575
let arguments = (ins
7676
SymbolRefAttr:$channel
7777
);
78-
let results = (outs TokenType:$token_out);
78+
let results = (outs ReadTokenType:$read_token, WriteTokenType:$write_token);
7979
// let hasVerifier = 1; // TODO: verify writers and readers are connected to the channel
8080
let assemblyFormat = "$channel attr-dict";
8181
}
8282

83+
def SendOp : DistributedOp<"Send", [DeclareOpInterfaceMethods<TokenWriterOpInterface>]>{
84+
let arguments = (ins
85+
WriteTokenType:$token,
86+
// value to send
87+
AnyType:$value);
88+
let assemblyFormat = "$token type($value) $value attr-dict";
89+
}
90+
91+
def RecvOp : DistributedOp<"Recv", [DeclareOpInterfaceMethods<TokenReaderOpInterface>]>{
92+
let arguments = (ins
93+
ReadTokenType:$token);
94+
let results = (outs AnyType:$value);
95+
let assemblyFormat = "$token type($value) attr-dict";
96+
}
97+
8398
#endif // ENZYME_AD_JAX_DIALECT_DISTRIBUTED_OPS_TD

src/enzyme_ad/jax/Dialect/Distributed/Types.td

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
include "Dialect.td"
55

6-
def TokenType : DistributedType<"Token", "token">;
6+
def ReadTokenType : DistributedType<"ReadToken", "read_token">;
7+
def WriteTokenType : DistributedType<"WriteToken", "write_token">;
78

89
#endif // ENZYME_DISTRIBUTED_DIALECT_TYPES_H

0 commit comments

Comments
 (0)