Skip to content

Commit 1c0187b

Browse files
Fmt 4
1 parent 83caf8b commit 1c0187b

File tree

3 files changed

+37
-46
lines changed

3 files changed

+37
-46
lines changed

src/enzyme_ad/jax/BUILD

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -641,11 +641,11 @@ gentbl_cc_library(
641641
tblgen = "@llvm-project//mlir:mlir-tblgen",
642642
td_file = "Dialect/Tessera/Ops.td",
643643
deps = [
644-
"@llvm-project//mlir:FunctionInterfacesTdFiles",
644+
":TesseraDialectTdFiles",
645+
"@llvm-project//mlir:CallOpInterfaces",
645646
"@llvm-project//mlir:ControlFlowInterfaces",
647+
"@llvm-project//mlir:FunctionInterfacesTdFiles",
646648
"@llvm-project//mlir:IR",
647-
"@llvm-project//mlir:CallOpInterfaces",
648-
":TesseraDialectTdFiles",
649649
],
650650
)
651651

@@ -672,8 +672,8 @@ gentbl_cc_library(
672672
tblgen = "@llvm-project//mlir:mlir-tblgen",
673673
td_file = "Passes/Tessera/Passes.td",
674674
deps = [
675-
"@llvm-project//mlir:FunctionInterfacesTdFiles",
676675
":TesseraPassesTdFiles",
676+
"@llvm-project//mlir:FunctionInterfacesTdFiles",
677677
],
678678
)
679679

src/enzyme_ad/jax/Passes/Tessera/FuncToTessera.cpp

Lines changed: 21 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
//===----------------------------------------------------------------------===//
22
//
3-
// This file implements patterns to convert operations in the Func dialect to
3+
// This file implements patterns to convert operations in the Func dialect to
44
// operations in the Tessera dialect.
55
//
66
//===----------------------------------------------------------------------===//
@@ -23,8 +23,7 @@
2323
using namespace mlir;
2424
using namespace mlir::enzyme;
2525

26-
namespace {
27-
} // namespace
26+
namespace {} // namespace
2827

