You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
When we have a JAXOp in the final graph in a non-jax backend we may want to manipulate the JAX Op for efficiency. We could rewrite Blockwise(JAXOp) -> JAXOp whose inner function is vectorized.
If we have both the Op and the gradient, we could rewrite into a single op that uses value_and_grad under the hood.
And similarly if we only need the shape we could rewrite into an Op whose internal function only computes the shape. This last one is only relevant if the original Op doesn't remain in the graph.