Skip to content

Commit e9c6e3c

Browse files
committed
Distributed passes boilerplate
1 parent 55998c5 commit e9c6e3c

File tree

4 files changed

+57
-0
lines changed

4 files changed

+57
-0
lines changed

src/enzyme_ad/jax/BUILD

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -584,6 +584,31 @@ gentbl_cc_library(
584584
],
585585
)
586586

587+
td_library(
588+
name = "DistributedPassesTdFiles",
589+
srcs = [
590+
],
591+
deps = [
592+
"@llvm-project//mlir:PassBaseTdFiles",
593+
],
594+
)
595+
596+
gentbl_cc_library(
597+
name = "DistributedPassesIncGen",
598+
tbl_outs = [
599+
(
600+
[
601+
"-gen-pass-decls",
602+
"-name=distributed",
603+
],
604+
"Passes/Distributed/Passes.h.inc",
605+
),
606+
],
607+
tblgen = "@llvm-project//mlir:mlir-tblgen",
608+
td_file = "Passes/Distributed/Passes.td",
609+
deps = [":DistributedPassesTdFiles"],
610+
)
611+
587612
td_library(
588613
name = "TesseraDialectTdFiles",
589614
srcs = [
@@ -715,6 +740,7 @@ cc_library(
715740
srcs = glob([
716741
"Implementations/*.cpp",
717742
"Passes/*.cpp",
743+
"Passes/Distributed/*.cpp",
718744
"Dialect/*.cpp",
719745
"Dialect/Distributed/*.cpp",
720746
"Dialect/Tessera/*.cpp",
@@ -724,6 +750,7 @@ cc_library(
724750
hdrs = glob([
725751
"Implementations/*.h",
726752
"Passes/*.h",
753+
"Passes/Distributed/*.h",
727754
"Dialect/*.h",
728755
"Dialect/Distributed/*.h",
729756
"Dialect/Tessera/*.h",
@@ -744,6 +771,7 @@ cc_library(
744771
":DistributedDialectIncGen",
745772
":DistributedInterfacesIncGen",
746773
":DistributedOpsIncGen",
774+
":DistributedPassesIncGen",
747775
":DistributedTypesIncGen",
748776
":EnzymeHLOPatternsIncGen",
749777
":EnzymeXLAAttrsIncGen",
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
#ifndef ENZYMEXLA_DISTRIBUTED_PASSES_H
2+
#define ENZYMEXLA_DISTRIBUTED_PASSES_H
3+
4+
namespace mlir::enzyme::distributed {
5+
6+
#define GEN_PASS_DECL
7+
#include "src/enzyme_ad/jax/Passes/Distributed/Passes.h.inc"
8+
#define GEN_PASS_REGISTRATION
9+
#include "src/enzyme_ad/jax/Passes/Distributed/Passes.h.inc"
10+
11+
} // namespace mlir::enzyme::distributed
12+
13+
#endif // ENZYMEXLA_DISTRIBUTED_PASSES_H
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
#ifndef ENZYMEXLA_DISTRIBUTED_PASSES
2+
#define ENZYMEXLA_DISTRIBUTED_PASSES
3+
4+
include "mlir/Pass/PassBase.td"
5+
6+
def EliminateConstantCommunicationPass : Pass<"eliminate-constant-communication"> {
7+
let summary = "Replaces communicated constants with local constants";
8+
let description = [{
9+
This pass identifies send instructions with constant operands and replaces
10+
the corresponding receive instructions with local constants.
11+
}];
12+
let dependentDialects = ["enzyme::distributed::DistributedDialect"];
13+
}
14+
15+
#endif // ENZYMEXLA_DISTRIBUTED_PASSES

src/enzyme_ad/jax/RegistryUtils.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@
8686
#include "mlir/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.h"
8787

8888
#include "src/enzyme_ad/jax/Dialect/Ops.h"
89+
#include "src/enzyme_ad/jax/Passes/Distributed/Passes.h"
8990
#include "src/enzyme_ad/jax/Passes/Passes.h"
9091

9192
#include "src/enzyme_ad/jax/Dialect/Distributed/Dialect.h"

0 commit comments

Comments
 (0)