Skip to content

Commit 38f8a46

Browse files
Fixed all bugs so that it correctly builds
1 parent 08bda01 commit 38f8a46

File tree

11 files changed

+106
-60
lines changed

11 files changed

+106
-60
lines changed

src/enzyme_ad/jax/BUILD

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -625,6 +625,10 @@ gentbl_cc_library(
625625
tblgen = "@llvm-project//mlir:mlir-tblgen",
626626
td_file = "Dialect/Tessera/Ops.td",
627627
deps = [
628+
"@llvm-project//mlir:FunctionInterfacesTdFiles",
629+
"@llvm-project//mlir:ControlFlowInterfaces",
630+
"@llvm-project//mlir:IR",
631+
"@llvm-project//mlir:CallOpInterfaces",
628632
":TesseraDialectTdFiles",
629633
],
630634
)
@@ -651,7 +655,10 @@ gentbl_cc_library(
651655
],
652656
tblgen = "@llvm-project//mlir:mlir-tblgen",
653657
td_file = "Passes/Tessera/Passes.td",
654-
deps = [":TesseraPassesTdFiles"],
658+
deps = [
659+
"@llvm-project//mlir:FunctionInterfacesTdFiles",
660+
":TesseraPassesTdFiles",
661+
],
655662
)
656663

657664
cc_library(

src/enzyme_ad/jax/Dialect/Tessera/Dialect.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
#include "Dialect.h"
2+
#include "Ops.h"
23

34
#include "mlir/IR/Builders.h"
45
#include "llvm/ADT/TypeSwitch.h"

src/enzyme_ad/jax/Dialect/Tessera/Dialect.h

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,4 @@
1717
// Include the dialect
1818
#include "src/enzyme_ad/jax/Dialect/Tessera/TesseraDialect.h.inc"
1919

20-
// Operations
21-
#define GET_OP_CLASSES
22-
#include "src/enzyme_ad/jax/Dialect/Tessera/TesseraOps.h.inc"
23-
2420
#endif // ENZYME_AD_JAX_DIALECT_TESSERA_DIALECT_H

src/enzyme_ad/jax/Dialect/Tessera/Ops.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,12 @@
22
#include "llvm/ADT/TypeSwitch.h"
33

44
#include "Dialect.h"
5+
#include "Ops.h"
56
#include "mlir/IR/IRMapping.h"
67
#include "mlir/Interfaces/FunctionImplementation.h"
78
#include "mlir/Interfaces/FunctionInterfaces.h"
9+
#include "mlir/Interfaces/CallInterfaces.h"
10+
#include "mlir/Interfaces/SideEffectInterfaces.h"
811

912
using namespace mlir;
1013
using namespace mlir::enzyme::tessera;
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
#ifndef ENZYME_AD_JAX_TESSERA_OPS_H
2+
#define ENZYME_AD_JAX_TESSERA_OPS_H
3+
4+
#include "mlir/IR/OpDefinition.h"
5+
#include "mlir/IR/BuiltinOps.h"
6+
#include "mlir/IR/Dialect.h"
7+
#include "mlir/Interfaces/FunctionInterfaces.h"
8+
#include "mlir/Interfaces/CallInterfaces.h"
9+
#include "mlir/Interfaces/SideEffectInterfaces.h"
10+
#include "mlir/Interfaces/ControlFlowInterfaces.h"
11+
12+
#define GET_OP_CLASSES
13+
#include "src/enzyme_ad/jax/Dialect/Tessera/TesseraOps.h.inc"
14+
15+
#endif // ENZYME_AD_JAX_TESSERA_OPS_H

src/enzyme_ad/jax/Dialect/Tessera/Ops.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@ def DefineOp : TesseraOp<"define", [
9393
//===------------------------------------------------------------------===//
9494

9595
bool isDeclaration() { return isExternal(); }
96+
9697
}];
9798
let hasCustomAssemblyFormat = 1;
9899
}

src/enzyme_ad/jax/Passes/Tessera/FuncToTessera.cpp

Lines changed: 28 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -5,19 +5,26 @@
55
//
66
//===----------------------------------------------------------------------===//
77

8+
#include "mlir/Bytecode/BytecodeOpInterface.h"
9+
#include "src/enzyme_ad/jax/Dialect/Tessera/Ops.h"
810
#include "mlir/Dialect/Func/IR/FuncOps.h"
11+
#include "mlir/IR/PatternMatch.h"
12+
#include "mlir/IR/IRMapping.h"
13+
#include "mlir/Transforms/DialectConversion.h"
14+
#include "mlir/Interfaces/FunctionInterfaces.h"
15+
#include "mlir/Interfaces/CallInterfaces.h"
16+
#include "mlir/Pass/Pass.h"
917
#include "mlir/IR/BuiltinOps.h"
1018
#include "src/enzyme_ad/jax/Dialect/Tessera/Dialect.h"
1119
#include "src/enzyme_ad/jax/Passes/Tessera/Passes.h"
12-
#include "src/enzyme_ad/jax/Passes/Passes.h"
20+
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
1321