2928
//===----------------------------------------------------------------------===//
3029
// Rewrite Patterns
@@ -37,12 +36,10 @@ class FuncOpRewrite final : public OpRewritePattern<func::FuncOp> {
3736
public:
3837
using OpRewritePattern<func::FuncOp>::OpRewritePattern;
3938

40-
LogicalResult
41-
matchAndRewrite(func::FuncOp funcOp,
42-
PatternRewriter &rewriter) const override {
39+
LogicalResult matchAndRewrite(func::FuncOp funcOp,
40+
PatternRewriter &rewriter) const override {
4341
FunctionType fnType = funcOp.getFunctionType();
4442

45-
4643
// Create the `tessera.define` op
4744
auto tesseraDefineOp = rewriter.create<tessera::DefineOp>(
4845
funcOp.getLoc(), funcOp.getName(), fnType);
@@ -88,21 +85,21 @@ class CallOpRewrite final : public OpRewritePattern<func::CallOp> {
8885
public:
8986
using OpRewritePattern<func::CallOp>::OpRewritePattern;
9087

91-
LogicalResult
92-
matchAndRewrite(func::CallOp callOp,
93-
PatternRewriter &rewriter) const override {
88+
LogicalResult matchAndRewrite(func::CallOp callOp,
89+
PatternRewriter &rewriter) const override {
9490

9591
auto calleeAttr = callOp.getCalleeAttr();
9692
Operation *moduleOp = callOp->getParentOfType<ModuleOp>();
9793
Operation *calleeOp = SymbolTable::lookupSymbolIn(moduleOp, calleeAttr);
98-
94+
9995
// Only convert if the callee is a Tessera DefineOp
10096
if (!isa<tessera::DefineOp>(calleeOp))
101-
return rewriter.notifyMatchFailure(callOp, "Callee is not a Tessera DefineOp");
102-
103-
rewriter.replaceOpWithNewOp<tessera::CallOp>(callOp, callOp.getResultTypes(),
104-
callOp.getOperands(),
105-
callOp->getAttrs());
97+
return rewriter.notifyMatchFailure(callOp,
98+
"Callee is not a Tessera DefineOp");
99+
100+
rewriter.replaceOpWithNewOp<tessera::CallOp>(
101+
callOp, callOp.getResultTypes(), callOp.getOperands(),
102+
callOp->getAttrs());
106103

107104
return success();
108105
}
@@ -113,17 +110,17 @@ class ReturnOpRewrite final : public OpRewritePattern<func::ReturnOp> {
113110
public:
114111
using OpRewritePattern<func::ReturnOp>::OpRewritePattern;
115112

116-
LogicalResult
117-
matchAndRewrite(func::ReturnOp returnOp,
118-
PatternRewriter &rewriter) const override {
113+
LogicalResult matchAndRewrite(func::ReturnOp returnOp,
114+
PatternRewriter &rewriter) const override {
119115
Operation *parent = returnOp->getParentOp();
120-
116+
121117
// Only convert if the function is a Tessera DefineOp
122118
if (!isa<tessera::DefineOp>(parent))
123-
return rewriter.notifyMatchFailure(returnOp, "Parent is not a Tessera DefineOp");
124-
119+
return rewriter.notifyMatchFailure(returnOp,
120+
"Parent is not a Tessera DefineOp");
121+
125122
rewriter.replaceOpWithNewOp<tessera::ReturnOp>(returnOp,
126-
returnOp.getOperands());
123+
returnOp.getOperands());
127124
return success();
128125
}
129126
};
@@ -136,7 +133,7 @@ class ReturnOpRewrite final : public OpRewritePattern<func::ReturnOp> {
136133
namespace mlir::enzyme::tessera {
137134

138135
struct FuncToTesseraPass
139-
: public PassWrapper<FuncToTesseraPass, OperationPass<ModuleOp>> {
136+
: public PassWrapper<FuncToTesseraPass, OperationPass<ModuleOp>> {
140137

141138
StringRef getArgument() const final { return "func-to-tessera"; }
142139
StringRef getDescription() const final {

src/enzyme_ad/jax/Passes/Tessera/TesseraToFunc.cpp

Lines changed: 12 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
//===----------------------------------------------------------------------===//
22
//
3-
// This file implements patterns to convert operations in the Tessera dialect to
3+
// This file implements patterns to convert operations in the Tessera dialect to
44
// operations in the Func dialect.
55
//
66
//===----------------------------------------------------------------------===//
@@ -22,9 +22,7 @@
2222
using namespace mlir;
2323
using namespace mlir::enzyme;
2424

25-
namespace {
26-
} // namespace
27-
25+
namespace {} // namespace
2826

2927
//===----------------------------------------------------------------------===//
3028
// Rewrite Patterns
@@ -37,12 +35,10 @@ class DefineOpRewrite final : public OpRewritePattern<tessera::DefineOp> {
3735
public:
3836
using OpRewritePattern<tessera::DefineOp>::OpRewritePattern;
3937

40-
LogicalResult
41-
matchAndRewrite(tessera::DefineOp defineOp,
42-
PatternRewriter &rewriter) const override {
38+
LogicalResult matchAndRewrite(tessera::DefineOp defineOp,
39+
PatternRewriter &rewriter) const override {
4340
FunctionType fnType = defineOp.getFunctionType();
4441

45-
4642
// Create the `func.func` op
4743
auto funcOp = rewriter.create<func::FuncOp>(defineOp.getLoc(),
4844
defineOp.getName(), fnType);
@@ -72,13 +68,12 @@ class CallOpRewrite final : public OpRewritePattern<tessera::CallOp> {
7268
public:
7369
using OpRewritePattern<tessera::CallOp>::OpRewritePattern;
7470

75-
LogicalResult
76-
matchAndRewrite(tessera::CallOp callOp,
77-
PatternRewriter &rewriter) const override {
71+
LogicalResult matchAndRewrite(tessera::CallOp callOp,
72+
PatternRewriter &rewriter) const override {
7873

7974
rewriter.replaceOpWithNewOp<func::CallOp>(callOp, callOp.getResultTypes(),
80-
callOp.getOperands(),
81-
callOp->getAttrs());
75+
callOp.getOperands(),
76+
callOp->getAttrs());
8277

8378
return success();
8479
}
@@ -89,12 +84,11 @@ class ReturnOpRewrite final : public OpRewritePattern<tessera::ReturnOp> {
8984
public:
9085
using OpRewritePattern<tessera::ReturnOp>::OpRewritePattern;
9186

92-
LogicalResult
93-
matchAndRewrite(tessera::ReturnOp returnOp,
94-
PatternRewriter &rewriter) const override {
87+
LogicalResult matchAndRewrite(tessera::ReturnOp returnOp,
88+
PatternRewriter &rewriter) const override {
9589

9690
rewriter.replaceOpWithNewOp<func::ReturnOp>(returnOp,
97-
returnOp.getOperands());
91+
returnOp.getOperands());
9892
return success();
9993
}
10094
};
@@ -107,7 +101,7 @@ class ReturnOpRewrite final : public OpRewritePattern<tessera::ReturnOp> {
107101
namespace mlir::enzyme::tessera {
108102

109103
struct TesseraToFuncPass
110-
: public PassWrapper<TesseraToFuncPass, OperationPass<ModuleOp>> {
104+
: public PassWrapper<TesseraToFuncPass, OperationPass<ModuleOp>> {
111105

112106
StringRef getArgument() const final { return "tessera-to-func"; }
113107
StringRef getDescription() const final {

0 commit comments

Comments
 (0)