Skip to content

Conversation

Copy link
Contributor

Copilot AI commented Nov 25, 2025

  • Understand the existing code structure and patterns in EnzymeHLOOpt.cpp
  • Create a new ReverseConstProp pattern to fold reverse(constant) -> constant
  • Handle splat constants by replacing the op with its operand
  • Handle non-splat constants within max_constant_expansion limit
  • Add test cases for the new optimization including splat constants
  • Address PR feedback to handle splat constants directly in ReverseConstProp
  • Add pattern to transform dialect in TransformOps.td
  • Update splat constant test to use huge tensor size (1000000 elements) that is clearly above any max threshold
  • Add pattern to primitives.py along with other const prop patterns
  • Fix CI linker error by adding addReverseConstProp function declaration/implementation and populatePatterns method

Changes

Fixed CI linker error by adding:

  • addReverseConstProp function declaration in EnzymeHLOPatterns.h
  • addReverseConstProp function implementation in EnzymeHLOOpt.cpp
  • ApplyReverseConstPropPatterns::populatePatterns method in TransformOps.cpp
Original prompt

This section details on the original issue you should resolve

<issue_title>Reverse of constant -> constant</issue_title>
<issue_description>```
module @reactant_jac2 attributes {mhlo.num_partitions = 1 : i64, mhlo.num_replicas = 1 : i64} {
func.func @main(%arg0: tensor<4xf64> {enzymexla.memory_effects = [], tf.aliasing_output = 1 : i32}) -> (tensor<4x2xf64>, tensor<4xf64>) attributes {enzymexla.memory_effects = []} {
%cst = stablehlo.constant dense<0.000000e+00> : tensor<1x2xf64>
%cst_0 = stablehlo.constant dense<[[0.000000e+00, 1.000000e+00]]> : tensor<1x2xf64>
%cst_1 = stablehlo.constant dense<[[1.000000e+00, 0.000000e+00]]> : tensor<1x2xf64>
%cst_2 = stablehlo.constant dense<[1.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00]> : tensor<4xf64>
%cst_3 = stablehlo.constant dense<[0.000000e+00, 1.000000e+00, 0.000000e+00, 0.000000e+00]> : tensor<4xf64>
%cst_4 = stablehlo.constant dense<[0.000000e+00, 0.000000e+00, 1.000000e+00, 0.000000e+00]> : tensor<4xf64>
%cst_5 = stablehlo.constant dense<[0.000000e+00, 0.000000e+00, 0.000000e+00, 1.000000e+00]> : tensor<4xf64>
%0 = stablehlo.reverse %cst_2, dims = [0] : tensor<4xf64>
%1 = stablehlo.reverse %arg0, dims = [0] : tensor<4xf64>
%2 = stablehlo.reverse %cst_3, dims = [0] : tensor<4xf64>
%3 = stablehlo.reverse %1, dims = [0] {enzymexla.finite = [#enzymexla]} : tensor<4xf64>
%4 = stablehlo.reverse %cst_4, dims = [0] : tensor<4xf64>
%5 = stablehlo.reverse %3, dims = [0] {enzymexla.finite = [#enzymexla]} : tensor<4xf64>
%6 = stablehlo.reverse %cst_5, dims = [0] : tensor<4xf64>
%7 = stablehlo.reverse %5, dims = [0] : tensor<4xf64>
%8 = stablehlo.slice %arg0 [0:2] : (tensor<4xf64>) -> tensor<2xf64>
%9 = stablehlo.cosine %8 : tensor<2xf64>
%10 = stablehlo.reshape %9 : (tensor<2xf64>) -> tensor<1x2xf64>
%11 = stablehlo.multiply %cst_1, %10 : tensor<1x2xf64>
%12 = stablehlo.slice %1 [0:2] : (tensor<4xf64>) -> tensor<2xf64>
%13 = stablehlo.reshape %12 : (tensor<2xf64>) -> tensor<1x2xf64>
%14 = stablehlo.cosine %13 : tensor<1x2xf64>
%15 = stablehlo.multiply %cst_0, %14 : tensor<1x2xf64>
%16 = stablehlo.slice %3 [0:2] : (tensor<4xf64>) -> tensor<2xf64>
%17 = stablehlo.reshape %16 : (tensor<2xf64>) -> tensor<1x2xf64>
%18 = stablehlo.cosine %17 {enzymexla.finite = [#enzymexla], enzymexla.no_nan = [#enzymexla]} : tensor<1x2xf64>
%19 = stablehlo.multiply %cst, %18 {enzymexla.no_nan = [#enzymexla]} : tensor<1x2xf64>
%20 = stablehlo.slice %5 [0:2] : (tensor<4xf64>) -> tensor<2xf64>
%21 = stablehlo.reshape %20 : (tensor<2xf64>) -> tensor<1x2xf64>
%22 = stabl...


💡 You can make Copilot smarter by setting up custom instructions, customizing its development environment and configuring Model Context Protocol (MCP) servers. Learn more Copilot coding agent tips in the docs.

Copilot AI changed the title [WIP] Fix reverse of constant to constant issue Add ReverseConstProp pattern to fold reverse of constant tensors Nov 25, 2025
Copilot AI requested a review from wsmoses November 25, 2025 05:22
// Splat constants are already handled by NoopReverse - the reverse
// of a splat is the same splat, so we can skip them here.
if (inp.isSplat())
return failure();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@copilot if it is a splat we can just replace the original op with the operand (since a reverse of a splat is the same splat). Add this functionality, and a test

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done in commit b009964. Updated ReverseConstProp to handle splat constants by replacing the reverse op with the operand, and added a test case for this behavior.

@wsmoses
Copy link
Member

wsmoses commented Nov 25, 2025

hey this actually kinda worked!

cc @avik-pal if you can help me do a quick review of these to see if I missed something obvious

patterns.add<IotaSimplify, BroadcastInDimSimplify, ConcatConstProp,
DynamicUpdateSliceConstProp, PadSimplify, ScatterConstFold>(
max_constant_expansion, context, PatternBenefit(65000));
DynamicUpdateSliceConstProp, PadSimplify, ScatterConstFold,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@copilot you also need to add this to the list of patterns in the transform dialect, look at earlier PRs like #1653 for examples (see src/enzyme_ad/jax/TransformOps/TransformOps.td)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done in commit 992788b. Added ApplyReverseConstPropPatterns to TransformOps.td following the same pattern as other parameterized constant propagation patterns.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@copilot add this pattern in primitives.py along with the other const prop patterns

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done in commit 577ea72. Added reverse_const_prop({max_constant_threshold}) to the constant propagation patterns section in primitives.py.

Copilot AI requested a review from wsmoses November 25, 2025 21:16
// Test 5: Reverse a splat constant - should replace with operand
module {
func.func @main() -> tensor<4xf64> {
%cst = stablehlo.constant dense<5.0> : tensor<4xf64>
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

change the test to be on a huge constant, that is clearly above any max threshold -- which should always have the optimization apply

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@copilot change the test to be on a huge constant, that is clearly above any max threshold -- which should always have the optimization apply

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done in commit 6df09d5. Updated the splat constant test to use a tensor with 1,000,000 elements, which is clearly above any max_constant_expansion threshold.

@wsmoses
Copy link
Member

wsmoses commented Nov 25, 2025

hm is there a reason why ci isn't running here? [cc @giordano ]

@wsmoses
Copy link
Member

wsmoses commented Nov 25, 2025

it looks like its not on a fork, which sure

@wsmoses wsmoses force-pushed the copilot/fix-reverse-of-constant branch from 577ea72 to 5aad9db Compare November 25, 2025 23:37
@giordano
Copy link
Member

It was waiting on the queue, apparently the condition

cancel-in-progress: ${{ startsWith(github.ref, 'refs/pull/') }}
is false for PRs opened by Copilot

@giordano
Copy link
Member

Also, for some reason I can't cancel the old job: https://github.com/EnzymeAD/Enzyme-JAX/actions/runs/19659114147/job/56385671649. Apparently CI jobs triggered by Copilot can't be cancelled at all?

@avik-pal
Copy link
Collaborator

@copilot fix the CI errors

ERROR: /__w/Enzyme-JAX/Enzyme-JAX/BUILD:35:10: Linking enzymexlamlir-opt failed: (Exit 1): clang failed: error executing CppLink command (from target //:enzymexlamlir-opt) /usr/lib/llvm-18/bin/clang @bazel-out/aarch64-opt/bin/enzymexlamlir-opt-2.params

Use --sandbox_debug to see verbose messages from the sandbox and retain the sandbox build root for debugging
ld.lld: error: undefined symbol: mlir::transform::ApplyReverseConstPropPatterns::populatePatterns(mlir::RewritePatternSet&)
>>> referenced by TransformOps.cpp
>>>               bazel-out/aarch64-opt/bin/src/enzyme_ad/jax/_objs/TransformOps/TransformOps.o:(mlir::transform::detail::PatternDescriptorOpInterfaceInterfaceTraits::Model<mlir::transform::ApplyReverseConstPropPatterns>::populatePatterns(mlir::transform::detail::PatternDescriptorOpInterfaceInterfaceTraits::Concept const*, mlir::Operation*, mlir::RewritePatternSet&))
>>> referenced by TransformOps.cpp
>>>               bazel-out/aarch64-opt/bin/src/enzyme_ad/jax/_objs/TransformOps/TransformOps.o:(mlir::transform::detail::PatternDescriptorOpInterfaceInterfaceTraits::Model<mlir::transform::ApplyReverseConstPropPatterns>::populatePatternsWithState(mlir::transform::detail::PatternDescriptorOpInterfaceInterfaceTraits::Concept const*, mlir::Operation*, mlir::RewritePatternSet&, mlir::transform::TransformState&))
clang: error: linker command failed with exit code 1 (use -v to see invocation)
Target //:enzymexlamlir-opt failed to build
Use --verbose_failures to see the command lines of failed build steps.

Copy link
Contributor Author

Copilot AI commented Nov 26, 2025

@copilot fix the CI errors

...

Fixed in commit 0653d98. Added addReverseConstProp function declaration in EnzymeHLOPatterns.h, its implementation in EnzymeHLOOpt.cpp, and the ApplyReverseConstPropPatterns::populatePatterns method in TransformOps.cpp.

Copy link
Contributor

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

EnzymeJAX Benchmarks

Details
Benchmark suite Current: e977cfb Previous: fadcc11 Ratio
scatter_sum / JaX / cpu / Primal 0.000004400513600012345 s 0.000004860522899980424 s 0.91
scatter_sum / JaXPipe / cpu / Primal 0.000004369506700004422 s 0.000004766752000432461 s 0.92
scatter_sum / JaX / tpu / Primal 0.0001556710854999 s 0.0001507932107 s 1.03
scatter_sum / JaXPipe / tpu / Primal 0.0001546421919999 s 0.0001417516048 s 1.09

This comment was automatically generated by workflow using github-action-benchmark.

@wsmoses wsmoses force-pushed the copilot/fix-reverse-of-constant branch from 0653d98 to e977cfb Compare December 1, 2025 15:27
Copy link
Member

@wsmoses wsmoses left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this looks good to me,

@avik-pal @giordano are you okay with it?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Reverse of constant -> constant

4 participants