@@ -111,20 +111,44 @@ GroupSplitOp::verifySymbolUses(::mlir::SymbolTableCollection &symbol_table) {
111111 getDeviceGroupAttr ());
112112}
113113
114- LogicalResult
115- SplitBranchOp::verifySymbolUses (::mlir::SymbolTableCollection &symbol_table) {
116- // Split branches have programs for individual devices or channels
117- Operation *dev_or_chan =
118- symbol_table.lookupNearestSymbolFrom (*this , getDeviceOrChannelAttr ());
119- if (!dev_or_chan || !(dev_or_chan->hasTrait <DeviceDefTrait>() ||
120- dev_or_chan->hasTrait <ChannelDefTrait>())) {
121- mlir::emitError (getLoc ())
122- << " branches must reference a valid device or channel" ;
123- return mlir::failure ();
114+ // Printer/parser for GroupsplitOp branches
115+ mlir::ParseResult parseSplitBranches (
116+ OpAsmParser &parser, mlir::ArrayAttr &branchAssignments,
117+ llvm::SmallVector<std::unique_ptr<::mlir::Region>, 2 > &branchesRegions) {
118+ // Expect 0 or more `branch` $symbol_name $symbol_region
119+ // While next token is `branch`:
120+ llvm::SmallVector<mlir::Attribute, 2 > assignment_symbols;
121+ while (parser.parseOptionalKeyword (" branch" ).succeeded ()) {
122+ // Parse symbol name
123+ mlir::SymbolRefAttr sym;
124+ auto sym_parse_failed = parser.parseAttribute <mlir::SymbolRefAttr>(sym);
125+ if (sym_parse_failed)
126+ return mlir::failure ();
127+ assignment_symbols.push_back (sym);
128+
129+ // Put placeholder region in list and parse into it
130+ branchesRegions.push_back (std::make_unique<mlir::Region>());
131+ auto parse_region_failed = parser.parseRegion (*branchesRegions.back ());
132+ if (parse_region_failed)
133+ return mlir::failure ();
124134 }
135+
136+ branchAssignments = mlir::ArrayAttr::get (parser.getBuilder ().getContext (),
137+ assignment_symbols);
125138 return mlir::success ();
126139}
127140
141+ void printSplitBranches (OpAsmPrinter &printer, const GroupSplitOp &op,
142+ const mlir::ArrayAttr branchAssignments,
143+ const llvm::MutableArrayRef<mlir::Region> branches) {
144+ // Print each branch as `branch` $symbol_name $symbol_region
145+ for (size_t i = 0 ; i < branches.size (); i++) {
146+ printer << " branch " ;
147+ printer.printAttribute (branchAssignments[i]);
148+ printer.printRegion (branches[i]);
149+ }
150+ }
151+
128152LogicalResult
129153DefineTokenOp::verifySymbolUses (::mlir::SymbolTableCollection &symbol_table) {
130154 // Tokens need to indicate which channel they communicate over
0 commit comments