@@ -94,106 +94,117 @@ struct AssertOpLowering : public ConvertOpToLLVMPattern<cf::AssertOp> {
9494 bool abortOnFailedAssert = true ;
9595};
9696
97- // / The cf->LLVM lowerings for branching ops require that the blocks they jump
98- // / to first have updated types which should be handled by a pattern operating
99- // / on the parent op.
100- static LogicalResult verifyMatchingValues (ConversionPatternRewriter &rewriter,
101- ValueRange operands,
102- ValueRange blockArgs, Location loc,
103- llvm::StringRef messagePrefix) {
104- for (const auto &idxAndTypes :
105- llvm::enumerate (llvm::zip (blockArgs, operands))) {
106- int64_t i = idxAndTypes.index ();
107- Value argValue =
108- rewriter.getRemappedValue (std::get<0 >(idxAndTypes.value ()));
109- Type operandType = std::get<1 >(idxAndTypes.value ()).getType ();
110- // In the case of an invalid jump, the block argument will have been
111- // remapped to an UnrealizedConversionCast. In the case of a valid jump,
112- // there might still be a no-op conversion cast with both types being equal.
113- // Consider both of these details to see if the jump would be invalid.
114- if (auto op = dyn_cast_or_null<UnrealizedConversionCastOp>(
115- argValue.getDefiningOp ())) {
116- if (op.getOperandTypes ().front () != operandType) {
117- return rewriter.notifyMatchFailure (loc, [&](Diagnostic &diag) {
118- diag << messagePrefix;
119- diag << " mismatched types from operand # " << i << " " ;
120- diag << operandType;
121- diag << " not compatible with destination block argument type " ;
122- diag << op.getOperandTypes ().front ();
123- diag << " which should be converted with the parent op." ;
124- });
125- }
126- }
127- }
128- return success ();
97+ // / Helper function for converting branch ops. This function converts the
98+ // / signature of the given block. If the new block signature is different from
99+ // / `expectedTypes`, returns "failure".
100+ static FailureOr<Block *> getConvertedBlock (ConversionPatternRewriter &rewriter,
101+ const TypeConverter *converter,
102+ Operation *branchOp, Block *block,
103+ TypeRange expectedTypes) {
104+ assert (converter && " expected non-null type converter" );
105+ assert (!block->isEntryBlock () && " entry blocks have no predecessors" );
106+
107+ // There is nothing to do if the types already match.
108+ if (block->getArgumentTypes () == expectedTypes)
109+ return block;
110+
111+ // Compute the new block argument types and convert the block.
112+ std::optional<TypeConverter::SignatureConversion> conversion =
113+ converter->convertBlockSignature (block);
114+ if (!conversion)
115+ return rewriter.notifyMatchFailure (branchOp,
116+ " could not compute block signature" );
117+ if (expectedTypes != conversion->getConvertedTypes ())
118+ return rewriter.notifyMatchFailure (
119+ branchOp,
120+ " mismatch between adaptor operand types and computed block signature" );
121+ return rewriter.applySignatureConversion (block, *conversion, converter);
129122}
130123
131- // / Ensure that all block types were updated and then create an LLVM::BrOp
124+ // / Convert the destination block signature (if necessary) and lower the branch
125+ // / op to llvm.br.
132126struct BranchOpLowering : public ConvertOpToLLVMPattern <cf::BranchOp> {
133127 using ConvertOpToLLVMPattern<cf::BranchOp>::ConvertOpToLLVMPattern;
134128
135129 LogicalResult
136130 matchAndRewrite (cf::BranchOp op, typename cf::BranchOp::Adaptor adaptor,
137131 ConversionPatternRewriter &rewriter) const override {
138- if ( failed ( verifyMatchingValues (rewriter, adaptor. getDestOperands (),
139- op.getSuccessor ()-> getArguments (),
140- op. getLoc (),
141- /* messagePrefix= */ " " ) ))
132+ FailureOr<Block *> convertedBlock =
133+ getConvertedBlock (rewriter, getTypeConverter (), op, op.getSuccessor (),
134+ TypeRange (adaptor. getOperands ()));
135+ if ( failed (convertedBlock ))
142136 return failure ();
143-
144- rewriter.replaceOpWithNewOp <LLVM::BrOp>(
145- op, adaptor.getOperands (), op->getSuccessors (), op->getAttrs ());
137+ Operation *newOp = rewriter.replaceOpWithNewOp <LLVM::BrOp>(
138+ op, adaptor.getOperands (), *convertedBlock);
139+ // TODO: We should not just forward all attributes like that. But there are
140+ // existing Flang tests that depend on this behavior.
141+ newOp->setAttrs (op->getAttrDictionary ());
146142 return success ();
147143 }
148144};
149145
150- // / Ensure that all block types were updated and then create an LLVM::CondBrOp
146+ // / Convert the destination block signatures (if necessary) and lower the
147+ // / branch op to llvm.cond_br.
151148struct CondBranchOpLowering : public ConvertOpToLLVMPattern <cf::CondBranchOp> {
152149 using ConvertOpToLLVMPattern<cf::CondBranchOp>::ConvertOpToLLVMPattern;
153150
154151 LogicalResult
155152 matchAndRewrite (cf::CondBranchOp op,
156153 typename cf::CondBranchOp::Adaptor adaptor,
157154 ConversionPatternRewriter &rewriter) const override {
158- if (failed (verifyMatchingValues (rewriter, adaptor.getFalseDestOperands (),
159- op.getFalseDest ()->getArguments (),
160- op.getLoc (), " in false case branch " )))
155+ FailureOr<Block *> convertedTrueBlock =
156+ getConvertedBlock (rewriter, getTypeConverter (), op, op.getTrueDest (),
157+ TypeRange (adaptor.getTrueDestOperands ()));
158+ if (failed (convertedTrueBlock))
161159 return failure ();
162- if (failed (verifyMatchingValues (rewriter, adaptor.getTrueDestOperands (),
163- op.getTrueDest ()->getArguments (),
164- op.getLoc (), " in true case branch " )))
160+ FailureOr<Block *> convertedFalseBlock =
161+ getConvertedBlock (rewriter, getTypeConverter (), op, op.getFalseDest (),
162+ TypeRange (adaptor.getFalseDestOperands ()));
163+ if (failed (convertedFalseBlock))
165164 return failure ();
166-
167- rewriter.replaceOpWithNewOp <LLVM::CondBrOp>(
168- op, adaptor.getOperands (), op->getSuccessors (), op->getAttrs ());
165+ Operation *newOp = rewriter.replaceOpWithNewOp <LLVM::CondBrOp>(
166+ op, adaptor.getCondition (), *convertedTrueBlock,
167+ adaptor.getTrueDestOperands (), *convertedFalseBlock,
168+ adaptor.getFalseDestOperands ());
169+ // TODO: We should not just forward all attributes like that. But there are
170+ // existing Flang tests that depend on this behavior.
171+ newOp->setAttrs (op->getAttrDictionary ());
169172 return success ();
170173 }
171174};
172175
173- // / Ensure that all block types were updated and then create an LLVM::SwitchOp
176+ // / Convert the destination block signatures (if necessary) and lower the
177+ // / switch op to llvm.switch.
174178struct SwitchOpLowering : public ConvertOpToLLVMPattern <cf::SwitchOp> {
175179 using ConvertOpToLLVMPattern<cf::SwitchOp>::ConvertOpToLLVMPattern;
176180
177181 LogicalResult
178182 matchAndRewrite (cf::SwitchOp op, typename cf::SwitchOp::Adaptor adaptor,
179183 ConversionPatternRewriter &rewriter) const override {
180- if (failed (verifyMatchingValues (rewriter, adaptor.getDefaultOperands (),
181- op.getDefaultDestination ()->getArguments (),
182- op.getLoc (), " in switch default case " )))
184+ // Get or convert default block.
185+ FailureOr<Block *> convertedDefaultBlock = getConvertedBlock (
186+ rewriter, getTypeConverter (), op, op.getDefaultDestination (),
187+ TypeRange (adaptor.getDefaultOperands ()));
188+ if (failed (convertedDefaultBlock))
183189 return failure ();
184190
185- for (const auto &i : llvm::enumerate (
186- llvm::zip (adaptor.getCaseOperands (), op.getCaseDestinations ()))) {
187- if (failed (verifyMatchingValues (
188- rewriter, std::get<0 >(i.value ()),
189- std::get<1 >(i.value ())->getArguments (), op.getLoc (),
190- " in switch case " + std::to_string (i.index ()) + " " ))) {
191+ // Get or convert all case blocks.
192+ SmallVector<Block *> caseDestinations;
193+ SmallVector<ValueRange> caseOperands = adaptor.getCaseOperands ();
194+ for (auto it : llvm::enumerate (op.getCaseDestinations ())) {
195+ Block *b = it.value ();
196+ FailureOr<Block *> convertedBlock =
197+ getConvertedBlock (rewriter, getTypeConverter (), op, b,
198+ TypeRange (caseOperands[it.index ()]));
199+ if (failed (convertedBlock))
191200 return failure ();
192- }
201+ caseDestinations. push_back (*convertedBlock);
193202 }
194203
195204 rewriter.replaceOpWithNewOp <LLVM::SwitchOp>(
196- op, adaptor.getOperands (), op->getSuccessors (), op->getAttrs ());
205+ op, adaptor.getFlag (), *convertedDefaultBlock,
206+ adaptor.getDefaultOperands (), adaptor.getCaseValuesAttr (),
207+ caseDestinations, caseOperands);
197208 return success ();
198209 }
199210};
@@ -230,14 +241,22 @@ struct ConvertControlFlowToLLVM
230241
231242 // / Run the dialect converter on the module.
232243 void runOnOperation () override {
233- LLVMConversionTarget target (getContext ());
234- RewritePatternSet patterns (&getContext ());
235-
236- LowerToLLVMOptions options (&getContext ());
244+ MLIRContext *ctx = &getContext ();
245+ LLVMConversionTarget target (*ctx);
246+ // This pass lowers only CF dialect ops, but it also modifies block
247+ // signatures inside other ops. These ops should be treated as legal. They
248+ // are lowered by other passes.
249+ target.markUnknownOpDynamicallyLegal ([&](Operation *op) {
250+ return op->getDialect () !=
251+ ctx->getLoadedDialect <cf::ControlFlowDialect>();
252+ });
253+
254+ LowerToLLVMOptions options (ctx);
237255 if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout )
238256 options.overrideIndexBitwidth (indexBitwidth);
239257
240- LLVMTypeConverter converter (&getContext (), options);
258+ LLVMTypeConverter converter (ctx, options);
259+ RewritePatternSet patterns (ctx);
241260 mlir::cf::populateControlFlowToLLVMConversionPatterns (converter, patterns);
242261
243262 if (failed (applyPartialConversion (getOperation (), target,
0 commit comments