Class with jitted function: the following solution is it a good practice? #10598
-
Hello, I know that using Well, I would like to submit to your comment/critics the following simple code that seems to work but may be more experienced jax-developpers would like to improve it. Notice that in my real-use case I have a tenth of "a" and "b" variables... The "a" variable is an archetype of variable that will not change during the processing of function like the "f"-func, while "b" variable is an archetype of variable that will change during the processing. So now you know the spirit/context, here is the snippet: import jax
import jax.numpy as jnp
from jax import vmap, jit
from functools import partial
jax.config.update("jax_enable_x64", True)
class A():
def __init__(self, a: jnp.array)->None:
self.a = a
self.Init()
def Init(self)->None:
self.b = None
def set_b(self, x):
self.b = x
@partial(jit, static_argnums=(0,))
def f(self, var: float)->float:
b = self.a * var
return b
objA = A(jnp.array([2.0]))
print("1)",objA.a, objA.b)
b = objA.f(10.)
print("2)",b)
objA.set_b(b)
print("3)",objA.a, objA.b)
new_objA = A(jnp.array([3.0]))
print("4)",objA.a, objA.b)
print("5)",new_objA.a, new_objA.b)
new_b = new_objA.f(20.)
new_objA.set_b(new_b)
print("6)",objA.a, objA.b)
print("7)",new_objA.a, new_objA.b) This gives as expected
So, now 1-2-3 start to comment... Thanks. |
Beta Was this translation helpful? Give feedback.
Replies: 8 comments 32 replies
-
I think the best practice is to make your class a pytree, |
Beta Was this translation helpful? Give feedback.
-
import jax.numpy as jnp
from jax import jit
from functools import partial
class A():
def __init__(self, a: jnp.ndarray):
self.a = a
self.b = jnp.zeros_like(a)
def set_b(self, x):
self.b = x
@partial(jit, static_argnums=(0,))
def f(self, var: float) -> float:
print('compiling')
b = self.a * var + self.b
return b
objA = A(jnp.array([2.0]))
print("1)", objA.a, objA.b)
b = objA.f(10.)
print("2)", b)
objA.set_b(b)
print("3)", objA.a, objA.b)
print("3.1)", objA.f(10.)) # expect 40.
new_objA = A(jnp.array([3.0]))
print("4)", objA.a, objA.b)
print("5)", new_objA.a, new_objA.b)
new_b = new_objA.f(20.) # expect not re-compiling
new_objA.set_b(new_b)
print("6)", objA.a, objA.b)
print("7)", new_objA.a, new_objA.b)
# 1) [2.] [0.]
# compiling
# 2) [20.]
# 3) [2.] [20.]
# 3.1) [20.]
# 4) [2.] [20.]
# 5) [3.] [0.]
# compiling
# 6) [2.] [20.]
# 7) [3.] [60.] |
Beta Was this translation helpful? Give feedback.
-
Marking obj = A(1)
print(obj.f(2))
# 2
obj.a = 2
print(obj.f(2)) # should print 4, but prints 2
# 2 The issue is that static arguments to jit are evaluated based on their hash, but your object's hash does not take into account the value of However, making For this reason, I would instead make your class a custom pytree and no longer mark the class A():
def __init__(self, a: jnp.array)->None:
self.a = a
self.Init()
def Init(self)->None:
self.b = None
def set_b(self, x):
self.b = x
@jit
def f(self, var: float)->float:
b = self.a * var
return b
def _tree_flatten(self):
# You might also want to store self.b in either the first group
# (if it's not hashable) or the second group (if it's hashable)
return (self.a,), ()
@classmethod
def _tree_unflatten(cls, aux, children):
return cls(*children)
tree_util.register_pytree_node(A, A._tree_flatten, A._tree_unflatten)
obj = A(1)
print(obj.f(2))
# 2
obj.a = 2
print(obj.f(2))
# 4 Note that I ignored the value of |
Beta Was this translation helpful? Give feedback.
-
Best practice IMO: import jax.numpy as jnp
from jax import jit
from jax.tree_util import register_pytree_node
from functools import partial
class A(object):
def __init__(self, a, b):
self.a = a
self.b = b
@classmethod # classmethod is a descriptor, not a callable
@partial(jit, static_argnums=(0,))
def f(cls, obj: 'A', v):
print('compiling')
return obj.a * v + obj.b
register_pytree_node(
A,
lambda objA: ((objA.a, objA.b), None),
lambda _, children: A(*children)
)
objA = A(jnp.array(1.0), jnp.array(5.0))
print(objA.f(objA, 10.0))
objA.b = 10.0
print(objA.f(objA, 10.0))
new_objA = A(jnp.array(6.0), jnp.array(5.0))
print(new_objA.f(new_objA, 10.0))
# compiling
# 15.0
# 20.0
# 65.0 |
Beta Was this translation helpful? Give feedback.
-
See https://jax.readthedocs.io/en/latest/faq.html#how-to-use-jit-with-methods |
Beta Was this translation helpful? Give feedback.
-
Much easier for nearly all use cases is to use a libary that treats classes as PyTrees. I'd recommend Equinox. (Of which I am the author.) This will allow you to use classes/ |
Beta Was this translation helpful? Give feedback.
-
Hi, following #1567 I have cooked this snippet An idea? clear_cache = jax._src.dispatch._xla_callable.cache_clear
def u_print(idx:int, *args)->int:
print(f"{idx}):",*args)
idx +=1
return idx
########
class A():
def __init__(self, a, b=None):
# print("Nouveau A")
self.a = a
self.b = b
@jit
def f(self, var):
print("compile...")
new_b = self.a * var
return A(self.a, new_b) # on retourne un nouvel objet.
tree_util.register_pytree_node(A,
lambda x: ((x.a,x.b), None),
lambda _, x: A(a=x[0],b=x[1])
)
###############
import time
clear_cache
idp=0
objA = A(2.0)
idp = u_print(idp, objA.a, objA.b) #0
objA = objA.f(10.)
time.sleep(0.5)
idp = u_print(idp,objA.a, objA.b) #1
objA = objA.f(11.)
time.sleep(0.5)
idp = u_print(idp, objA.a, objA.b) #2
####
objA = A(3.0)
idp = u_print(idp, objA.a, objA.b) #3
objA = objA.f(20.)
time.sleep(0.5)
idp = u_print(idp, objA.a, objA.b) #4
objA= objA.f(30.)
time.sleep(0.5)
idp = u_print(idp, objA.a, objA.b) #5
objA.a = 30. # permis mais volontaire
objA= objA.f(30.)
time.sleep(0.5)
idp = u_print(idp, objA.a, objA.b) #6
####
objA = A(3.0)
idp = u_print(idp, objA.a, objA.b) #7
objA= objA.f(20.)
time.sleep(0.5)
idp = u_print(idp, objA.a, objA.b) #8
objA= objA.f(30.)
time.sleep(0.5)
idp = u_print(idp, objA.a, objA.b) #9
print(">>>> try an array for var")
objA= objA.f(jnp.array([30.]))
time.sleep(0.5)
idp = u_print(idp, objA.a, objA.b) #10
time.sleep(0.5)
objA= objA.f(jnp.array([40.]))
time.sleep(0.5)
idp = u_print(idp, objA.a, objA.b) #11
#objA.a = 40. ###
#objA= objA.f(jnp.array([40.]))
#time.sleep(0.5)
#idp = u_print(idp, objA.a, objA.b) #12 leading to
|
Beta Was this translation helpful? Give feedback.
-
Hi again, def g(x):
return x**2
class A():
def __init__(self, a):
print("Nouveau A")
self.a = a
self.g = g
self.b = None
def f(self, var):
kargs={}
kargs["var"] = var
kargs["g"] = self.g
self.b = _f(self.a,kargs)
@jit
def _f(a, kargs):
print("compile...")
var = kargs["var"]
g = kargs["g"]
res = g(a)*var
return res
clear_cache
idp=0
objA = A(2.0)
idp = u_print(idp, objA.a, objA.b)
objA.f(10.) I get:
Is there a solution considering the use of helper function? Thanks |
Beta Was this translation helpful? Give feedback.
See https://jax.readthedocs.io/en/latest/faq.html#how-to-use-jit-with-methods