@@ -1103,44 +1103,69 @@ class CIRGenFunction : public CIRGenTypeCache {
11031103 // ---
11041104
11051105 private:
1106- // `returnBlock`, `returnLoc`, and all the functions that deal with them
1107- // will change and become more complicated when `switch` statements are
1108- // upstreamed. `case` statements within the `switch` are in the same scope
1109- // but have their own regions. Therefore the LexicalScope will need to
1110- // keep track of multiple return blocks.
1111- mlir::Block *returnBlock = nullptr ;
1112- std::optional<mlir::Location> returnLoc;
1113-
1114- // See the comment on `getOrCreateRetBlock`.
1106+ // On switches we need one return block per region, since cases don't
1107+ // have their own scopes but are distinct regions nonetheless.
1108+
1109+ // TODO: This implementation should change once we have support for early
1110+ // exits in MLIR structured control flow (llvm-project#161575)
1111+ llvm::SmallVector<mlir::Block *> retBlocks;
1112+ llvm::DenseMap<mlir::Block *, mlir::Location> retLocs;
1113+ llvm::DenseMap<cir::CaseOp, unsigned > retBlockInCaseIndex;
1114+ std::optional<unsigned > normalRetBlockIndex;
1115+
1116+ // There's usually only one ret block per scope, but this needs to be
1117+ // get or create because of potential unreachable return statements, note
1118+ // that for those, all source location maps to the first one found.
11151119 mlir::Block *createRetBlock (CIRGenFunction &cgf, mlir::Location loc) {
1116- assert (returnBlock == nullptr && " only one return block per scope" );
1117- // Create the cleanup block but don't hook it up just yet.
1120+ assert ((isa_and_nonnull<cir::CaseOp>(
1121+ cgf.builder .getBlock ()->getParentOp ()) ||
1122+ retBlocks.size () == 0 ) &&
1123+ " only switches can hold more than one ret block" );
1124+
1125+ // Create the return block but don't hook it up just yet.
11181126 mlir::OpBuilder::InsertionGuard guard (cgf.builder );
1119- returnBlock =
1120- cgf. builder . createBlock (cgf. builder . getBlock ()-> getParent () );
1121- updateRetLoc (returnBlock , loc);
1122- return returnBlock ;
1127+ auto *b = cgf. builder . createBlock (cgf. builder . getBlock ()-> getParent ());
1128+ retBlocks. push_back (b );
1129+ updateRetLoc (b , loc);
1130+ return b ;
11231131 }
11241132
11251133 cir::ReturnOp emitReturn (mlir::Location loc);
11261134 void emitImplicitReturn ();
11271135
11281136 public:
1129- mlir::Block *getRetBlock () { return returnBlock; }
1130- mlir::Location getRetLoc (mlir::Block *b) { return *returnLoc; }
1131- void updateRetLoc (mlir::Block *b, mlir::Location loc) { returnLoc = loc; }
1132-
1133- // Create the return block for this scope, or return the existing one.
1134- // This get-or-create logic is necessary to handle multiple return
1135- // statements within the same scope, which can happen if some of them are
1136- // dead code or if there is a `goto` into the middle of the scope.
1137+ llvm::ArrayRef<mlir::Block *> getRetBlocks () { return retBlocks; }
1138+ mlir::Location getRetLoc (mlir::Block *b) { return retLocs.at (b); }
1139+ void updateRetLoc (mlir::Block *b, mlir::Location loc) {
1140+ retLocs.insert_or_assign (b, loc);
1141+ }
1142+
11371143 mlir::Block *getOrCreateRetBlock (CIRGenFunction &cgf, mlir::Location loc) {
1138- if (returnBlock == nullptr ) {
1139- returnBlock = createRetBlock (cgf, loc);
1140- return returnBlock;
1144+ // Check if we're inside a case region
1145+ if (auto caseOp = mlir::dyn_cast_if_present<cir::CaseOp>(
1146+ cgf.builder .getBlock ()->getParentOp ())) {
1147+ auto iter = retBlockInCaseIndex.find (caseOp);
1148+ if (iter != retBlockInCaseIndex.end ()) {
1149+ // Reuse existing return block
1150+ mlir::Block *ret = retBlocks[iter->second ];
1151+ updateRetLoc (ret, loc);
1152+ return ret;
1153+ }
1154+ // Create new return block
1155+ mlir::Block *ret = createRetBlock (cgf, loc);
1156+ retBlockInCaseIndex[caseOp] = retBlocks.size () - 1 ;
1157+ return ret;
11411158 }
1142- updateRetLoc (returnBlock, loc);
1143- return returnBlock;
1159+
1160+ if (normalRetBlockIndex) {
1161+ mlir::Block *ret = retBlocks[*normalRetBlockIndex];
1162+ updateRetLoc (ret, loc);
1163+ return ret;
1164+ }
1165+
1166+ mlir::Block *ret = createRetBlock (cgf, loc);
1167+ normalRetBlockIndex = retBlocks.size () - 1 ;
1168+ return ret;
11441169 }
11451170
11461171 mlir::Block *getEntryBlock () { return entryBlock; }
0 commit comments