-
Hi, For reasons I need to symbolically evaluate JAX functions. E.g. Say I have a JAX function
and return a string that represents the result of the computation
In general, I want to support all JAX functions, and also support symbolic weights. I've experimented with rolling-my-own "shadow" functions. But this won't scale, because I need to re-implement the JAX api. Instead, I'd like to support strings and string operations at a deeper level in the JAX system, so all the existing functionality can operate on symbolic expressions. Any pointers much appreciated! |
Beta Was this translation helpful? Give feedback.
Replies: 4 comments 4 replies
-
Hi - thanks for the question. This is pretty far-afield from what JAX is designed to do, but you may be able to take advantage of JAX's transformation framework to define a custom transform that would do this kind of thing. I don't have any pointers to relevant documentation or examples, because any such solution would rely on a lot of internal APIs. Edit: here's a project with a similar goal that you might use and/or get ideas from: https://git.informatik.uni-hamburg.de/0moin/google-research/-/tree/master/jax2tex. I don't think the code is maintained, though. |
Beta Was this translation helpful? Give feedback.
-
Thanks for the reply! I had seen I plan, therefore, to generate symbolic representations by interpreting the |
Beta Was this translation helpful? Give feedback.
-
If you're looking for a way to pass strings across JAX API boundaries, then from equinox.internal import str2jax
from jax import jit
hello = str2jax("hello")
world = str2jax("world")
@jit
def concat_strings(x, y):
return str2jax(str(x) + str(y))
concat_strings(hello, world) All the string manipulations necessarily happen during trace time, i.e. they're "static". If you're looking for something deeper then that then I don't think that'd really be supported by any JAX code that is lowered to XLA (i.e. jit'd in the normal way). |
Beta Was this translation helpful? Give feedback.
-
In case this is relevant to others I implemented a symbolic evaluator for JAX expressions in this repo: The relevant python files are: Note that I've implemented symbolic versions only for the primitive JAX operations that I need. But it's not difficult to add more. So the overall structure might be useful to others. I construct neural nets that learn discrete boolean functions, and then use the symbolic evaluator to extract the boolean function that's been learnt. See the (draft) paper: Thanks for writing JAX! |
Beta Was this translation helpful? Give feedback.
Hi - thanks for the question. This is pretty far-afield from what JAX is designed to do, but you may be able to take advantage of JAX's transformation framework to define a custom transform that would do this kind of thing. I don't have any pointers to relevant documentation or examples, because any such solution would rely on a lot of internal APIs.
Edit: here's a project with a similar goal that you might use and/or get ideas from: https://git.informatik.uni-hamburg.de/0moin/google-research/-/tree/master/jax2tex. I don't think the code is maintained, though.