Skip to content

Commit c8efb46

Browse files
authored
feat: update AutoBatching passes to be CheckedRewrite (#1434)
1 parent 87f3bf4 commit c8efb46

File tree

5 files changed

+123
-94
lines changed

5 files changed

+123
-94
lines changed

src/enzyme_ad/jax/BUILD

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -629,6 +629,15 @@ gentbl_cc_library(
629629
],
630630
)
631631

632+
cc_library(
633+
name = "CheckedRewrite",
634+
hdrs = ["CheckedRewrite.h"],
635+
deps = [
636+
"@llvm-project//mlir:FunctionInterfaces",
637+
"@llvm-project//mlir:IR",
638+
],
639+
)
640+
632641
cc_library(
633642
name = "XLADerivatives",
634643
srcs = glob([
@@ -659,6 +668,7 @@ cc_library(
659668
],
660669
visibility = ["//visibility:public"],
661670
deps = [
671+
":CheckedRewrite",
662672
":DistributedDialectIncGen",
663673
":DistributedOpsIncGen",
664674
":DistributedTypesIncGen",
@@ -787,8 +797,9 @@ cc_library(
787797
hdrs = glob([
788798
"Implementations/*.h",
789799
"Passes/*.h",
800+
]) + [
790801
"RegistryUtils.h",
791-
]),
802+
],
792803
deps = [
793804
":TransformOps",
794805
":XLADerivatives",

src/enzyme_ad/jax/CheckedRewrite.h

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
#pragma once
2+
3+
#include "mlir/IR/PatternMatch.h"
4+
#include "mlir/Interfaces/FunctionInterfaces.h"
5+
6+
namespace mlir {
7+
namespace enzyme {
8+
9+
static constexpr StringRef kDisablePatternAttrName =
10+
"enzymexla.disable_hlo_opts";
11+
12+
static LogicalResult failIfDynamicShape(Operation *op,
13+
PatternRewriter &rewriter) {
14+
for (auto type : op->getResultTypes()) {
15+
auto rType = dyn_cast<RankedTensorType>(type);
16+
if (!rType || !rType.hasStaticShape())
17+
return rewriter.notifyMatchFailure(
18+
op, "unsupported dynamic shape for output.");
19+
}
20+
21+
for (auto type : op->getOperandTypes()) {
22+
auto rType = dyn_cast<RankedTensorType>(type);
23+
if (!rType || !rType.hasStaticShape())
24+
return rewriter.notifyMatchFailure(
25+
op, "unsupported dynamic shape for input.");
26+
}
27+
28+
return success();
29+
}
30+
31+
static LogicalResult failIfFuncOpInterfaceHasAttr(Operation *op,
32+
StringRef attrName,
33+
PatternRewriter &rewriter) {
34+
if (auto func = op->getParentOfType<FunctionOpInterface>()) {
35+
if (func->hasAttrOfType<UnitAttr>(attrName))
36+
return rewriter.notifyMatchFailure(op, "disabled by attribute.");
37+
}
38+
39+
return success();
40+
}
41+
42+
template <typename OpTy, typename Child>
43+
struct CheckedOpRewritePattern : public OpRewritePattern<OpTy> {
44+
using Base = OpRewritePattern<OpTy>;
45+
using Base::Base;
46+
47+
LogicalResult
48+
matchAndRewrite(OpTy op, PatternRewriter &rewriter) const override final {
49+
LogicalResult res =
50+
failIfFuncOpInterfaceHasAttr(op, kDisablePatternAttrName, rewriter);
51+
if (res.failed())
52+
return res;
53+
54+
if (!((Child *)this)->supportsDynamicShapes()) {
55+
LogicalResult res = failIfDynamicShape(op, rewriter);
56+
if (res.failed())
57+
return res;
58+
}
59+
60+
return ((Child *)this)->matchAndRewriteImpl(op, rewriter);
61+
}
62+
63+
bool supportsDynamicShapes() const { return false; }
64+
};
65+
66+
template <template <typename> class TraitType, typename Child>
67+
struct CheckedOpTraitRewritePattern : public OpTraitRewritePattern<TraitType> {
68+
using Base = OpTraitRewritePattern<TraitType>;
69+
using Base::Base;
70+
71+
LogicalResult
72+
matchAndRewrite(Operation *op,
73+
PatternRewriter &rewriter) const override final {
74+
LogicalResult res =
75+
failIfFuncOpInterfaceHasAttr(op, kDisablePatternAttrName, rewriter);
76+
if (res.failed())
77+
return res;
78+
79+
if (!((Child *)this)->supportsDynamicShapes()) {
80+
auto res = failIfDynamicShape(op, rewriter);
81+
if (res.failed())
82+
return res;
83+
}
84+
85+
return ((Child *)this)->matchAndRewriteImpl(op, rewriter);
86+
}
87+
88+
bool supportsDynamicShapes() const { return false; }
89+
};
90+
91+
} // namespace enzyme
92+
} // namespace mlir

src/enzyme_ad/jax/Passes/AutoBatching.cpp

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -261,9 +261,8 @@ constructAndExtractBatchOperands(PatternRewriter &rewriter,
261261
return std::make_tuple(operands, operandIndexMap);
262262
}
263263

264-
LogicalResult
265-
ConcatInsertDimToBatchBase::matchAndRewrite(stablehlo::ConcatenateOp concatOp,
266-
PatternRewriter &rewriter) const {
264+
LogicalResult ConcatInsertDimToBatchBase::matchAndRewriteImpl(
265+
stablehlo::ConcatenateOp concatOp, PatternRewriter &rewriter) const {
267266
if (concatOp.getNumOperands() <= 1)
268267
return failure();
269268

@@ -391,8 +390,8 @@ bool ConcatInsertDimToBatchBase::validBroadcastInDimOpInsertDimForBatching(
391390
}
392391

393392
LogicalResult
394-
SliceToBatchBase::matchAndRewrite(stablehlo::SliceOp sliceOp,
395-
PatternRewriter &rewriter) const {
393+
SliceToBatchBase::matchAndRewriteImpl(stablehlo::SliceOp sliceOp,
394+
PatternRewriter &rewriter) const {
396395
Value sliceInput = sliceOp.getOperand();
397396
// Find all slices of the same input that feed into equivalent operations
398397
SmallVector<SliceInfo> relatedSlices;

src/enzyme_ad/jax/Passes/AutoBatching.h

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#pragma once
22

33
#include "mlir/IR/PatternMatch.h"
4+
#include "src/enzyme_ad/jax/CheckedRewrite.h"
45
#include "src/enzyme_ad/jax/Utils.h"
56
#include "stablehlo/dialect/StablehloOps.h"
67
#include "llvm/ADT/SmallVector.h"
@@ -16,8 +17,11 @@ struct BatchOperandConstructionInfo {
1617
};
1718

1819
struct ConcatInsertDimToBatchBase
19-
: public mlir::OpRewritePattern<mlir::stablehlo::ConcatenateOp> {
20-
using Base = mlir::OpRewritePattern<mlir::stablehlo::ConcatenateOp>;
20+
: public mlir::enzyme::CheckedOpRewritePattern<
21+
mlir::stablehlo::ConcatenateOp, ConcatInsertDimToBatchBase> {
22+
using Base =
23+
mlir::enzyme::CheckedOpRewritePattern<mlir::stablehlo::ConcatenateOp,
24+
ConcatInsertDimToBatchBase>;
2125
using Base::Base;
2226

2327
ConcatInsertDimToBatchBase(
@@ -26,8 +30,8 @@ struct ConcatInsertDimToBatchBase
2630
: Base(ctx, benefit), isValidTargetOp(isValidTargetOp) {}
2731

2832
llvm::LogicalResult
29-
matchAndRewrite(mlir::stablehlo::ConcatenateOp concatOp,
30-
mlir::PatternRewriter &rewriter) const override;
33+
matchAndRewriteImpl(mlir::stablehlo::ConcatenateOp concatOp,
34+
mlir::PatternRewriter &rewriter) const;
3135

3236
protected:
3337
std::function<mlir::Operation *(mlir::Operation *)> isValidTargetOp;
@@ -67,8 +71,10 @@ struct ConcatInsertDimElementwiseToBatch : public ConcatInsertDimToBatchBase {
6771
};
6872

6973
struct SliceToBatchBase
70-
: public mlir::OpRewritePattern<mlir::stablehlo::SliceOp> {
71-
using Base = mlir::OpRewritePattern<mlir::stablehlo::SliceOp>;
74+
: public mlir::enzyme::CheckedOpRewritePattern<mlir::stablehlo::SliceOp,
75+
SliceToBatchBase> {
76+
using Base = mlir::enzyme::CheckedOpRewritePattern<mlir::stablehlo::SliceOp,
77+
SliceToBatchBase>;
7278
using Base::Base;
7379

7480
SliceToBatchBase(
@@ -77,8 +83,8 @@ struct SliceToBatchBase
7783
: Base(ctx, benefit), isValidTargetOp(isValidTargetOp) {}
7884

7985
llvm::LogicalResult
80-
matchAndRewrite(mlir::stablehlo::SliceOp sliceOp,
81-
mlir::PatternRewriter &rewriter) const override;
86+
matchAndRewriteImpl(mlir::stablehlo::SliceOp sliceOp,
87+
mlir::PatternRewriter &rewriter) const;
8288

8389
private:
8490
struct SliceInfo {

src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp

Lines changed: 1 addition & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
#include "mlir/Pass/PassManager.h"
2626
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
2727
#include "shardy/dialect/sdy/ir/utils.h"
28+
#include "src/enzyme_ad/jax/CheckedRewrite.h"
2829
#include "src/enzyme_ad/jax/Dialect/Dialect.h"
2930
#include "src/enzyme_ad/jax/Dialect/Ops.h"
3031
#include "src/enzyme_ad/jax/Passes/EnzymeHLOPatterns.h"
@@ -306,86 +307,6 @@ class StaticSlice {
306307
}
307308
};
308309

309-
LogicalResult failIfDynamicShape(Operation *op, PatternRewriter &rewriter) {
310-
for (auto type : op->getResultTypes()) {
311-
auto rType = dyn_cast<RankedTensorType>(type);
312-
if (!rType || !rType.hasStaticShape())
313-
return rewriter.notifyMatchFailure(
314-
op, "unsupported dynamic shape for output.");
315-
}
316-
317-
for (auto type : op->getOperandTypes()) {
318-
auto rType = dyn_cast<RankedTensorType>(type);
319-
if (!rType || !rType.hasStaticShape())
320-
return rewriter.notifyMatchFailure(
321-
op, "unsupported dynamic shape for input.");
322-
}
323-
324-
return success();
325-
}
326-
327-
LogicalResult failIfFuncOpInterfaceHasAttr(Operation *op, StringRef attrName,
328-
PatternRewriter &rewriter) {
329-
if (auto func = op->getParentOfType<FunctionOpInterface>()) {
330-
if (func->hasAttrOfType<UnitAttr>(attrName))
331-
return rewriter.notifyMatchFailure(op, "disabled by attribute.");
332-
}
333-
334-
return success();
335-
}
336-
337-
static constexpr StringRef kDisablePatternAttrName =
338-
"enzymexla.disable_hlo_opts";
339-
340-
template <typename OpTy, typename Child>
341-
struct CheckedOpRewritePattern : public OpRewritePattern<OpTy> {
342-
using Base = OpRewritePattern<OpTy>;
343-
using Base::Base;
344-
345-
LogicalResult
346-
matchAndRewrite(OpTy op, PatternRewriter &rewriter) const override final {
347-
LogicalResult res =
348-
failIfFuncOpInterfaceHasAttr(op, kDisablePatternAttrName, rewriter);
349-
if (res.failed())
350-
return res;
351-
352-
if (!((Child *)this)->supportsDynamicShapes()) {
353-
LogicalResult res = failIfDynamicShape(op, rewriter);
354-
if (res.failed())
355-
return res;
356-
}
357-
358-
return ((Child *)this)->matchAndRewriteImpl(op, rewriter);
359-
}
360-
361-
bool supportsDynamicShapes() { return false; }
362-
};
363-
364-
template <template <typename> class TraitType, typename Child>
365-
struct CheckedOpTraitRewritePattern : public OpTraitRewritePattern<TraitType> {
366-
using Base = OpTraitRewritePattern<TraitType>;
367-
using Base::Base;
368-
369-
LogicalResult
370-
matchAndRewrite(Operation *op,
371-
PatternRewriter &rewriter) const override final {
372-
LogicalResult res =
373-
failIfFuncOpInterfaceHasAttr(op, kDisablePatternAttrName, rewriter);
374-
if (res.failed())
375-
return res;
376-
377-
if (!((Child *)this)->supportsDynamicShapes()) {
378-
auto res = failIfDynamicShape(op, rewriter);
379-
if (res.failed())
380-
return res;
381-
}
382-
383-
return ((Child *)this)->matchAndRewriteImpl(op, rewriter);
384-
}
385-
386-
bool supportsDynamicShapes() { return false; }
387-
};
388-
389310
template <typename OpTy, typename Child>
390311
struct NoNanCheckedOpRewritePattern
391312
: public CheckedOpRewritePattern<OpTy, Child> {

0 commit comments

Comments
 (0)