Skip to content

Commit d3ed39c

Browse files
EganJwsmoses
authored andcommitted
Distributed passes boilerplate
1 parent f4bb2c6 commit d3ed39c

File tree

4 files changed

+57
-0
lines changed

4 files changed

+57
-0
lines changed

src/enzyme_ad/jax/BUILD

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -584,6 +584,31 @@ gentbl_cc_library(
584584
],
585585
)
586586

587+
td_library(
588+
name = "DistributedPassesTdFiles",
589+
srcs = [
590+
],
591+
deps = [
592+
"@llvm-project//mlir:PassBaseTdFiles",
593+
],
594+
)
595+
596+
gentbl_cc_library(
597+
name = "DistributedPassesIncGen",
598+
tbl_outs = [
599+
(
600+
[
601+
"-gen-pass-decls",
602+
"-name=distributed",
603+
],
604+
"Passes/Distributed/Passes.h.inc",
605+
),
606+
],
607+
tblgen = "@llvm-project//mlir:mlir-tblgen",
608+
td_file = "Passes/Distributed/Passes.td",
609+
deps = [":DistributedPassesTdFiles"],
610+
)
611+
587612
td_library(
588613
name = "TesseraDialectTdFiles",
589614
srcs = [
@@ -717,6 +742,7 @@ cc_library(
717742
srcs = glob([
718743
"Implementations/*.cpp",
719744
"Passes/*.cpp",
745+
"Passes/Distributed/*.cpp",
720746
"Dialect/*.cpp",
721747
"Dialect/Distributed/*.cpp",
722748
"Dialect/Tessera/*.cpp",
@@ -726,6 +752,7 @@ cc_library(
726752
hdrs = glob([
727753
"Implementations/*.h",
728754
"Passes/*.h",
755+
"Passes/Distributed/*.h",
729756
"Dialect/*.h",
730757
"Dialect/Distributed/*.h",
731758
"Dialect/Tessera/*.h",
@@ -746,6 +773,7 @@ cc_library(
746773
":DistributedDialectIncGen",
747774
":DistributedInterfacesIncGen",
748775
":DistributedOpsIncGen",
776+
":DistributedPassesIncGen",
749777
":DistributedTypesIncGen",
750778
":EnzymeHLOPatternsIncGen",
751779
":EnzymeXLAAttrsIncGen",
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
#ifndef ENZYMEXLA_DISTRIBUTED_PASSES_H
2+
#define ENZYMEXLA_DISTRIBUTED_PASSES_H
3+
4+
namespace mlir::enzyme::distributed {
5+
6+
#define GEN_PASS_DECL
7+
#include "src/enzyme_ad/jax/Passes/Distributed/Passes.h.inc"
8+
#define GEN_PASS_REGISTRATION
9+
#include "src/enzyme_ad/jax/Passes/Distributed/Passes.h.inc"
10+
11+
} // namespace mlir::enzyme::distributed
12+
13+
#endif // ENZYMEXLA_DISTRIBUTED_PASSES_H
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
#ifndef ENZYMEXLA_DISTRIBUTED_PASSES
2+
#define ENZYMEXLA_DISTRIBUTED_PASSES
3+
4+
include "mlir/Pass/PassBase.td"
5+
6+
def EliminateConstantCommunicationPass : Pass<"eliminate-constant-communication"> {
7+
let summary = "Replaces communicated constants with local constants";
8+
let description = [{
9+
This pass identifies send instructions with constant operands and replaces
10+
the corresponding receive instructions with local constants.
11+
}];
12+
let dependentDialects = ["enzyme::distributed::DistributedDialect"];
13+
}
14+
15+
#endif // ENZYMEXLA_DISTRIBUTED_PASSES

src/enzyme_ad/jax/RegistryUtils.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@
8686
#include "mlir/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.h"
8787

8888
#include "src/enzyme_ad/jax/Dialect/Ops.h"
89+
#include "src/enzyme_ad/jax/Passes/Distributed/Passes.h"
8990
#include "src/enzyme_ad/jax/Passes/Passes.h"
9091

9192
#include "src/enzyme_ad/jax/Dialect/Distributed/Dialect.h"

0 commit comments

Comments
 (0)