How to write a batching rule for tridiagonal solver? #10339
-
Hi! I've been trying to write a tridiagonal solver in JAX which I need to be able to forward differentiate through. #6843 implemented a primitive for the solver itself, with no forward differentiation rule. I tried using custom_linear_solve to fix this issue:
but when I call (for some trivial 4x4 tridiagonal matrix A and vector b):
and from here I am running into the error |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments 5 replies
-
Since |
Beta Was this translation helpful? Give feedback.
-
The reason that other batch rules in the file look like the one you created is that those primitives are closed under batching: that is, the primitive itself can handle batched inputs, so in order to compute the batched results, you must simply ensure the inputs are laid out in the expected way and then call the original primitive. Long term, the best way to support batched tridiagonal solves would be to make the primitive closed under batching. This would involve changes at the XLA or maybe C++ level. In the meantime, you could probably implement a working batch rule by calling back into the python implementation with def _tridiagonal_solve_batch_rule(vals_in, dims_in, **kw):
return vmap(partial(_tridiagonal_solve_jax, **kw), dims_in)(*vals_in), (0,)
batching.primitive_batchers[lax.linalg.tridiagonal_solve_p] = _tridiagonal_solve_batch_rule Note that the performance of this will likely not be that good, since it's calling the JAX implementation rather than the native implementation. An alternative would be to write a batch rule that explicitly loops over the batch dimensions, calling Does that answer your question? |
Beta Was this translation helpful? Give feedback.
The reason that other batch rules in the file look like the one you created is that those primitives are closed under batching: that is, the primitive itself can handle batched inputs, so in order to compute the batched results, you must simply ensure the inputs are laid out in the expected way and then call the original primitive.
tridiagonal_solve
, on the other hand, is not closed under batching: that is, you cannot use the primitive directly to compute batched results, so the batch rule is going to have to do something other than call back into the primitive.Long term, the best way to support batched tridiagonal solves would be to make the primitive closed under batching. This would i…