Skip to content

Commit 710f238

Browse files
Fmt 1
1 parent 838368d commit 710f238

File tree

5 files changed

+35
-36
lines changed

5 files changed

+35
-36
lines changed

src/enzyme_ad/jax/Dialect/Tessera/Ops.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,9 @@
44
#include "Dialect.h"
55
#include "Ops.h"
66
#include "mlir/IR/IRMapping.h"
7+
#include "mlir/Interfaces/CallInterfaces.h"
78
#include "mlir/Interfaces/FunctionImplementation.h"
89
#include "mlir/Interfaces/FunctionInterfaces.h"
9-
#include "mlir/Interfaces/CallInterfaces.h"
1010
#include "mlir/Interfaces/SideEffectInterfaces.h"
1111

1212
using namespace mlir;

src/enzyme_ad/jax/Dialect/Tessera/Ops.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
#ifndef ENZYME_AD_JAX_TESSERA_OPS_H
22
#define ENZYME_AD_JAX_TESSERA_OPS_H
33

4-
#include "mlir/IR/OpDefinition.h"
54
#include "mlir/IR/BuiltinOps.h"
65
#include "mlir/IR/Dialect.h"
7-
#include "mlir/Interfaces/FunctionInterfaces.h"
6+
#include "mlir/IR/OpDefinition.h"
87
#include "mlir/Interfaces/CallInterfaces.h"
9-
#include "mlir/Interfaces/SideEffectInterfaces.h"
108
#include "mlir/Interfaces/ControlFlowInterfaces.h"
9+
#include "mlir/Interfaces/FunctionInterfaces.h"
10+
#include "mlir/Interfaces/SideEffectInterfaces.h"
1111

1212
#define GET_OP_CLASSES
1313
#include "src/enzyme_ad/jax/Dialect/Tessera/TesseraOps.h.inc"

src/enzyme_ad/jax/Passes/Tessera/FuncToTessera.cpp

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ class FuncOpRewrite final : public OpRewritePattern<func::FuncOp> {
4646
// Create the `tessera.define` op
4747
auto tesseraDefineOp = rewriter.create<tessera::DefineOp>(
4848
funcOp.getLoc(), funcOp.getName(), fnType);
49-
49+
5050
// Copy over all attributes other than the function name and type.
5151
for (const auto &namedAttr : funcOp->getAttrs()) {
5252
if (namedAttr.getName() != funcOp.getFunctionTypeAttrName() &&
@@ -58,8 +58,7 @@ class FuncOpRewrite final : public OpRewritePattern<func::FuncOp> {
5858
if (!funcOp.isExternal()) {
5959
IRMapping mapper;
6060
funcOp.getBody().cloneInto(&tesseraDefineOp.getBody(),
61-
tesseraDefineOp.getBody().end(),
62-
mapper);
61+
tesseraDefineOp.getBody().end(), mapper);
6362

6463
// Now walk through the cloned operations and convert func.return to
6564
// tessera.return
@@ -149,23 +148,24 @@ struct FuncToTesseraPass
149148
}
150149

151150
void runOnOperation() override {
152-
MLIRContext *ctx = &getContext();
151+
MLIRContext *ctx = &getContext();
153152

154-
ConversionTarget target(*ctx);
155-
target.addLegalDialect<tessera::TesseraDialect>();
156-
target.addLegalDialect<BuiltinDialect>();
157-
target.addIllegalDialect<func::FuncDialect>();
153+
ConversionTarget target(*ctx);
154+
target.addLegalDialect<tessera::TesseraDialect>();
155+
target.addLegalDialect<BuiltinDialect>();
156+
target.addIllegalDialect<func::FuncDialect>();
158157

159-
RewritePatternSet patterns(ctx);
158+
RewritePatternSet patterns(ctx);
160159

161-
patterns.add<FuncOpRewrite, CallOpRewrite, ReturnOpRewrite>(ctx);
160+
patterns.add<FuncOpRewrite, CallOpRewrite, ReturnOpRewrite>(ctx);
162161

163-
if (failed(applyFullConversion(getOperation(), target, std::move(patterns))))
164-
signalPassFailure();
165-
}
162+
if (failed(
163+
applyFullConversion(getOperation(), target, std::move(patterns))))
164+
signalPassFailure();
165+
}
166166
};
167167

168168
std::unique_ptr<mlir::Pass> createFuncToTesseraPass() {
169169
return std::make_unique<FuncToTesseraPass>();
170170
}
171-
}
171+
} // namespace mlir::enzyme::tessera

