-
Notifications
You must be signed in to change notification settings - Fork 26
Add ReverseConstProp pattern to fold reverse of constant tensors #1655
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
| // Splat constants are already handled by NoopReverse - the reverse | ||
| // of a splat is the same splat, so we can skip them here. | ||
| if (inp.isSplat()) | ||
| return failure(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@copilot if it is a splat we can just replace the original op with the operand (since a reverse of a splat is the same splat). Add this functionality, and a test
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done in commit b009964. Updated ReverseConstProp to handle splat constants by replacing the reverse op with the operand, and added a test case for this behavior.
|
hey this actually kinda worked! cc @avik-pal if you can help me do a quick review of these to see if I missed something obvious |
| patterns.add<IotaSimplify, BroadcastInDimSimplify, ConcatConstProp, | ||
| DynamicUpdateSliceConstProp, PadSimplify, ScatterConstFold>( | ||
| max_constant_expansion, context, PatternBenefit(65000)); | ||
| DynamicUpdateSliceConstProp, PadSimplify, ScatterConstFold, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done in commit 992788b. Added ApplyReverseConstPropPatterns to TransformOps.td following the same pattern as other parameterized constant propagation patterns.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@copilot add this pattern in primitives.py along with the other const prop patterns
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done in commit 577ea72. Added reverse_const_prop({max_constant_threshold}) to the constant propagation patterns section in primitives.py.
test/lit_tests/reverseconstant.mlir
Outdated
| // Test 5: Reverse a splat constant - should replace with operand | ||
| module { | ||
| func.func @main() -> tensor<4xf64> { | ||
| %cst = stablehlo.constant dense<5.0> : tensor<4xf64> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
change the test to be on a huge constant, that is clearly above any max threshold -- which should always have the optimization apply
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@copilot change the test to be on a huge constant, that is clearly above any max threshold -- which should always have the optimization apply
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done in commit 6df09d5. Updated the splat constant test to use a tensor with 1,000,000 elements, which is clearly above any max_constant_expansion threshold.
|
hm is there a reason why ci isn't running here? [cc @giordano ] |
|
it looks like its not on a fork, which sure |
577ea72 to
5aad9db
Compare
|
It was waiting on the queue, apparently the condition Enzyme-JAX/.github/workflows/build.yml Line 38 in 31f39c5
|
|
Also, for some reason I can't cancel the old job: https://github.com/EnzymeAD/Enzyme-JAX/actions/runs/19659114147/job/56385671649. Apparently CI jobs triggered by Copilot can't be cancelled at all? |
|
@copilot fix the CI errors |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
EnzymeJAX Benchmarks
Details
| Benchmark suite | Current: e977cfb | Previous: fadcc11 | Ratio |
|---|---|---|---|
scatter_sum / JaX / cpu / Primal |
0.000004400513600012345 s |
0.000004860522899980424 s |
0.91 |
scatter_sum / JaXPipe / cpu / Primal |
0.000004369506700004422 s |
0.000004766752000432461 s |
0.92 |
scatter_sum / JaX / tpu / Primal |
0.0001556710854999 s |
0.0001507932107 s |
1.03 |
scatter_sum / JaXPipe / tpu / Primal |
0.0001546421919999 s |
0.0001417516048 s |
1.09 |
This comment was automatically generated by workflow using github-action-benchmark.
Co-authored-by: wsmoses <[email protected]>
Co-authored-by: wsmoses <[email protected]>
Co-authored-by: wsmoses <[email protected]>
Co-authored-by: avik-pal <[email protected]>
Co-authored-by: avik-pal <[email protected]>
…to fix linker error Co-authored-by: avik-pal <[email protected]>
0653d98 to
e977cfb
Compare
wsmoses
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ReverseConstProppattern to foldreverse(constant)->constantaddReverseConstPropfunction declaration/implementation andpopulatePatternsmethodChanges
Fixed CI linker error by adding:
addReverseConstPropfunction declaration inEnzymeHLOPatterns.haddReverseConstPropfunction implementation inEnzymeHLOOpt.cppApplyReverseConstPropPatterns::populatePatternsmethod inTransformOps.cppOriginal prompt
This section details on the original issue you should resolve
<issue_title>Reverse of constant -> constant</issue_title>
<issue_description>```
module @reactant_jac2 attributes {mhlo.num_partitions = 1 : i64, mhlo.num_replicas = 1 : i64} {
func.func @main(%arg0: tensor<4xf64> {enzymexla.memory_effects = [], tf.aliasing_output = 1 : i32}) -> (tensor<4x2xf64>, tensor<4xf64>) attributes {enzymexla.memory_effects = []} {
%cst = stablehlo.constant dense<0.000000e+00> : tensor<1x2xf64>
%cst_0 = stablehlo.constant dense<[[0.000000e+00, 1.000000e+00]]> : tensor<1x2xf64>
%cst_1 = stablehlo.constant dense<[[1.000000e+00, 0.000000e+00]]> : tensor<1x2xf64>
%cst_2 = stablehlo.constant dense<[1.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00]> : tensor<4xf64>
%cst_3 = stablehlo.constant dense<[0.000000e+00, 1.000000e+00, 0.000000e+00, 0.000000e+00]> : tensor<4xf64>
%cst_4 = stablehlo.constant dense<[0.000000e+00, 0.000000e+00, 1.000000e+00, 0.000000e+00]> : tensor<4xf64>
%cst_5 = stablehlo.constant dense<[0.000000e+00, 0.000000e+00, 0.000000e+00, 1.000000e+00]> : tensor<4xf64>
%0 = stablehlo.reverse %cst_2, dims = [0] : tensor<4xf64>
%1 = stablehlo.reverse %arg0, dims = [0] : tensor<4xf64>
%2 = stablehlo.reverse %cst_3, dims = [0] : tensor<4xf64>
%3 = stablehlo.reverse %1, dims = [0] {enzymexla.finite = [#enzymexla]} : tensor<4xf64>
%4 = stablehlo.reverse %cst_4, dims = [0] : tensor<4xf64>
%5 = stablehlo.reverse %3, dims = [0] {enzymexla.finite = [#enzymexla]} : tensor<4xf64>
%6 = stablehlo.reverse %cst_5, dims = [0] : tensor<4xf64>
%7 = stablehlo.reverse %5, dims = [0] : tensor<4xf64>
%8 = stablehlo.slice %arg0 [0:2] : (tensor<4xf64>) -> tensor<2xf64>
%9 = stablehlo.cosine %8 : tensor<2xf64>
%10 = stablehlo.reshape %9 : (tensor<2xf64>) -> tensor<1x2xf64>
%11 = stablehlo.multiply %cst_1, %10 : tensor<1x2xf64>
%12 = stablehlo.slice %1 [0:2] : (tensor<4xf64>) -> tensor<2xf64>
%13 = stablehlo.reshape %12 : (tensor<2xf64>) -> tensor<1x2xf64>
%14 = stablehlo.cosine %13 : tensor<1x2xf64>
%15 = stablehlo.multiply %cst_0, %14 : tensor<1x2xf64>
%16 = stablehlo.slice %3 [0:2] : (tensor<4xf64>) -> tensor<2xf64>
%17 = stablehlo.reshape %16 : (tensor<2xf64>) -> tensor<1x2xf64>
%18 = stablehlo.cosine %17 {enzymexla.finite = [#enzymexla], enzymexla.no_nan = [#enzymexla]} : tensor<1x2xf64>
%19 = stablehlo.multiply %cst, %18 {enzymexla.no_nan = [#enzymexla]} : tensor<1x2xf64>
%20 = stablehlo.slice %5 [0:2] : (tensor<4xf64>) -> tensor<2xf64>
%21 = stablehlo.reshape %20 : (tensor<2xf64>) -> tensor<1x2xf64>
%22 = stabl...
💡 You can make Copilot smarter by setting up custom instructions, customizing its development environment and configuring Model Context Protocol (MCP) servers. Learn more Copilot coding agent tips in the docs.