@@ -94,106 +94,111 @@ struct AssertOpLowering : public ConvertOpToLLVMPattern<cf::AssertOp> {
9494 bool abortOnFailedAssert = true ;
9595};
9696
97- // / The cf->LLVM lowerings for branching ops require that the blocks they jump
98- // / to first have updated types which should be handled by a pattern operating
99- // / on the parent op.
100- static LogicalResult verifyMatchingValues (ConversionPatternRewriter &rewriter,
101- ValueRange operands,
102- ValueRange blockArgs, Location loc,
103- llvm::StringRef messagePrefix) {
104- for (const auto &idxAndTypes :
105- llvm::enumerate (llvm::zip (blockArgs, operands))) {
106- int64_t i = idxAndTypes.index ();
107- Value argValue =
108- rewriter.getRemappedValue (std::get<0 >(idxAndTypes.value ()));
109- Type operandType = std::get<1 >(idxAndTypes.value ()).getType ();
110- // In the case of an invalid jump, the block argument will have been
111- // remapped to an UnrealizedConversionCast. In the case of a valid jump,
112- // there might still be a no-op conversion cast with both types being equal.
113- // Consider both of these details to see if the jump would be invalid.
114- if (auto op = dyn_cast_or_null<UnrealizedConversionCastOp>(
115- argValue.getDefiningOp ())) {
116- if (op.getOperandTypes ().front () != operandType) {
117- return rewriter.notifyMatchFailure (loc, [&](Diagnostic &diag) {
118- diag << messagePrefix;
119- diag << " mismatched types from operand # " << i << " " ;
120- diag << operandType;
121- diag << " not compatible with destination block argument type " ;
122- diag << op.getOperandTypes ().front ();
123- diag << " which should be converted with the parent op." ;
124- });
125- }
126- }
127- }
128- return success ();
97+ // / Helper function for converting branch ops. This function converts the
98+ // / signature of the given block. If the new block signature is different from
99+ // / `expectedTypes`, returns "failure".
100+ static FailureOr<Block *> getConvertedBlock (ConversionPatternRewriter &rewriter,
101+ const TypeConverter *converter,
102+ Operation *branchOp, Block *block,
103+ TypeRange expectedTypes) {
104+ assert (converter && " expected non-null type converter" );
105+ assert (!block->isEntryBlock () && " entry blocks have no predecessors" );
106+
107+ // There is nothing to do if the types already match.
108+ if (block->getArgumentTypes () == expectedTypes)
109+ return block;
110+
111+ // Compute the new block argument types and convert the block.
112+ std::optional<TypeConverter::SignatureConversion> conversion =
113+ converter->convertBlockSignature (block);
114+ if (!conversion)
115+ return rewriter.notifyMatchFailure (branchOp,
116+ " could not compute block signature" );
117+ if (expectedTypes != conversion->getConvertedTypes ())
118+ return rewriter.notifyMatchFailure (
119+ branchOp,
120+ " mismatch between adaptor operand types and computed block signature" );
121+ return rewriter.applySignatureConversion (block, *conversion, converter);
129122}
130123
131- // / Ensure that all block types were updated and then create an LLVM::BrOp
124+ // / Convert the destination block signature (if necessary) and lower the branch
125+ // / op to llvm.br.
132126struct BranchOpLowering : public ConvertOpToLLVMPattern <cf::BranchOp> {
133127 using ConvertOpToLLVMPattern<cf::BranchOp>::ConvertOpToLLVMPattern;
134128
135129 LogicalResult
136130 matchAndRewrite (cf::BranchOp op, typename cf::BranchOp::Adaptor adaptor,
137131 ConversionPatternRewriter &rewriter) const override {
138- if ( failed ( verifyMatchingValues (rewriter, adaptor. getDestOperands (),
139- op.getSuccessor ()-> getArguments (),
140- op. getLoc (),
141- /* messagePrefix= */ " " ) ))
132+ FailureOr<Block *> convertedBlock =
133+ getConvertedBlock (rewriter, getTypeConverter (), op, op.getSuccessor (),
134+ TypeRange (adaptor. getOperands ()));
135+ if ( failed (convertedBlock ))
142136 return failure ();
143-
144- rewriter.replaceOpWithNewOp <LLVM::BrOp>(
145- op, adaptor.getOperands (), op->getSuccessors (), op->getAttrs ());
137+ rewriter.replaceOpWithNewOp <LLVM::BrOp>(op, adaptor.getOperands (),
138+ *convertedBlock);
146139 return success ();
147140 }
148141};
149142
150- // / Ensure that all block types were updated and then create an LLVM::CondBrOp
143+ // / Convert the destination block signatures (if necessary) and lower the
144+ // / branch op to llvm.cond_br.
151145struct CondBranchOpLowering : public ConvertOpToLLVMPattern <cf::CondBranchOp> {
152146 using ConvertOpToLLVMPattern<cf::CondBranchOp>::ConvertOpToLLVMPattern;
153147
154148 LogicalResult
155149 matchAndRewrite (cf::CondBranchOp op,
156150 typename cf::CondBranchOp::Adaptor adaptor,
157151 ConversionPatternRewriter &rewriter) const override {
158- if (failed (verifyMatchingValues (rewriter, adaptor.getFalseDestOperands (),
159- op.getFalseDest ()->getArguments (),
160- op.getLoc (), " in false case branch " )))
152+ FailureOr<Block *> convertedTrueBlock =
153+ getConvertedBlock (rewriter, getTypeConverter (), op, op.getTrueDest (),
154+ TypeRange (adaptor.getTrueDestOperands ()));
155+ if (failed (convertedTrueBlock))
161156 return failure ();
162- if (failed (verifyMatchingValues (rewriter, adaptor.getTrueDestOperands (),
163- op.getTrueDest ()->getArguments (),
164- op.getLoc (), " in true case branch " )))
157+ FailureOr<Block *> convertedFalseBlock =
158+ getConvertedBlock (rewriter, getTypeConverter (), op, op.getFalseDest (),
159+ TypeRange (adaptor.getFalseDestOperands ()));
160+ if (failed (convertedFalseBlock))
165161 return failure ();
166-
167162 rewriter.replaceOpWithNewOp <LLVM::CondBrOp>(
168- op, adaptor.getOperands (), op->getSuccessors (), op->getAttrs ());
163+ op, adaptor.getCondition (), *convertedTrueBlock,
164+ adaptor.getTrueDestOperands (), *convertedFalseBlock,
165+ adaptor.getFalseDestOperands ());
169166 return success ();
170167 }
171168};
172169
173- // / Ensure that all block types were updated and then create an LLVM::SwitchOp
170+ // / Convert the destination block signatures (if necessary) and lower the
171+ // / switch op to llvm.switch.
174172struct SwitchOpLowering : public ConvertOpToLLVMPattern <cf::SwitchOp> {
175173 using ConvertOpToLLVMPattern<cf::SwitchOp>::ConvertOpToLLVMPattern;
176174
177175 LogicalResult
178176 matchAndRewrite (cf::SwitchOp op, typename cf::SwitchOp::Adaptor adaptor,
179177 ConversionPatternRewriter &rewriter) const override {
180- if (failed (verifyMatchingValues (rewriter, adaptor.getDefaultOperands (),
181- op.getDefaultDestination ()->getArguments (),
182- op.getLoc (), " in switch default case " )))
178+ // Get or convert default block.
179+ FailureOr<Block *> convertedDefaultBlock = getConvertedBlock (
180+ rewriter, getTypeConverter (), op, op.getDefaultDestination (),
181+ TypeRange (adaptor.getDefaultOperands ()));
182+ if (failed (convertedDefaultBlock))
183183 return failure ();
184184
185- for (const auto &i : llvm::enumerate (
186- llvm::zip (adaptor.getCaseOperands (), op.getCaseDestinations ()))) {
187- if (failed (verifyMatchingValues (
188- rewriter, std::get<0 >(i.value ()),
189- std::get<1 >(i.value ())->getArguments (), op.getLoc (),
190- " in switch case " + std::to_string (i.index ()) + " " ))) {
185+ // Get or convert all case blocks.
186+ SmallVector<Block *> caseDestinations;
187+ SmallVector<ValueRange> caseOperands = adaptor.getCaseOperands ();
188+ for (auto it : llvm::enumerate (op.getCaseDestinations ())) {
189+ Block *b = it.value ();
190+ FailureOr<Block *> convertedBlock =
191+ getConvertedBlock (rewriter, getTypeConverter (), op, b,
192+ TypeRange (caseOperands[it.index ()]));
193+ if (failed (convertedBlock))
191194 return failure ();
192- }
195+ caseDestinations. push_back (*convertedBlock);
193196 }
194197
195198 rewriter.replaceOpWithNewOp <LLVM::SwitchOp>(
196- op, adaptor.getOperands (), op->getSuccessors (), op->getAttrs ());
199+ op, adaptor.getFlag (), *convertedDefaultBlock,
200+ adaptor.getDefaultOperands (), adaptor.getCaseValuesAttr (),
201+ caseDestinations, caseOperands);
197202 return success ();
198203 }
199204};
@@ -230,14 +235,22 @@ struct ConvertControlFlowToLLVM
230235
231236 // / Run the dialect converter on the module.
232237 void runOnOperation () override {
233- LLVMConversionTarget target (getContext ());
234- RewritePatternSet patterns (&getContext ());
235-
236- LowerToLLVMOptions options (&getContext ());
238+ MLIRContext *ctx = &getContext ();
239+ LLVMConversionTarget target (*ctx);
240+ // This pass lowers only CF dialect ops, but it also modifies block
241+ // signatures inside other ops. These ops should be treated as legal. They
242+ // are lowered by other passes.
243+ target.markUnknownOpDynamicallyLegal ([&](Operation *op) {
244+ return op->getDialect () !=
245+ ctx->getLoadedDialect <cf::ControlFlowDialect>();
246+ });
247+
248+ LowerToLLVMOptions options (ctx);
237249 if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout )
238250 options.overrideIndexBitwidth (indexBitwidth);
239251
240- LLVMTypeConverter converter (&getContext (), options);
252+ LLVMTypeConverter converter (ctx, options);
253+ RewritePatternSet patterns (ctx);
241254 mlir::cf::populateControlFlowToLLVMConversionPatterns (converter, patterns);
242255
243256 if (failed (applyPartialConversion (getOperation (), target,
0 commit comments