@@ -1040,44 +1040,69 @@ class CIRGenFunction : public CIRGenTypeCache {
10401040 // ---
10411041
10421042 private:
1043- // `returnBlock`, `returnLoc`, and all the functions that deal with them
1044- // will change and become more complicated when `switch` statements are
1045- // upstreamed. `case` statements within the `switch` are in the same scope
1046- // but have their own regions. Therefore the LexicalScope will need to
1047- // keep track of multiple return blocks.
1048- mlir::Block *returnBlock = nullptr ;
1049- std::optional<mlir::Location> returnLoc;
1050-
1051- // See the comment on `getOrCreateRetBlock`.
1043+ // On switches we need one return block per region, since cases don't
1044+ // have their own scopes but are distinct regions nonetheless.
1045+
1046+ // TODO: This implementation should change once we have support for early
1047+ // exits in MLIR structured control flow (llvm-project#161575)
1048+ llvm::SmallVector<mlir::Block *> retBlocks;
1049+ llvm::DenseMap<mlir::Block *, mlir::Location> retLocs;
1050+ llvm::DenseMap<cir::CaseOp, unsigned > retBlockInCaseIndex;
1051+ std::optional<unsigned > normalRetBlockIndex;
1052+
1053+ // There's usually only one ret block per scope, but this needs to be
1054+ // get or create because of potential unreachable return statements, note
1055+ // that for those, all source location maps to the first one found.
10521056 mlir::Block *createRetBlock (CIRGenFunction &cgf, mlir::Location loc) {
1053- assert (returnBlock == nullptr && " only one return block per scope" );
1054- // Create the cleanup block but don't hook it up just yet.
1057+ assert ((isa_and_nonnull<cir::CaseOp>(
1058+ cgf.builder .getBlock ()->getParentOp ()) ||
1059+ retBlocks.size () == 0 ) &&
1060+ " only switches can hold more than one ret block" );
1061+
1062+ // Create the return block but don't hook it up just yet.
10551063 mlir::OpBuilder::InsertionGuard guard (cgf.builder );
1056- returnBlock =
1057- cgf. builder . createBlock (cgf. builder . getBlock ()-> getParent () );
1058- updateRetLoc (returnBlock , loc);
1059- return returnBlock ;
1064+ auto *b = cgf. builder . createBlock (cgf. builder . getBlock ()-> getParent ());
1065+ retBlocks. push_back (b );
1066+ updateRetLoc (b , loc);
1067+ return b ;
10601068 }
10611069
10621070 cir::ReturnOp emitReturn (mlir::Location loc);
10631071 void emitImplicitReturn ();
10641072
10651073 public:
1066- mlir::Block *getRetBlock () { return returnBlock; }
1067- mlir::Location getRetLoc (mlir::Block *b) { return *returnLoc; }
1068- void updateRetLoc (mlir::Block *b, mlir::Location loc) { returnLoc = loc; }
1069-
1070- // Create the return block for this scope, or return the existing one.
1071- // This get-or-create logic is necessary to handle multiple return
1072- // statements within the same scope, which can happen if some of them are
1073- // dead code or if there is a `goto` into the middle of the scope.
1074+ llvm::ArrayRef<mlir::Block *> getRetBlocks () { return retBlocks; }
1075+ mlir::Location getRetLoc (mlir::Block *b) { return retLocs.at (b); }
1076+ void updateRetLoc (mlir::Block *b, mlir::Location loc) {
1077+ retLocs.insert_or_assign (b, loc);
1078+ }
1079+
10741080 mlir::Block *getOrCreateRetBlock (CIRGenFunction &cgf, mlir::Location loc) {
1075- if (returnBlock == nullptr ) {
1076- returnBlock = createRetBlock (cgf, loc);
1077- return returnBlock;
1081+ // Check if we're inside a case region
1082+ if (auto caseOp = mlir::dyn_cast_if_present<cir::CaseOp>(
1083+ cgf.builder .getBlock ()->getParentOp ())) {
1084+ auto iter = retBlockInCaseIndex.find (caseOp);
1085+ if (iter != retBlockInCaseIndex.end ()) {
1086+ // Reuse existing return block
1087+ mlir::Block *ret = retBlocks[iter->second ];
1088+ updateRetLoc (ret, loc);
1089+ return ret;
1090+ }
1091+ // Create new return block
1092+ mlir::Block *ret = createRetBlock (cgf, loc);
1093+ retBlockInCaseIndex[caseOp] = retBlocks.size () - 1 ;
1094+ return ret;
10781095 }
1079- updateRetLoc (returnBlock, loc);
1080- return returnBlock;
1096+
1097+ if (normalRetBlockIndex) {
1098+ mlir::Block *ret = retBlocks[*normalRetBlockIndex];
1099+ updateRetLoc (ret, loc);
1100+ return ret;
1101+ }
1102+
1103+ mlir::Block *ret = createRetBlock (cgf, loc);
1104+ normalRetBlockIndex = retBlocks.size () - 1 ;
1105+ return ret;
10811106 }
10821107
10831108 mlir::Block *getEntryBlock () { return entryBlock; }
0 commit comments