Skip to content

Commit ff303ea

Browse files
EganJwsmoses
authored andcommitted
Working commit: single token type
1 parent 1d8e1fa commit ff303ea

File tree

4 files changed

+9
-10
lines changed

4 files changed

+9
-10
lines changed

src/enzyme_ad/jax/Dialect/Distributed/Interfaces.td

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ def TokenReaderOpInterface : OpInterface<"TokenReaderOpInterface"> {
1313
Ops may read from multiple channels.
1414
}];
1515
let methods = [
16-
InterfaceMethod<"Returns the SSA values of tokens read from this op.", "::llvm::ArrayRef<::mlir::TypedValue<::mlir::enzyme::distributed::ReadTokenType>>", "getReadTokens">,
16+
InterfaceMethod<"Returns the SSA values of tokens read from this op.", "::llvm::ArrayRef<::mlir::TypedValue<::mlir::enzyme::distributed::TokenType>>", "getReadTokens">,
1717
InterfaceMethod<"Returns the types of tokens read from this op, parallel to getReadTokens.", "::llvm::ArrayRef<::mlir::Type>", "getReadTokenTypes">
1818
];
1919
}
@@ -25,7 +25,7 @@ def TokenWriterOpInterface : OpInterface<"TokenWriterOpInterface"> {
2525
Ops may write to multiple channels.
2626
}];
2727
let methods = [
28-
InterfaceMethod<"Returns the SSA values of tokens written from this op.", "::llvm::ArrayRef<::mlir::TypedValue<::mlir::enzyme::distributed::WriteTokenType>>", "getWriteTokens">,
28+
InterfaceMethod<"Returns the SSA values of tokens written from this op.", "::llvm::ArrayRef<::mlir::TypedValue<::mlir::enzyme::distributed::TokenType>>", "getWriteTokens">,
2929
InterfaceMethod<"Returns the types of tokens written from this op, parallel to getWriteTokens.", "::llvm::ArrayRef<::mlir::Type>", "getWriteTokenTypes">
3030
];
3131
}

src/enzyme_ad/jax/Dialect/Distributed/Ops.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -156,15 +156,15 @@ DefineTokenOp::verifySymbolUses(::mlir::SymbolTableCollection &symbol_table) {
156156
getChannelAttr());
157157
}
158158

159-
llvm::ArrayRef<mlir::TypedValue<WriteTokenType>> SendOp::getWriteTokens() {
160-
return llvm::SmallVector<mlir::TypedValue<WriteTokenType>, 1>{getToken()};
159+
llvm::ArrayRef<mlir::TypedValue<TokenType>> SendOp::getWriteTokens() {
160+
return llvm::SmallVector<mlir::TypedValue<TokenType>, 1>{getToken()};
161161
}
162162
llvm::ArrayRef<mlir::Type> SendOp::getWriteTokenTypes() {
163163
return llvm::SmallVector<mlir::Type, 1>{getValue().getType()};
164164
}
165165

166-
llvm::ArrayRef<mlir::TypedValue<ReadTokenType>> RecvOp::getReadTokens() {
167-
return llvm::SmallVector<mlir::TypedValue<ReadTokenType>, 1>{getToken()};
166+
llvm::ArrayRef<mlir::TypedValue<TokenType>> RecvOp::getReadTokens() {
167+
return llvm::SmallVector<mlir::TypedValue<TokenType>, 1>{getToken()};
168168
}
169169
llvm::ArrayRef<mlir::Type> RecvOp::getReadTokenTypes() {
170170
return llvm::SmallVector<mlir::Type, 1>{getValue().getType()};

src/enzyme_ad/jax/Dialect/Distributed/Ops.td

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ def DefineTokenOp : DistributedOp<"DefineToken", [DeclareOpInterfaceMethods<Symb
7575
let arguments = (ins
7676
SymbolRefAttr:$channel
7777
);
78-
let results = (outs ReadTokenType:$read_token, WriteTokenType:$write_token);
78+
let results = (outs TokenType:$token);
7979
// let hasVerifier = 1; // TODO: verify writers and readers are connected to the channel
8080
let assemblyFormat = "$channel attr-dict";
8181
}
@@ -90,7 +90,7 @@ def SendOp : DistributedOp<"Send", [DeclareOpInterfaceMethods<TokenWriterOpInter
9090

9191
def RecvOp : DistributedOp<"Recv", [DeclareOpInterfaceMethods<TokenReaderOpInterface>]>{
9292
let arguments = (ins
93-
ReadTokenType:$token);
93+
TokenType:$token);
9494
let results = (outs AnyType:$value);
9595
let assemblyFormat = "$token type($value) attr-dict";
9696
}

src/enzyme_ad/jax/Dialect/Distributed/Types.td

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33

44
include "Dialect.td"
55

6-
def ReadTokenType : DistributedType<"ReadToken", "read_token">;
7-
def WriteTokenType : DistributedType<"WriteToken", "write_token">;
6+
def TokenType : DistributedType<"Token", "token">;
87

98
#endif // ENZYME_DISTRIBUTED_DIALECT_TYPES_H

0 commit comments

Comments
 (0)