Skip to content

Conversation

@andrinr
Copy link
Contributor

@andrinr andrinr commented Nov 4, 2025

Relevant issue or PR

Tesseract-jax, in some cases, cannot properly deal with traced arrays that are not differentiable. This PR adds a new test that triggers the exact undesired behavior. Note that it works without JIT but fails when JIT is enabled. This is because there is no way to distinguish a traced array due to jitting and a traced array due to differentiability. Even when we call jax.lax.stop_gradient on an array, the tracer still remains due to jitting.

Description of changes

The attempt is to add an argument to the apply function where it marks specific arguments as static. This is, however, not successful yet. Possibly, we need to come up with another solution.

Testing done

@andrinr andrinr marked this pull request as draft November 4, 2025 13:42
@andrinr andrinr changed the title Fix wrong AD inputs Fix: wrong AD inputs Nov 4, 2025
@dionhaefner
Copy link
Contributor

@andrinr What is the issue here? Please add a minimal description on what this is attempting to fix.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants