-
Notifications
You must be signed in to change notification settings - Fork 8
Description
Hi fax team!
We're still making use of the two_phase method to implicitly differentiate through a maximum likelihood fit. In this context, I recently tried to add an additional root-finding computation after this that involves using jax.lax.custom_root to try and make the output differentiable.
I can evaluate the whole forward computation with no issue, but when trying to call jax.grad on the result of this computation, I run into this error:
NotImplementedError: XLA translation rule for primitive 'two_phase_op_lin' not found
I'm not too well versed in the low-level details in jax or fax, but I thought since there's two_phase in the primitive name, it may be something on the fax side that's throwing this issue.
The code is a little involved right now, but if this error isn't clear enough, I can try to deconstruct it to minimally reproduce this error if needed.
Thanks again for a great tool :)