55//
66// ===----------------------------------------------------------------------===//
77
8+ #include " mlir/Bytecode/BytecodeOpInterface.h"
9+ #include " src/enzyme_ad/jax/Dialect/Tessera/Ops.h"
810#include " mlir/Dialect/Func/IR/FuncOps.h"
11+ #include " mlir/IR/PatternMatch.h"
12+ #include " mlir/IR/IRMapping.h"
13+ #include " mlir/Transforms/DialectConversion.h"
14+ #include " mlir/Interfaces/FunctionInterfaces.h"
15+ #include " mlir/Interfaces/CallInterfaces.h"
16+ #include " mlir/Pass/Pass.h"
917#include " mlir/IR/BuiltinOps.h"
1018#include " src/enzyme_ad/jax/Dialect/Tessera/Dialect.h"
1119#include " src/enzyme_ad/jax/Passes/Tessera/Passes.h"
12- #include " src/enzyme_ad/jax/Passes/Passes .h"
20+ #include " mlir/Transforms/GreedyPatternRewriteDriver .h"
1321
1422using namespace mlir ;
15- using namespace mlir ::enzyme::tessera ;
23+ using namespace mlir ::enzyme;
1624
1725namespace {
1826} // namespace
1927
20-
2128// ===----------------------------------------------------------------------===//
2229// Rewrite Patterns
2330// ===----------------------------------------------------------------------===//
@@ -39,30 +46,19 @@ class FuncOpRewrite final : public OpRewritePattern<func::FuncOp> {
3946 auto tesseraDefineOp = rewriter.create <tessera::DefineOp>(
4047 funcOp.getLoc (), funcOp.getName (), fnType);
4148
42-
4349 // Copy over all attributes other than the function name and type.
4450 for (const auto &namedAttr : funcOp->getAttrs ()) {
4551 if (namedAttr.getName () != funcOp.getFunctionTypeAttrName () &&
4652 namedAttr.getName () != SymbolTable::getSymbolAttrName ())
4753 tesseraDefineOp->setAttr (namedAttr.getName (), namedAttr.getValue ());
4854 }
4955
50- // Add `extern` to specifiers if `func.func` is declaration only.
51- if (funcOp.isDeclaration ()) {
52- ArrayAttr specifiers = rewriter.getStrArrayAttr ({" extern" });
53- tesseraDefineOp.setSpecifiersAttr (specifiers);
54- }
55-
56- // Add `static` to specifiers if `func.func` is private but not a
57- // declaration.
58- if (funcOp.isPrivate () && !funcOp.isDeclaration ()) {
59- ArrayAttr specifiers = rewriter.getStrArrayAttr ({" static" });
60- tesseraDefineOp.setSpecifiersAttr (specifiers);
61- }
62-
63- if (!funcOp.isDeclaration ()) {
56+ // Clone body of function
57+ if (!funcOp.isExternal ()) {
58+ IRMapping mapper;
6459 funcOp.getBody ().cloneInto (&tesseraDefineOp.getBody (),
65- tesseraDefineOp.getBody ().end ());
60+ tesseraDefineOp.getBody ().end (),
61+ mapper);
6662 }
6763
6864 rewriter.eraseOp (funcOp);
@@ -121,18 +117,24 @@ class ReturnOpRewrite final : public OpRewritePattern<func::ReturnOp> {
121117// Pass to convert Func operations into Tessera operations
122118// ===----------------------------------------------------------------------===//
123119
120+ namespace mlir ::enzyme::tessera {
121+
124122struct FuncToTesseraPass
125123 : public PassWrapper<FuncToTesseraPass, OperationPass<ModuleOp>> {
126124
127125 void runOnOperation () override {
128- MLIRContext & ctx = patterns. getContext ();
129- RewritePatternSet patterns (& ctx);
126+ MLIRContext * ctx = & getContext ();
127+ RewritePatternSet patterns (ctx);
130128
131- patterns.add <FuncOpRewrite, CallOpRewrite, ReturnOpRewrite>(& ctx);
129+ patterns.add <FuncOpRewrite, CallOpRewrite, ReturnOpRewrite>(ctx);
132130
133- if (failed (applyPatternsAndFoldGreedily (getOperation (),
134- std::move (patterns))))
135- signalPassFailure ();
136- }
131+ if (failed (applyPatternsAndFoldGreedily (getOperation (),
132+ std::move (patterns))))
133+ signalPassFailure ();
134+ }
137135};
138136
137+ std::unique_ptr<mlir::Pass> createFuncToTesseraPass () {
138+ return std::make_unique<FuncToTesseraPass>();
139+ }
140+ }
0 commit comments