Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 52 additions & 2 deletions src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13571,6 +13571,48 @@ struct NoopReverse final
}
};

// Fold reverse of a constant tensor to a new constant tensor.
// This handles non-splat constants within the max_constant_expansion limit.
struct ReverseConstProp final
: CheckedOpRewritePattern<stablehlo::ReverseOp, ReverseConstProp> {
using CheckedOpRewritePattern::CheckedOpRewritePattern;

size_t max_constant_expansion;
ReverseConstProp(size_t max_constant_expansion, MLIRContext *context,
PatternBenefit benefit = 1,
ArrayRef<StringRef> generatedNames = {})
: CheckedOpRewritePattern(context, benefit, generatedNames),
max_constant_expansion(max_constant_expansion) {}

LogicalResult matchAndRewriteImpl(stablehlo::ReverseOp op,
PatternRewriter &rewriter) const {
DenseElementsAttr inp;
if (!matchPattern(op.getOperand(), m_Constant(&inp)))
return failure();

// For splat constants, reverse produces the same splat.
// Replace the op with the operand.
if (inp.isSplat()) {
rewriter.replaceOp(op, op.getOperand());
return success();
}

// Check size limit
size_t size = 1;
for (auto sz : op.getType().getShape())
size *= sz;
if (size >= max_constant_expansion)
return failure();

// Use stablehlo reference interpreter to compute the reverse
auto out = fromTensor(stablehlo::reverseOp(
stablehlo::constantOp(inp), stablehlo::Axes(op.getDimensions()),
op.getType()));
rewriter.replaceOpWithNewOp<stablehlo::ConstantOp>(op, op.getType(), out);
return success();
}
};

/// Converts gather ops to slice ops in case we have a single set of constant
/// indices.
struct GatherOpCanon final
Expand Down Expand Up @@ -26272,6 +26314,13 @@ void mlir::transform::addDynamicUpdateSliceConstProp(
benefit);
}

void mlir::transform::addReverseConstProp(RewritePatternSet &patterns,
int64_t maxConstantExpansion,
MLIRContext &context,
PatternBenefit benefit) {
patterns.insert<ReverseConstProp>(maxConstantExpansion, &context, benefit);
}

void mlir::transform::addWhileSimplify(RewritePatternSet &patterns,
bool hoistAll, MLIRContext &context,
PatternBenefit benefit) {
Expand Down Expand Up @@ -26557,8 +26606,9 @@ struct EnzymeHLOOptPass
PatternBenefit(65000));

patterns.add<IotaSimplify, BroadcastInDimSimplify, ConcatConstProp,
DynamicUpdateSliceConstProp, PadSimplify, ScatterConstFold>(
max_constant_expansion, context, PatternBenefit(65000));
DynamicUpdateSliceConstProp, PadSimplify, ScatterConstFold,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@copilot you also need to add this to the list of patterns in the transform dialect, look at earlier PRs like #1653 for examples (see src/enzyme_ad/jax/TransformOps/TransformOps.td)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done in commit 992788b. Added ApplyReverseConstPropPatterns to TransformOps.td following the same pattern as other parameterized constant propagation patterns.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@copilot add this pattern in primitives.py along with the other const prop patterns

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done in commit 577ea72. Added reverse_const_prop({max_constant_threshold}) to the constant propagation patterns section in primitives.py.

ReverseConstProp>(max_constant_expansion, context,
PatternBenefit(65000));

