Skip to content

Commit d7c66f3

Browse files
committed
Add raising skeleton
1 parent 47d57c1 commit d7c66f3

File tree

2 files changed

+58
-0
lines changed

2 files changed

+58
-0
lines changed
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
//===- EnzymeBatchToStableHLOPass.cpp ------------------------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This file implements a pass to print the MLIR module
10+
//===----------------------------------------------------------------------===//
11+
12+
#include "src/enzyme_ad/jax/Passes/Passes.h"
13+
#include "stablehlo/dialect/StablehloOps.h"
14+
#include "Enzyme/MLIR/Dialect/Dialect.h"
15+
#include "Enzyme/MLIR/Dialect/Ops.h"
16+
#include "src/enzyme_ad/jax/Utils.h"
17+
18+
#include "mlir/IR/PatternMatch.h"
19+
#include "mlir/Transforms/DialectConversion.h"
20+
21+
namespace mlir {
22+
namespace enzyme {
23+
#define GEN_PASS_DEF_ENZYMEBATCHTOSTABLEHLOPASS
24+
#include "src/enzyme_ad/jax/Passes/Passes.h.inc"
25+
} // namespace enzyme
26+
} // namespace mlir
27+
28+
using namespace mlir;
29+
using namespace mlir::enzyme;
30+
using namespace enzyme;
31+
namespace {
32+
struct EnzymeBatchToStableHLOPass
33+
: public enzyme::impl::EnzymeBatchToStableHLOPassBase<
34+
EnzymeBatchToStableHLOPass> {
35+
void runOnOperation() override {
36+
MLIRContext *context = &getContext();
37+
RewritePatternSet patterns(context);
38+
ConversionTarget target(*context);
39+
target.addLegalDialect<stablehlo::StablehloDialect>();
40+
target.addLegalDialect<enzyme::EnzymeDialect>();
41+
target.addIllegalOp<enzyme::ConcatOp, enzyme::ExtractOp>();
42+
43+
if (failed(applyPartialConversion(getOperation(), target,
44+
std::move(patterns)))) {
45+
signalPassFailure();
46+
}
47+
};
48+
};
49+
} // namespace

src/enzyme_ad/jax/Passes/Passes.td

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1008,4 +1008,13 @@ def SCFCPUify : Pass<"cpuify"> {
10081008
Option<"method", "method", "std::string", /*default=*/"\"distribute\"", "Method of doing distribution">
10091009
];
10101010
}
1011+
1012+
def EnzymeBatchToStableHLOPass : Pass<"enzyme-batch-to-stablehlo"> {
1013+
let summary = "Legalize batching specific enzyme ops to stablehlo dialect";
1014+
let dependentDialects = [
1015+
"stablehlo::StablehloDialect",
1016+
"enzyme::EnzymeDialect"
1017+
];
1018+
}
1019+
10111020
#endif

0 commit comments

Comments
 (0)