Skip to content

XLA translation rule for primitive 'two_phase_op_lin' not found #24

@phinate

Description

@phinate

Hi fax team!

We're still making use of the two_phase method to implicitly differentiate through a maximum likelihood fit. In this context, I recently tried to add an additional root-finding computation after this that involves using jax.lax.custom_root to try and make the output differentiable.

I can evaluate the whole forward computation with no issue, but when trying to call jax.grad on the result of this computation, I run into this error:
NotImplementedError: XLA translation rule for primitive 'two_phase_op_lin' not found

I'm not too well versed in the low-level details in jax or fax, but I thought since there's two_phase in the primitive name, it may be something on the fax side that's throwing this issue.

The code is a little involved right now, but if this error isn't clear enough, I can try to deconstruct it to minimally reproduce this error if needed.

Thanks again for a great tool :)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions