-
Notifications
You must be signed in to change notification settings - Fork 23
refactor: move broadcast derivative rule to tablegen #1552
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
| ) | ||
| ], | ||
| ( | ||
| SelectIfActive $x, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
select if active shouldn't be required here
however in this case I'm surprised that the hloidentityop part failed, since this is what it should do (it also supports the reverse mode tablegen so we should do that separately as well)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
will remove
this doesn't fix the batched ad issue, i am just moving the reverse mode cpp version to tablegen
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah all that readonlyidentityop does is say that for forward mode, redo the op with corresponding shadow indices [which should be equivalent to this here]
732648d to
fd83341
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
EnzymeJAX Benchmarks
| Benchmark suite | Current: b7ae6d3 | Previous: 5c0b45a | Ratio |
|---|---|---|---|
scatter_sum / JaX / cpu / Primal |
0.0000043282088998239484 s |
0.000004331737999746111 s |
1.00 |
scatter_sum / JaXPipe / cpu / Primal |
0.000004300268000224605 s |
0.000003905613000097219 s |
1.10 |
scatter_sum / JaX / tpu / Primal |
0.0001375585123998 s |
0.0001482329919002 s |
0.93 |
scatter_sum / JaXPipe / tpu / Primal |
0.0001402481363002 s |
0.0001482432529999 s |
0.95 |
This comment was automatically generated by workflow using github-action-benchmark.
fd83341 to
e10412f
Compare
e10412f to
b7ae6d3
Compare
| ) | ||
| ], | ||
| ( | ||
| BroadcastInDim (ResultTypeWithBatch), (Shadow $x), (getBroadcastDimensionsWithBatch) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cam you explicitly add batch tests? In principle the tablegen interface should automatically do batching on top of so we should check it doesn't accidentally conflict
No description provided.