src/enzyme_ad/jax/Passes/Tessera/Passes.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
#ifndef TESSERA_PASSES_H
22
#define TESSERA_PASSES_H
33

4-
#include "src/enzyme_ad/jax/Dialect/Tessera/Ops.h"
54
#include "mlir/Pass/Pass.h"
5+
#include "src/enzyme_ad/jax/Dialect/Tessera/Ops.h"
66

77
namespace mlir {
88
namespace enzyme {

src/enzyme_ad/jax/Passes/Tessera/TesseraToFunc.cpp

Lines changed: 16 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -6,18 +6,18 @@
66
//===----------------------------------------------------------------------===//
77

88
#include "mlir/Bytecode/BytecodeOpInterface.h"
9-
#include "src/enzyme_ad/jax/Dialect/Tessera/Ops.h"
109
#include "mlir/Dialect/Func/IR/FuncOps.h"
11-
#include "mlir/IR/PatternMatch.h"
10+
#include "mlir/IR/BuiltinOps.h"
1211
#include "mlir/IR/IRMapping.h"
13-
#include "mlir/Transforms/DialectConversion.h"
14-
#include "mlir/Interfaces/FunctionInterfaces.h"
12+
#include "mlir/IR/PatternMatch.h"
1513
#include "mlir/Interfaces/CallInterfaces.h"
14+
#include "mlir/Interfaces/FunctionInterfaces.h"
1615
#include "mlir/Pass/Pass.h"
17-
#include "mlir/IR/BuiltinOps.h"
16+
#include "mlir/Transforms/DialectConversion.h"
17+
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
1818
#include "src/enzyme_ad/jax/Dialect/Tessera/Dialect.h"
19+
#include "src/enzyme_ad/jax/Dialect/Tessera/Ops.h"
1920
#include "src/enzyme_ad/jax/Passes/Tessera/Passes.h"
20-
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
2121

2222
using namespace mlir;
2323
using namespace mlir::enzyme;
@@ -57,9 +57,8 @@ class DefineOpRewrite final : public OpRewritePattern<tessera::DefineOp> {
5757
// Clone body of function
5858
if (!defineOp.isExternal()) {
5959
IRMapping mapper;
60-
defineOp.getBody().cloneInto(&funcOp.getBody(),
61-
funcOp.getBody().end(),
62-
mapper);
60+
defineOp.getBody().cloneInto(&funcOp.getBody(), funcOp.getBody().end(),
61+
mapper);
6362
}
6463

6564
rewriter.eraseOp(defineOp);
@@ -120,18 +119,18 @@ struct TesseraToFuncPass
120119
}
121120

122121
void runOnOperation() override {
123-
MLIRContext *ctx = &getContext();
124-
RewritePatternSet patterns(ctx);
122+
MLIRContext *ctx = &getContext();
123+
RewritePatternSet patterns(ctx);
125124

126-
patterns.add<DefineOpRewrite, CallOpRewrite, ReturnOpRewrite>(ctx);
125+
patterns.add<DefineOpRewrite, CallOpRewrite, ReturnOpRewrite>(ctx);
127126

128-
if (failed(applyPatternsAndFoldGreedily(getOperation(),
129-
std::move(patterns))))
130-
signalPassFailure();
131-
}
127+
if (failed(
128+
applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
129+
signalPassFailure();
130+
}
132131
};
133132

134133
std::unique_ptr<mlir::Pass> createTesseraToFuncPass() {
135134
return std::make_unique<TesseraToFuncPass>();
136135
}
137-
}
136+
} // namespace mlir::enzyme::tessera

0 commit comments

Comments
 (0)