Fix: wrong AD inputs #99
Draft
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Relevant issue or PR
Tesseract-jax, in some cases, cannot properly deal with traced arrays that are not differentiable. This PR adds a new test that triggers the exact undesired behavior. Note that it works without JIT but fails when JIT is enabled. This is because there is no way to distinguish a traced array due to jitting and a traced array due to differentiability. Even when we call jax.lax.stop_gradient on an array, the tracer still remains due to jitting.
Description of changes
The attempt is to add an argument to the apply function where it marks specific arguments as static. This is, however, not successful yet. Possibly, we need to come up with another solution.
Testing done