Skip to content

Commit a611a1d

Browse files
Added conditions that parent/callee must be Tessera DefineOps
1 parent c4d10eb commit a611a1d

File tree

3 files changed

+163
-16
lines changed

3 files changed

+163
-16
lines changed

src/enzyme_ad/jax/Passes/Tessera/FuncToTessera.cpp

Lines changed: 20 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
//===----------------------------------------------------------------------===//
22
//
33
// This file implements patterns to convert the Func dialect to the Tessera
4-
// dialect and from the Tessera dialect to the Func dialect.
4+
// dialect.
55
//
66
//===----------------------------------------------------------------------===//
77

88
#include "mlir/Dialect/Func/IR/FuncOps.h"
9+
#include "mlir/IR/BuiltinOps.h"
910
#include "src/enzyme_ad/jax/Dialect/Tessera/Dialect.h"
1011
#include "src/enzyme_ad/jax/Passes/Passes.h"
1112

@@ -32,10 +33,6 @@ class FuncOpRewrite final : public OpRewritePattern<func::FuncOp> {
3233
PatternRewriter &rewriter) const override {
3334
FunctionType fnType = funcOp.getFunctionType();
3435

35-
if (fnType.getNumResults() > 1)
36-
return rewriter.notifyMatchFailure(
37-
funcOp, "only functions with zero or one result can be rewritten");
38-
3936

4037
// Create the `tessera.define` op
4138
auto tesseraDefineOp = rewriter.create<tessera::DefineOp>(
@@ -63,11 +60,10 @@ class FuncOpRewrite final : public OpRewritePattern<func::FuncOp> {
6360
}
6461

6562
if (!funcOp.isDeclaration()) {
66-
rewriter.inlineRegionBefore(funcOp.getBody(), tesseraDefineOp.getBody(),
67-
tesseraDefineOp.end());
63+
funcOp.getBody().cloneInto(&tesseraDefineOp.getBody(),
64+
tesseraDefineOp.getBody().end());
6865
}
6966

70-
7167
rewriter.eraseOp(funcOp);
7268

7369
return success();
@@ -83,7 +79,15 @@ class CallOpRewrite final : public OpRewritePattern<func::CallOp> {
8379
matchAndRewrite(func::CallOp callOp,
8480
PatternRewriter &rewriter) const override {
8581

86-
rewriter.replaceOpWithNewOp<tessera::CallOp>(callOp, callOp.getResultTypes(),
82+
auto calleeAttr = callOp.getCalleeAttr();
83+
Operation *moduleOp = callOp->getParentOfType<ModuleOp>();
84+
Operation *calleeOp = SymbolTable::lookupSymbolIn(moduleOp, calleeAttr);
85+
86+
// Only convert if the callee is a Tessera DefineOp
87+
if (isa<tessera::DefineOp>(calleeOp))
88+
return rewriter.notifyMatchFailure(callOp, "Callee is not a Tessera DefineOp");
89+
90+
rewriter.replaceOpWithNewOp<tessera::CallOp>(callOp, callOp.getResultTypes(),
8791
callOp.getOperands(),
8892
callOp->getAttrs());
8993

@@ -99,7 +103,12 @@ class ReturnOpRewrite final : public OpRewritePattern<func::ReturnOp> {
99103
LogicalResult
100104
matchAndRewrite(func::ReturnOp returnOp,
101105
PatternRewriter &rewriter) const override {
102-
106+
Operation *parent = returnOp->getParentOp();
107+
108+
// Only convert if the function is a Tessera DefineOp
109+
if (!isa<tessera::DefineOp>(parent))
110+
return rewriter.notifyMatchFailure(returnOp, "Parent is not a Tessera DefineOp");
111+
103112
rewriter.replaceOpWithNewOp<tessera::ReturnOp>(returnOp,
104113
returnOp.getOperands());
105114
return success();
@@ -118,7 +127,7 @@ struct FuncToTesseraPass
118127
MLIRContext &ctx = patterns.getContext();
119128
RewritePatternSet patterns(&ctx);
120129

121-
patterns.add<CallOpRewrite, FuncOpRewrite, ReturnOpRewrite>(&ctx);
130+
patterns.add<FuncOpRewrite, CallOpRewrite, ReturnOpRewrite>(&ctx);
122131

123132
if (failed(applyPatternsAndFoldGreedily(getOperation(),
124133
std::move(patterns))))
Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,29 @@
1-
#ifndef ENZYME_AD_JAX_PASSES_TESSERA_PASSES_TD
2-
#define ENZYME_AD_JAX_PASSES_TESSERA_PASSES_TD
1+
#ifndef ENZYME_AD_JAX_TESSERA_PASSES
2+
#define ENZYME_AD_JAX_TESSERA_PASSES
33

44
include "mlir/Pass/PassBase.td"
55

66
def FuncToTesseraPass : Pass<"func-to-tessera"> {
7-
let summary = "Convert operations in the FuncDialect to operations in the TesseraDialect and vice versa";
7+
let summary = "Convert operations in the Func Dialect to operations in the Tessera Dialect";
8+
let description = [{
9+
This pass checks if an operation is marked with a custom
10+
annotation and if so, creates a TesseraOp from the FuncOp.
11+
}]
812
let dependentDialects = [
913
"func::FuncDialect",
10-
"tessera::TesseraDialect"
14+
"enzyme::tessera::TesseraDialect"
1115
];
1216
}
1317

14-
#endif // ENZYME_AD_JAX_PASSES_TESSERA_PASSES_TD
18+
def TesseraToFuncPass : Pass<"tessera-to-func"> {
19+
let summary = "Convert operations in the Tessera Dialect to operations in the Func Dialect";
20+
let description = [{
21+
This pass converts a TesseraOp back into a FuncOp.
22+
}]
23+
let dependentDialects = [
24+
"func::FuncDialect",
25+
"enzyme::tessera::TesseraDialect"
26+
];
27+
}
28+
29+
#endif // ENZYME_AD_JAX_TESSERA_PASSES
Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
//===----------------------------------------------------------------------===//
2+
//
3+
// This file implements patterns to convert the Tessera dialect to the Func
4+
// dialect.
5+
//
6+
//===----------------------------------------------------------------------===//
7+
8+
#include "mlir/Dialect/Func/IR/FuncOps.h"
9+
#include "src/enzyme_ad/jax/Dialect/Tessera/Dialect.h"
10+
#include "src/enzyme_ad/jax/Passes/Passes.h"
11+
12+
using namespace mlir;
13+
using namespace mlir::enzyme::tessera;
14+
15+
namespace {
16+
} // namespace
17+
18+
19+
//===----------------------------------------------------------------------===//
20+
// Rewrite Patterns
21+
//===----------------------------------------------------------------------===//
22+
23+
namespace {
24+
25+
// Rewrite 'tessera.define' -> 'func.func'
26+
class DefineOpRewrite final : public OpRewritePattern<tessera::DefineOp> {
27+
public:
28+
using OpRewritePattern<tessera::DefineOp>::OpRewritePattern;
29+
30+
LogicalResult
31+
matchAndRewrite(tessera::DefineOp defineOp,
32+
PatternRewriter &rewriter) const override {
33+
FunctionType fnType = defineOp.getFunctionType();
34+
35+
36+
// Create the `func.func` op
37+
auto funcOp = rewriter.create<tessera::DefineOp>(
38+
defineOp.getLoc(), defineOp.getName(), fnType);
39+
40+
41+
// Copy over all attributes other than the function name and type.
42+
for (const auto &namedAttr : defineOp->getAttrs()) {
43+
if (namedAttr.getName() != defineOp.getFunctionTypeAttrName() &&
44+
namedAttr.getName() != SymbolTable::getSymbolAttrName())
45+
funcOp->setAttr(namedAttr.getName(), namedAttr.getValue());
46+
}
47+
48+
// Add `extern` to specifiers if `tessera.define` is declaration only.
49+
if (defineOp.isDeclaration()) {
50+
ArrayAttr specifiers = rewriter.getStrArrayAttr({"extern"});
51+
funcOp.setSpecifiersAttr(specifiers);
52+
}
53+
54+
// Add `static` to specifiers if `tessera.define` is private but not a
55+
// declaration.
56+
if (defineOp.isPrivate() && !defineOp.isDeclaration()) {
57+
ArrayAttr specifiers = rewriter.getStrArrayAttr({"static"});
58+
funcOp.setSpecifiersAttr(specifiers);
59+
}
60+
61+
if (!defineOp.isDeclaration()) {
62+
defineOp.getBody().cloneInto(&funcOp.getBody(),
63+
funcOp.getBody().end());
64+
}
65+
66+
rewriter.eraseOp(defineOp);
67+
68+
return success();
69+
}
70+
};
71+
72+
// Rewrite 'tessera.call' -> 'func.call'
73+
class CallOpRewrite final : public OpRewritePattern<tessera::CallOp> {
74+
public:
75+
using OpRewritePattern<tessera::CallOp>::OpRewritePattern;
76+
77+
LogicalResult
78+
matchAndRewrite(tessera::CallOp callOp,
79+
PatternRewriter &rewriter) const override {
80+
81+
rewriter.replaceOpWithNewOp<func::CallOp>(callOp, callOp.getResultTypes(),
82+
callOp.getOperands(),
83+
callOp->getAttrs());
84+
85+
return success();
86+
}
87+
};
88+
89+
// Rewrite 'tessera.return' -> 'func.return'
90+
class ReturnOpRewrite final : public OpRewritePattern<tessera::ReturnOp> {
91+
public:
92+
using OpRewritePattern<tessera::ReturnOp>::OpRewritePattern;
93+
94+
LogicalResult
95+
matchAndRewrite(tessera::ReturnOp returnOp,
96+
PatternRewriter &rewriter) const override {
97+
98+
rewriter.replaceOpWithNewOp<func::ReturnOp>(returnOp,
99+
returnOp.getOperands());
100+
return success();
101+
}
102+
};
103+
} // namespace
104+
105+
//===----------------------------------------------------------------------===//
106+
// Pass to convert Func operations into Tessera operations
107+
//===----------------------------------------------------------------------===//
108+
109+
struct TesseraToFuncPass
110+
: public PassWrapper<TesseraToFuncPass, OperationPass<ModuleOp>> {
111+
112+
void runOnOperation() override {
113+
MLIRContext &ctx = patterns.getContext();
114+
RewritePatternSet patterns(&ctx);
115+
116+
patterns.add<DefineOpRewrite, CallOpRewrite, ReturnOpRewrite>(&ctx);
117+
118+
if (failed(applyPatternsAndFoldGreedily(getOperation(),
119+
std::move(patterns))))
120+
signalPassFailure();
121+
}
122+
};
123+

0 commit comments

Comments
 (0)