Skip to content

Commit 08bda01

Browse files
Added code to register passes
1 parent a611a1d commit 08bda01

File tree

5 files changed

+56
-5
lines changed

5 files changed

+56
-5
lines changed

src/enzyme_ad/jax/BUILD

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -629,6 +629,31 @@ gentbl_cc_library(
629629
],
630630
)
631631

632+
td_library(
633+
name = "TesseraPassesTdFiles",
634+
srcs = [
635+
],
636+
deps = [
637+
"@llvm-project//mlir:PassBaseTdFiles",
638+
],
639+
)
640+
641+
gentbl_cc_library(
642+
name = "TesseraPassesIncGen",
643+
tbl_outs = [
644+
(
645+
[
646+
"-gen-pass-decls",
647+
"-name=tessera",
648+
],
649+
"Passes/Tessera/Passes.h.inc",
650+
),
651+
],
652+
tblgen = "@llvm-project//mlir:mlir-tblgen",
653+
td_file = "Passes/Tessera/Passes.td",
654+
deps = [":TesseraPassesTdFiles"],
655+
)
656+
632657
cc_library(
633658
name = "CheckedRewrite",
634659
hdrs = ["CheckedRewrite.h"],
@@ -643,6 +668,7 @@ cc_library(
643668
srcs = glob([
644669
"Implementations/*.cpp",
645670
"Passes/*.cpp",
671+
"Passes/Tessera/*.cpp",
646672
"Dialect/*.cpp",
647673
"Dialect/Distributed/*.cpp",
648674
"Dialect/Tessera/*.cpp",
@@ -652,6 +678,7 @@ cc_library(
652678
hdrs = glob([
653679
"Implementations/*.h",
654680
"Passes/*.h",
681+
"Passes/Tessera/*.h",
655682
"Dialect/*.h",
656683
"Dialect/Distributed/*.h",
657684
"Dialect/Tessera/*.h",
@@ -683,6 +710,7 @@ cc_library(
683710
":StablehloOptPatternsIncGen",
684711
":TesseraDialectIncGen",
685712
":TesseraOpsIncGen",
713+
":TesseraPassesIncGen",
686714
":chlo-derivatives",
687715
":mhlo-derivatives",
688716
":stablehlo-derivatives",

src/enzyme_ad/jax/Passes/Tessera/FuncToTessera.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
11
//===----------------------------------------------------------------------===//
22
//
3-
// This file implements patterns to convert the Func dialect to the Tessera
4-
// dialect.
3+
// This file implements patterns to convert operations in the Func dialect to
4+
// operations in the Tessera dialect.
55
//
66
//===----------------------------------------------------------------------===//
77

88
#include "mlir/Dialect/Func/IR/FuncOps.h"
99
#include "mlir/IR/BuiltinOps.h"
1010
#include "src/enzyme_ad/jax/Dialect/Tessera/Dialect.h"
11+
#include "src/enzyme_ad/jax/Passes/Tessera/Passes.h"
1112
#include "src/enzyme_ad/jax/Passes/Passes.h"
1213

1314
using namespace mlir;
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
#ifndef TESSERA_PASSES_H
2+
#define TESSERA_PASSES_H
3+
4+
#include "mlir/Pass/Pass.h"
5+
6+
namespace mlir {
7+
namespace tessera {
8+
9+
#define GEN_PASS_DECLS
10+
#include "Tessera/Passes/Tessera/Passes.h.inc"
11+
12+
#define GEN_PASS_REGISTRATION
13+
#include "Tessera/Passes/Tessera/Passes.h.inc"
14+
15+
} // namespace tessera
16+
} // namespace mlir
17+
18+
#endif

src/enzyme_ad/jax/Passes/Tessera/TesseraToFunc.cpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
11
//===----------------------------------------------------------------------===//
22
//
3-
// This file implements patterns to convert the Tessera dialect to the Func
4-
// dialect.
3+
// This file implements patterns to convert operations in the Tessera dialect to
4+
// operations in the Func dialect.
55
//
66
//===----------------------------------------------------------------------===//
77

88
#include "mlir/Dialect/Func/IR/FuncOps.h"
9+
#include "mlir/IR/BuiltinOps.h"
910
#include "src/enzyme_ad/jax/Dialect/Tessera/Dialect.h"
11+
#include "src/enzyme_ad/jax/Passes/Tessera/Passes.h"
1012
#include "src/enzyme_ad/jax/Passes/Passes.h"
1113

1214
using namespace mlir;
@@ -103,7 +105,7 @@ class ReturnOpRewrite final : public OpRewritePattern<tessera::ReturnOp> {
103105
} // namespace
104106

105107
//===----------------------------------------------------------------------===//
106-
// Pass to convert Func operations into Tessera operations
108+
// Pass to convert Tessera operations into Func operations
107109
//===----------------------------------------------------------------------===//
108110

109111
struct TesseraToFuncPass

src/enzyme_ad/jax/RegistryUtils.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@
8787

8888
#include "src/enzyme_ad/jax/Dialect/Ops.h"
8989
#include "src/enzyme_ad/jax/Passes/Passes.h"
90+
#include "src/enzyme_ad/jax/Passes/Tessera/Passes.h"
9091

9192
#include "src/enzyme_ad/jax/Dialect/Distributed/Dialect.h"
9293
#include "src/enzyme_ad/jax/Dialect/Tessera/Dialect.h"
@@ -294,6 +295,7 @@ void registerInterfaces(mlir::DialectRegistry &registry) {
294295
void initializePasses() {
295296
registerenzymePasses();
296297
enzyme::registerenzymexlaPasses();
298+
enzyme::tessera::registerTesseraPasses();
297299

298300
// Register the standard passes we want.
299301
mlir::registerCSEPass();

0 commit comments

Comments
 (0)