patterns.add<
ConvertConcat, DynamicUpdateToConcat, SliceOfDynamicUpdate,
Expand Down
3 changes: 3 additions & 0 deletions src/enzyme_ad/jax/Passes/EnzymeHLOPatterns.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,9 @@ void addDynamicUpdateSliceConstProp(RewritePatternSet &patterns,
int64_t maxConstantExpansion,
MLIRContext &context,
PatternBenefit benefit);
void addReverseConstProp(RewritePatternSet &patterns,
int64_t maxConstantExpansion, MLIRContext &context,
PatternBenefit benefit);
void addWhileSimplify(RewritePatternSet &patterns, bool hoist_all,
MLIRContext &context, PatternBenefit benefit);
void addWhileLICM(RewritePatternSet &patterns, bool hoist_all,
Expand Down
5 changes: 5 additions & 0 deletions src/enzyme_ad/jax/TransformOps/TransformOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,11 @@ void ApplyDynamicUpdateSliceConstPropPatterns::populatePatterns(
addDynamicUpdateSliceConstProp(patterns, getParameter(), *getContext(),
PatternBenefit(getBenefit().value_or(1)));
}
void ApplyReverseConstPropPatterns::populatePatterns(
RewritePatternSet &patterns) {
addReverseConstProp(patterns, getParameter(), *getContext(),
PatternBenefit(getBenefit().value_or(1)));
}
void ApplyBroadcastInDimSimplifyPatterns::populatePatterns(
RewritePatternSet &patterns) {
addBroadcastInDimSimplify(patterns, getParameter(), *getContext(),
Expand Down
14 changes: 14 additions & 0 deletions src/enzyme_ad/jax/TransformOps/TransformOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,20 @@ def ApplyDynamicUpdateSliceConstPropPatterns : EnzymeHLOParameterizedPatternOp<
}
}];
}
def ApplyReverseConstPropPatterns : EnzymeHLOParameterizedPatternOp<
"reverse_const_prop"> {
let arguments = (ins OptionalAttr<I64Attr>:$benefit, I64Attr:$parameter);
let assemblyFormat = "attr-dict";
// TODO: this should be made better searchable.
let extraClassDeclaration = [{
::llvm::SmallVector<::mlir::DictionaryAttr>
static getPossibleAttrCombinations(::mlir::Builder &builder) {
return {builder.getDictionaryAttr(
builder.getNamedAttr("parameter",
builder.getI64IntegerAttr(1024)))};
}
}];
}
def ApplyConcatFusePatterns : EnzymeHLOPatternOp<
"concat_fuse"> {
let patterns = ["ConcatFuse"];
Expand Down
1 change: 1 addition & 0 deletions src/enzyme_ad/jax/primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,6 +365,7 @@ def optimization_passes(
"const_prop_through_barrier<16>",
f"concat_const_prop<1>({max_constant_threshold})",
f"dynamic_update_slice_const_prop({max_constant_threshold})",
f"reverse_const_prop({max_constant_threshold})",
"clamp_const_prop",
]

Expand Down
72 changes: 72 additions & 0 deletions test/lit_tests/reverseconstant.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
// RUN: enzymexlamlir-opt --pass-pipeline="builtin.module(enzyme-hlo-opt)" %s | FileCheck %s

// Test 1: Reverse a non-splat constant along one dimension
module {
func.func @main() -> tensor<4xf64> {
%cst = stablehlo.constant dense<[1.0, 2.0, 3.0, 4.0]> : tensor<4xf64>
%0 = stablehlo.reverse %cst, dims = [0] : tensor<4xf64>
return %0 : tensor<4xf64>
}
}

// CHECK: func.func @main() -> tensor<4xf64> {
// CHECK-NEXT{LITERAL}: %cst = stablehlo.constant dense<[4.000000e+00, 3.000000e+00, 2.000000e+00, 1.000000e+00]> : tensor<4xf64>
// CHECK-NEXT: return %cst : tensor<4xf64>
// CHECK-NEXT: }

// Test 2: Reverse a 2D constant along both dimensions
module {
func.func @main() -> tensor<2x3xf64> {
%cst = stablehlo.constant dense<[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]> : tensor<2x3xf64>
%0 = stablehlo.reverse %cst, dims = [0, 1] : tensor<2x3xf64>
return %0 : tensor<2x3xf64>
}
}

// CHECK: func.func @main() -> tensor<2x3xf64> {
// CHECK-NEXT{LITERAL}: %cst = stablehlo.constant dense<[[6.000000e+00, 5.000000e+00, 4.000000e+00], [3.000000e+00, 2.000000e+00, 1.000000e+00]]> : tensor<2x3xf64>
// CHECK-NEXT: return %cst : tensor<2x3xf64>
// CHECK-NEXT: }

// Test 3: Reverse a 2D constant along just the first dimension
module {
func.func @main() -> tensor<2x3xf64> {
%cst = stablehlo.constant dense<[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]> : tensor<2x3xf64>
%0 = stablehlo.reverse %cst, dims = [0] : tensor<2x3xf64>
return %0 : tensor<2x3xf64>
}
}

// CHECK: func.func @main() -> tensor<2x3xf64> {
// CHECK-NEXT{LITERAL}: %cst = stablehlo.constant dense<[[4.000000e+00, 5.000000e+00, 6.000000e+00], [1.000000e+00, 2.000000e+00, 3.000000e+00]]> : tensor<2x3xf64>
// CHECK-NEXT: return %cst : tensor<2x3xf64>
// CHECK-NEXT: }

// Test 4: Reverse a 2D constant along just the second dimension
module {
func.func @main() -> tensor<2x3xf64> {
%cst = stablehlo.constant dense<[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]> : tensor<2x3xf64>
%0 = stablehlo.reverse %cst, dims = [1] : tensor<2x3xf64>
return %0 : tensor<2x3xf64>
}
}

// CHECK: func.func @main() -> tensor<2x3xf64> {
// CHECK-NEXT{LITERAL}: %cst = stablehlo.constant dense<[[3.000000e+00, 2.000000e+00, 1.000000e+00], [6.000000e+00, 5.000000e+00, 4.000000e+00]]> : tensor<2x3xf64>
// CHECK-NEXT: return %cst : tensor<2x3xf64>
// CHECK-NEXT: }

// Test 5: Reverse a splat constant - should replace with operand
// Use a huge tensor shape that is clearly above any max threshold
module {
func.func @main() -> tensor<1000000xf64> {
%cst = stablehlo.constant dense<5.0> : tensor<1000000xf64>
%0 = stablehlo.reverse %cst, dims = [0] : tensor<1000000xf64>
return %0 : tensor<1000000xf64>
}
}

// CHECK: func.func @main() -> tensor<1000000xf64> {
// CHECK-NEXT: %cst = stablehlo.constant dense<5.000000e+00> : tensor<1000000xf64>
// CHECK-NEXT: return %cst : tensor<1000000xf64>
// CHECK-NEXT: }
Loading