1422
using namespace mlir;
15-
using namespace mlir::enzyme::tessera;
23+
using namespace mlir::enzyme;
1624

1725
namespace {
1826
} // namespace
1927

20-
2128
//===----------------------------------------------------------------------===//
2229
// Rewrite Patterns
2330
//===----------------------------------------------------------------------===//
@@ -39,30 +46,19 @@ class FuncOpRewrite final : public OpRewritePattern<func::FuncOp> {
3946
auto tesseraDefineOp = rewriter.create<tessera::DefineOp>(
4047
funcOp.getLoc(), funcOp.getName(), fnType);
4148

42-
4349
// Copy over all attributes other than the function name and type.
4450
for (const auto &namedAttr : funcOp->getAttrs()) {
4551
if (namedAttr.getName() != funcOp.getFunctionTypeAttrName() &&
4652
namedAttr.getName() != SymbolTable::getSymbolAttrName())
4753
tesseraDefineOp->setAttr(namedAttr.getName(), namedAttr.getValue());
4854
}
4955

50-
// Add `extern` to specifiers if `func.func` is declaration only.
51-
if (funcOp.isDeclaration()) {
52-
ArrayAttr specifiers = rewriter.getStrArrayAttr({"extern"});
53-
tesseraDefineOp.setSpecifiersAttr(specifiers);
54-
}
55-
56-
// Add `static` to specifiers if `func.func` is private but not a
57-
// declaration.
58-
if (funcOp.isPrivate() && !funcOp.isDeclaration()) {
59-
ArrayAttr specifiers = rewriter.getStrArrayAttr({"static"});
60-
tesseraDefineOp.setSpecifiersAttr(specifiers);
61-
}
62-
63-
if (!funcOp.isDeclaration()) {
56+
// Clone body of function
57+
if (!funcOp.isExternal()) {
58+
IRMapping mapper;
6459
funcOp.getBody().cloneInto(&tesseraDefineOp.getBody(),
65-
tesseraDefineOp.getBody().end());
60+
tesseraDefineOp.getBody().end(),
61+
mapper);
6662
}
6763

6864
rewriter.eraseOp(funcOp);
@@ -121,18 +117,24 @@ class ReturnOpRewrite final : public OpRewritePattern<func::ReturnOp> {
121117
// Pass to convert Func operations into Tessera operations
122118
//===----------------------------------------------------------------------===//
123119

120+
namespace mlir::enzyme::tessera {
121+
124122
struct FuncToTesseraPass
125123
: public PassWrapper<FuncToTesseraPass, OperationPass<ModuleOp>> {
126124

127125
void runOnOperation() override {
128-
MLIRContext &ctx = patterns.getContext();
129-
RewritePatternSet patterns(&ctx);
126+
MLIRContext *ctx = &getContext();
127+
RewritePatternSet patterns(ctx);
130128

131-
patterns.add<FuncOpRewrite, CallOpRewrite, ReturnOpRewrite>(&ctx);
129+
patterns.add<FuncOpRewrite, CallOpRewrite, ReturnOpRewrite>(ctx);
132130

133-
if (failed(applyPatternsAndFoldGreedily(getOperation(),
134-
std::move(patterns))))
135-
signalPassFailure();
136-
}
131+
if (failed(applyPatternsAndFoldGreedily(getOperation(),
132+
std::move(patterns))))
133+
signalPassFailure();
134+
}
137135
};
138136

137+
std::unique_ptr<mlir::Pass> createFuncToTesseraPass() {
138+
return std::make_unique<FuncToTesseraPass>();
139+
}
140+
}
Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,32 @@
11
#ifndef TESSERA_PASSES_H
22
#define TESSERA_PASSES_H
33

4+
#include "src/enzyme_ad/jax/Dialect/Tessera/Ops.h"
45
#include "mlir/Pass/Pass.h"
56

67
namespace mlir {
8+
namespace enzyme {
79
namespace tessera {
810

11+
std::unique_ptr<mlir::Pass> createTesseraToFuncPass();
12+
std::unique_ptr<mlir::Pass> createFuncToTesseraPass();
13+
14+
} // namespace tessera
15+
} // namespace enzyme
16+
} // namespace mlir
17+
918
#define GEN_PASS_DECLS
10-
#include "Tessera/Passes/Tessera/Passes.h.inc"
19+
#include "src/enzyme_ad/jax/Passes/Tessera/Passes.h.inc"
20+
21+
namespace mlir {
22+
namespace enzyme {
23+
namespace tessera {
1124

1225
#define GEN_PASS_REGISTRATION
13-
#include "Tessera/Passes/Tessera/Passes.h.inc"
26+
#include "src/enzyme_ad/jax/Passes/Tessera/Passes.h.inc"
1427

1528
} // namespace tessera
29+
} // namespace enzyme
1630
} // namespace mlir
1731

1832
#endif

src/enzyme_ad/jax/Passes/Tessera/Passes.td

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,15 @@
22
#define ENZYME_AD_JAX_TESSERA_PASSES
33

44
include "mlir/Pass/PassBase.td"
5+
include "mlir/Interfaces/FunctionInterfaces.td"
56

67
def FuncToTesseraPass : Pass<"func-to-tessera"> {
78
let summary = "Convert operations in the Func Dialect to operations in the Tessera Dialect";
89
let description = [{
910
This pass checks if an operation is marked with a custom
1011
annotation and if so, creates a TesseraOp from the FuncOp.
11-
}]
12+
}];
13+
let constructor = "mlir::enzyme::tessera::createFuncToTesseraPass()";
1214
let dependentDialects = [
1315
"func::FuncDialect",
1416
"enzyme::tessera::TesseraDialect"
@@ -19,7 +21,8 @@ def TesseraToFuncPass : Pass<"tessera-to-func"> {
1921
let summary = "Convert operations in the Tessera Dialect to operations in the Func Dialect";
2022
let description = [{
2123
This pass converts a TesseraOp back into a FuncOp.
22-
}]
24+
}];
25+
let constructor = "mlir::enzyme::tessera::createTesseraToFuncPass()";
2326
let dependentDialects = [
2427
"func::FuncDialect",
2528
"enzyme::tessera::TesseraDialect"

src/enzyme_ad/jax/Passes/Tessera/TesseraToFunc.cpp

Lines changed: 28 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,22 @@
55
//
66
//===----------------------------------------------------------------------===//
77

8+
#include "mlir/Bytecode/BytecodeOpInterface.h"
9+
#include "src/enzyme_ad/jax/Dialect/Tessera/Ops.h"
810
#include "mlir/Dialect/Func/IR/FuncOps.h"
11+
#include "mlir/IR/PatternMatch.h"
12+
#include "mlir/IR/IRMapping.h"
13+
#include "mlir/Transforms/DialectConversion.h"
14+
#include "mlir/Interfaces/FunctionInterfaces.h"
15+
#include "mlir/Interfaces/CallInterfaces.h"
16+
#include "mlir/Pass/Pass.h"
917
#include "mlir/IR/BuiltinOps.h"
1018
#include "src/enzyme_ad/jax/Dialect/Tessera/Dialect.h"
1119
#include "src/enzyme_ad/jax/Passes/Tessera/Passes.h"
12-
#include "src/enzyme_ad/jax/Passes/Passes.h"
20+
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
1321

1422
using namespace mlir;
15-
using namespace mlir::enzyme::tessera;
23+
using namespace mlir::enzyme;
1624

1725
namespace {
1826
} // namespace
@@ -47,22 +55,12 @@ class DefineOpRewrite final : public OpRewritePattern<tessera::DefineOp> {
4755
funcOp->setAttr(namedAttr.getName(), namedAttr.getValue());
4856
}
4957

50-
// Add `extern` to specifiers if `tessera.define` is declaration only.
51-
if (defineOp.isDeclaration()) {
52-
ArrayAttr specifiers = rewriter.getStrArrayAttr({"extern"});
53-
funcOp.setSpecifiersAttr(specifiers);
54-
}
55-
56-
// Add `static` to specifiers if `tessera.define` is private but not a
57-
// declaration.
58-
if (defineOp.isPrivate() && !defineOp.isDeclaration()) {
59-
ArrayAttr specifiers = rewriter.getStrArrayAttr({"static"});
60-
funcOp.setSpecifiersAttr(specifiers);
61-
}
62-
63-
if (!defineOp.isDeclaration()) {
58+
// Clone body of function
59+
if (!defineOp.isExternal()) {
60+
IRMapping mapper;
6461
defineOp.getBody().cloneInto(&funcOp.getBody(),
65-
funcOp.getBody().end());
62+
funcOp.getBody().end(),
63+
mapper);
6664
}
6765

6866
rewriter.eraseOp(defineOp);
@@ -108,18 +106,24 @@ class ReturnOpRewrite final : public OpRewritePattern<tessera::ReturnOp> {
108106
// Pass to convert Tessera operations into Func operations
109107
//===----------------------------------------------------------------------===//
110108

109+
namespace mlir::enzyme::tessera {
110+
111111
struct TesseraToFuncPass
112112
: public PassWrapper<TesseraToFuncPass, OperationPass<ModuleOp>> {
113113

114114
void runOnOperation() override {
115-
MLIRContext &ctx = patterns.getContext();
116-
RewritePatternSet patterns(&ctx);
115+
MLIRContext *ctx = &getContext();
116+
RewritePatternSet patterns(ctx);
117117

118-
patterns.add<DefineOpRewrite, CallOpRewrite, ReturnOpRewrite>(&ctx);
118+
patterns.add<DefineOpRewrite, CallOpRewrite, ReturnOpRewrite>(ctx);
119119

120-
if (failed(applyPatternsAndFoldGreedily(getOperation(),
121-
std::move(patterns))))
122-
signalPassFailure();
123-
}
120+
if (failed(applyPatternsAndFoldGreedily(getOperation(),
121+
std::move(patterns))))
122+
signalPassFailure();
123+
}
124124
};
125125

126+
std::unique_ptr<mlir::Pass> createTesseraToFuncPass() {
127+
return std::make_unique<TesseraToFuncPass>();
128+
}
129+
}

0 commit comments

Comments
 (0)