@@ -437,15 +437,14 @@ struct CustomCallOpPattern : public OpConversionPattern<CustomCallOp> {
437437struct DefineCallbackOpPattern : public OpConversionPattern <CallbackOp> {
438438 using OpConversionPattern::OpConversionPattern;
439439
440- LogicalResult match (CallbackOp op) const
440+ LogicalResult matchAndRewrite (CallbackOp op, CallbackOpAdaptor adaptor,
441+ ConversionPatternRewriter &rewriter) const override
441442 {
442443 // Only match with ops without an entry block
443- return !op.empty () ? failure () : success ();
444- }
444+ if (!op.empty ()) {
445+ return failure ();
446+ }
445447
446- void rewrite (CallbackOp op, CallbackOpAdaptor adaptor,
447- ConversionPatternRewriter &rewriter) const
448- {
449448 Block *entry;
450449 rewriter.modifyOpInPlace (op, [&] { entry = op.addEntryBlock (); });
451450 PatternRewriter::InsertionGuard guard (rewriter);
@@ -487,20 +486,21 @@ struct DefineCallbackOpPattern : public OpConversionPattern<CallbackOp> {
487486 }
488487 rewriter.create <LLVM::CallOp>(loc, customCallFnOp, callArgs);
489488 rewriter.create <func::ReturnOp>(loc, TypeRange{}, ValueRange{});
489+ return success ();
490490 }
491491};
492492
493493struct ReplaceCallbackOpWithFuncOp : public OpConversionPattern <CallbackOp> {
494494 using OpConversionPattern::OpConversionPattern;
495495
496- LogicalResult match (CallbackOp op) const
496+ LogicalResult matchAndRewrite (CallbackOp op, CallbackOpAdaptor adaptor,
497+ ConversionPatternRewriter &rewriter) const override
497498 {
498499 // Only match with ops with an entry block
499- return !op.empty () ? success () : failure ();
500- }
501- void rewrite (CallbackOp op, CallbackOpAdaptor adaptor,
502- ConversionPatternRewriter &rewriter) const
503- {
500+ if (op.empty ()) {
501+ return failure ();
502+ }
503+
504504 ModuleOp mod = op->getParentOfType <ModuleOp>();
505505 rewriter.setInsertionPointToStart (mod.getBody ());
506506
@@ -515,6 +515,7 @@ struct ReplaceCallbackOpWithFuncOp : public OpConversionPattern<CallbackOp> {
515515 auto typeConverter = getTypeConverter ();
516516 gradient::wrapMemRefArgsFunc (func, typeConverter, rewriter, op.getLoc ());
517517 rewriter.eraseOp (op);
518+ return success ();
518519 }
519520};
520521
@@ -545,7 +546,8 @@ struct CallbackCallOpPattern : public OpConversionPattern<CallbackCallOp> {
545546struct CustomGradOpPattern : public OpConversionPattern <gradient::CustomGradOp> {
546547 using OpConversionPattern::OpConversionPattern;
547548
548- LogicalResult match (gradient::CustomGradOp op) const
549+ LogicalResult matchAndRewrite (gradient::CustomGradOp op, gradient::CustomGradOpAdaptor adaptor,
550+ ConversionPatternRewriter &rewriter) const override
549551 {
550552 // only match after all three are func.func
551553 auto callee = op.getCalleeAttr ();
@@ -556,22 +558,14 @@ struct CustomGradOpPattern : public OpConversionPattern<gradient::CustomGradOp>
556558 auto forwardOp = mod.lookupSymbol <func::FuncOp>(forward);
557559 auto reverseOp = mod.lookupSymbol <func::FuncOp>(reverse);
558560 auto ready = calleeOp && forwardOp && reverseOp;
559- return ready ? success () : failure ();
560- }
561+ if (!ready) {
562+ return failure ();
563+ }
561564
562- void rewrite (gradient::CustomGradOp op, gradient::CustomGradOpAdaptor adaptor,
563- ConversionPatternRewriter &rewriter) const
564- {
565565 auto loc = op.getLoc ();
566- ModuleOp mod = op->getParentOfType <ModuleOp>();
567- auto callee = op.getCalleeAttr ();
568- auto forward = op.getForwardAttr ();
569- auto reverse = op.getReverseAttr ();
570- auto calleeOp = mod.lookupSymbol <func::FuncOp>(callee);
571- auto forwardOp = mod.lookupSymbol <func::FuncOp>(forward);
572- auto reverseOp = mod.lookupSymbol <func::FuncOp>(reverse);
573566 gradient::insertEnzymeCustomGradient (rewriter, mod, loc, calleeOp, forwardOp, reverseOp);
574567 rewriter.eraseOp (op);
568+ return success ();
575569 }
576570};
577571
0 commit comments