66// ===----------------------------------------------------------------------===//
77
88#include " mlir/Bytecode/BytecodeOpInterface.h"
9- #include " src/enzyme_ad/jax/Dialect/Tessera/Ops.h"
109#include " mlir/Dialect/Func/IR/FuncOps.h"
11- #include " mlir/IR/PatternMatch.h"
10+ #include " mlir/IR/BuiltinDialect.h"
11+ #include " mlir/IR/BuiltinOps.h"
1212#include " mlir/IR/IRMapping.h"
13- #include " mlir/Transforms/DialectConversion.h"
14- #include " mlir/Interfaces/FunctionInterfaces.h"
13+ #include " mlir/IR/PatternMatch.h"
1514#include " mlir/Interfaces/CallInterfaces.h"
15+ #include " mlir/Interfaces/FunctionInterfaces.h"
1616#include " mlir/Pass/Pass.h"
17- #include " mlir/IR/BuiltinOps.h"
17+ #include " mlir/Transforms/DialectConversion.h"
18+ #include " mlir/Transforms/GreedyPatternRewriteDriver.h"
1819#include " src/enzyme_ad/jax/Dialect/Tessera/Dialect.h"
20+ #include " src/enzyme_ad/jax/Dialect/Tessera/Ops.h"
1921#include " src/enzyme_ad/jax/Passes/Tessera/Passes.h"
20- #include " mlir/Transforms/GreedyPatternRewriteDriver.h"
2122
2223using namespace mlir ;
2324using namespace mlir ::enzyme;
@@ -59,6 +60,22 @@ class FuncOpRewrite final : public OpRewritePattern<func::FuncOp> {
5960 funcOp.getBody ().cloneInto (&tesseraDefineOp.getBody (),
6061 tesseraDefineOp.getBody ().end (),
6162 mapper);
63+
64+ // Now walk through the cloned operations and convert func.return to
65+ // tessera.return
66+ tesseraDefineOp.walk ([&](func::ReturnOp returnOp) {
67+ rewriter.setInsertionPoint (returnOp);
68+ rewriter.replaceOpWithNewOp <tessera::ReturnOp>(returnOp,
69+ returnOp.getOperands ());
70+ });
71+
72+ // Convert func.call to tessera.call
73+ tesseraDefineOp.walk ([&](func::CallOp callOp) {
74+ rewriter.setInsertionPoint (callOp);
75+ rewriter.replaceOpWithNewOp <tessera::CallOp>(
76+ callOp, callOp.getResultTypes (), callOp.getOperands (),
77+ callOp->getAttrs ());
78+ });
6279 }
6380
6481 rewriter.eraseOp (funcOp);
@@ -81,7 +98,7 @@ class CallOpRewrite final : public OpRewritePattern<func::CallOp> {
8198 Operation *calleeOp = SymbolTable::lookupSymbolIn (moduleOp, calleeAttr);
8299
83100 // Only convert if the callee is a Tessera DefineOp
84- if (isa<tessera::DefineOp>(calleeOp))
101+ if (! isa<tessera::DefineOp>(calleeOp))
85102 return rewriter.notifyMatchFailure (callOp, " Callee is not a Tessera DefineOp" );
86103
87104 rewriter.replaceOpWithNewOp <tessera::CallOp>(callOp, callOp.getResultTypes (),
@@ -122,14 +139,28 @@ namespace mlir::enzyme::tessera {
122139struct FuncToTesseraPass
123140 : public PassWrapper<FuncToTesseraPass, OperationPass<ModuleOp>> {
124141
142+ StringRef getArgument () const final { return " func-to-tessera" ; }
143+ StringRef getDescription () const final {
144+ return " Convert func dialect to tessera dialect." ;
145+ }
146+
147+ void getDependentDialects (DialectRegistry ®istry) const override {
148+ registry.insert <tessera::TesseraDialect>();
149+ }
150+
125151 void runOnOperation () override {
126152 MLIRContext *ctx = &getContext ();
153+
154+ ConversionTarget target (*ctx);
155+ target.addLegalDialect <tessera::TesseraDialect>();
156+ target.addLegalDialect <BuiltinDialect>();
157+ target.addIllegalDialect <func::FuncDialect>();
158+
127159 RewritePatternSet patterns (ctx);
128160
129161 patterns.add <FuncOpRewrite, CallOpRewrite, ReturnOpRewrite>(ctx);
130162
131- if (failed (applyPatternsAndFoldGreedily (getOperation (),
132- std::move (patterns))))
163+ if (failed (applyFullConversion (getOperation (), target, std::move (patterns))))
133164 signalPassFailure ();
134165}
135166};
0 commit comments