pure_callback for custom class instances #13418
-
Hello, Is there a way to tell I am looking for something similar to using # This works
class_instance = CustomClass(*arg, **kwargs)
@jax.jit
def func(x):
return class_instance.do_computation(x) I want to instantiate # This does not work
@jax.jit
def func(x):
class_instance = CustomClass(*arg, **kwargs)
return class_instance.do_computation(x) So my idea is to wrap the instantiation by Is this the right approach for this situation? Any pointers to possible solutions or workaround would be appreciated! |
Beta Was this translation helpful? Give feedback.
Replies: 3 comments 3 replies
-
My primary suggestion would be to modify If this is not possible, then perhaps you could instead do a pure callback around a function that both creates and calls the class? i.e. something like this: def g(x):
class_instance = CustomClass(*arg, **kwargs)
return class_instance.do_computation(x)
@jax.jit
def func(x):
...
return jax.pure_callback(g, result_shape_dtype, x) Without more details it's hard to guess what the best approach would be. What do you think? |
Beta Was this translation helpful? Give feedback.
-
You might be able to try |
Beta Was this translation helpful? Give feedback.
-
Sorry for not answering, we in the end went with another library and did not have that problem anymore. But found @patrick-kidger 's answer to help and work on another instance and marked as answer. |
Beta Was this translation helpful? Give feedback.
You might be able to try
equinox.filter_pure_callback
, which is equivalent tojax.pure_callback
but handles arbitrary Python objects.