Skip to content

Commit df8a289

Browse files
committed
3 missed matchAndRewrite migration
1 parent 2a457a4 commit df8a289

File tree

1 file changed

+19
-25
lines changed

1 file changed

+19
-25
lines changed

mlir/lib/Catalyst/Transforms/catalyst_to_llvm.cpp

Lines changed: 19 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -437,15 +437,14 @@ struct CustomCallOpPattern : public OpConversionPattern<CustomCallOp> {
437437
struct DefineCallbackOpPattern : public OpConversionPattern<CallbackOp> {
438438
using OpConversionPattern::OpConversionPattern;
439439

440-
LogicalResult match(CallbackOp op) const
440+
LogicalResult matchAndRewrite(CallbackOp op, CallbackOpAdaptor adaptor,
441+
ConversionPatternRewriter &rewriter) const override
441442
{
442443
// Only match with ops without an entry block
443-
return !op.empty() ? failure() : success();
444-
}
444+
if (!op.empty()) {
445+
return failure();
446+
}
445447

446-
void rewrite(CallbackOp op, CallbackOpAdaptor adaptor,
447-
ConversionPatternRewriter &rewriter) const
448-
{
449448
Block *entry;
450449
rewriter.modifyOpInPlace(op, [&] { entry = op.addEntryBlock(); });
451450
PatternRewriter::InsertionGuard guard(rewriter);
@@ -487,20 +486,21 @@ struct DefineCallbackOpPattern : public OpConversionPattern<CallbackOp> {
487486
}
488487
rewriter.create<LLVM::CallOp>(loc, customCallFnOp, callArgs);
489488
rewriter.create<func::ReturnOp>(loc, TypeRange{}, ValueRange{});
489+
return success();
490490
}
491491
};
492492

493493
struct ReplaceCallbackOpWithFuncOp : public OpConversionPattern<CallbackOp> {
494494
using OpConversionPattern::OpConversionPattern;
495495

496-
LogicalResult match(CallbackOp op) const
496+
LogicalResult matchAndRewrite(CallbackOp op, CallbackOpAdaptor adaptor,
497+
ConversionPatternRewriter &rewriter) const override
497498
{
498499
// Only match with ops with an entry block
499-
return !op.empty() ? success() : failure();
500-
}
501-
void rewrite(CallbackOp op, CallbackOpAdaptor adaptor,
502-
ConversionPatternRewriter &rewriter) const
503-
{
500+
if (op.empty()) {
501+
return failure();
502+
}
503+
504504
ModuleOp mod = op->getParentOfType<ModuleOp>();
505505
rewriter.setInsertionPointToStart(mod.getBody());
506506

@@ -515,6 +515,7 @@ struct ReplaceCallbackOpWithFuncOp : public OpConversionPattern<CallbackOp> {
515515
auto typeConverter = getTypeConverter();
516516
gradient::wrapMemRefArgsFunc(func, typeConverter, rewriter, op.getLoc());
517517
rewriter.eraseOp(op);
518+
return success();
518519
}
519520
};
520521

@@ -545,7 +546,8 @@ struct CallbackCallOpPattern : public OpConversionPattern<CallbackCallOp> {
545546
struct CustomGradOpPattern : public OpConversionPattern<gradient::CustomGradOp> {
546547
using OpConversionPattern::OpConversionPattern;
547548

548-
LogicalResult match(gradient::CustomGradOp op) const
549+
LogicalResult matchAndRewrite(gradient::CustomGradOp op, gradient::CustomGradOpAdaptor adaptor,
550+
ConversionPatternRewriter &rewriter) const override
549551
{
550552
// only match after all three are func.func
551553
auto callee = op.getCalleeAttr();
@@ -556,22 +558,14 @@ struct CustomGradOpPattern : public OpConversionPattern<gradient::CustomGradOp>
556558
auto forwardOp = mod.lookupSymbol<func::FuncOp>(forward);
557559
auto reverseOp = mod.lookupSymbol<func::FuncOp>(reverse);
558560
auto ready = calleeOp && forwardOp && reverseOp;
559-
return ready ? success() : failure();
560-
}
561+
if (!ready) {
562+
return failure();
563+
}
561564

562-
void rewrite(gradient::CustomGradOp op, gradient::CustomGradOpAdaptor adaptor,
563-
ConversionPatternRewriter &rewriter) const
564-
{
565565
auto loc = op.getLoc();
566-
ModuleOp mod = op->getParentOfType<ModuleOp>();
567-
auto callee = op.getCalleeAttr();
568-
auto forward = op.getForwardAttr();
569-
auto reverse = op.getReverseAttr();
570-
auto calleeOp = mod.lookupSymbol<func::FuncOp>(callee);
571-
auto forwardOp = mod.lookupSymbol<func::FuncOp>(forward);
572-
auto reverseOp = mod.lookupSymbol<func::FuncOp>(reverse);
573566
gradient::insertEnzymeCustomGradient(rewriter, mod, loc, calleeOp, forwardOp, reverseOp);
574567
rewriter.eraseOp(op);
568+
return success();
575569
}
576570
};
577571

0 commit comments

Comments
 (0)