11#include " Utils.h"
2+ #include " Dialect.h"
23namespace mlir ::enzyme::distributed {
4+
35Region *getEnclosingDeviceParallelBranch (DeviceParallelOp parent,
46 Operation *op) {
57 auto region = op->getParentRegion ();
68 while (region->getParentOp () != parent) {
79 auto region_parent =
8- region->getParentOp (); // All regoins have parent ops...
10+ region->getParentOp (); // All regions have parent ops...
911 if (!region_parent->getParentRegion ()) // But not all ops have parent
1012 // regions (e.g. top level ops)
1113 return nullptr ;
@@ -38,25 +40,60 @@ mlir::Operation *getExecutingDevice(mlir::Operation *op) {
3840 return SymbolTable::lookupNearestSymbolFrom (parent, device_sym);
3941}
4042
41- llvm::SmallVector<mlir::BlockArgument>
42- getCorrespondingTokens (mlir::BlockArgument token) {
43- unsigned idx = token.getArgNumber ();
44- auto op = token.getOwner ()->getParentOp ();
43+ llvm::SmallVector<Token> getCorrespondingTokens (Token token) {
44+ unsigned idx = token.asBlockArg ().getArgNumber ();
45+ auto op = token.asBlockArg ().getOwner ()->getParentOp ();
4546 DeviceParallelOp parent = llvm::cast<DeviceParallelOp>(op);
46- llvm::SmallVector<mlir::BlockArgument > results;
47+ llvm::SmallVector<Token > results;
4748 results.reserve (parent.getNumRegions ());
4849 for (auto region : parent.getRegions ()) {
49- results.push_back (region->getArgument (idx));
50+ results.push_back (Token ( region->getArgument (idx) ));
5051 }
5152 return results;
5253}
5354
54- llvm::SmallVector<mlir::Operation *> getTokenUsers (mlir::BlockArgument token) {
55- llvm::SmallVector<mlir::Operation *, 4 > results;
56- for (auto user : token.getUsers ()) {
57- results.push_back (user);
55+ llvm::SmallVector<mlir::Operation *> getTokenUsers (Token token) {
56+ auto all_tokens = getCorrespondingTokens (token);
57+ llvm::SmallVector<mlir::Operation *> results;
58+ // Concatenate all users of all corresponding tokens.
59+ // Due to scoping rules and since each token is a block arg to a
60+ // different region, there should be no duplicates here.
61+ for (auto t : all_tokens) {
62+ for (auto user : t.asBlockArg ().getUsers ()) {
63+ results.push_back (user);
64+ }
5865 }
5966 return results;
6067}
6168
69+ bool isSoleSender (TokenWriterOpInterface writer) {
70+ auto tokens = writer.getWriteTokens ();
71+ // Check for conflicts on all tokens
72+ for (auto token : tokens) {
73+ auto users = getTokenUsers (token);
74+ if (!isSoleSender (writer, token, users)) {
75+ return false ;
76+ }
77+ }
78+ return true ;
79+ }
80+
81+ bool isSoleSender (TokenWriterOpInterface writer, Token token,
82+ llvm::ArrayRef<Operation *> others) {
83+ for (auto user : others) {
84+ TypedValue<TokenType> as_val = token.asTypedValue ();
85+ TokenWriterOpInterface other = dyn_cast<TokenWriterOpInterface>(user);
86+ if (other && other != writer) {
87+ // Found another writer using the same token. Check if it uses
88+ // the token to write, or only for something else:
89+ auto other_write_tokens = other.getWriteTokens ();
90+ for (auto t : other_write_tokens) {
91+ if (t == as_val) {
92+ return false ; // Found another op writing to the same token
93+ }
94+ }
95+ }
96+ }
97+ return true ;
98+ }
6299} // namespace mlir::enzyme::distributed
0 commit comments