How to speed XLA code...a challenge #11324
Replies: 6 comments 33 replies
-
|
Beta Was this translation helpful? Give feedback.
-
@YouJiacheng I would like to get the last point. @register_pytree_node_class
class A:
def __init__(self,
kernel, #: Callable[[jnp.ndarray, jnp.ndarray, Dict[str, jnp.ndarray], float, float],jnp.ndarray]
kernel_hat, # dict
):
self.kernel = kernel
self.kernel_hat = kernel_hat
def tree_flatten(self):
children = (self.kernel_hat, # dict
)
aux_data = {'kernel':self.kernel}
return (children, aux_data)
@classmethod
def tree_unflatten(cls, aux_data, children):
kernel = aux_data['kernel']
return cls(kernel, *children) I imagine that you invite me to manage
but then in my A_factory is it correct then to do a = A(Partial(kernel, params=kernal_hat), <other args>) at least it seems to work... |
Beta Was this translation helpful? Give feedback.
-
Many thanks to @YouJiacheng, even if I do not master (yet) the PyTree (ie. still complicated to see what should go to children and to aux_data) is doing very well. We can close this thread. |
Beta Was this translation helpful? Give feedback.
-
Hi @YouJiacheng import jax
from jax import jit
import jax.numpy as jnp
import numpy as np
@jit
def K(x):
return x**2
@jit
def complicated_a_func(a_obj, x):
print("compiled...")
return a_obj.val * a_obj.f(x)
@register_pytree_node_class
class A():
def __init__(self,val, f):
if not (type(val) is object or val is None or isinstance(val, A)):
val = jnp.asarray(val)
self.val = val
self.f = f
def __repr__(self):
return "A(val={})".format(self.val)
def tree_flatten(self):
children = (self.val,self.f)
aux_data = None
return (children, aux_data)
@classmethod
def tree_unflatten(cls, aux_data, children):
return cls(*children)
func = complicated_a_func
class A_factory():
done = False
_ws = {}
@classmethod
def make(cls, n_data=5_000):
if not A_factory.done:
A_factory._ws = {'data': [A(val, Partial(K)) for val in np.random.uniform(0., 10, size=n_data)]}
A_factory.done = True
return A_factory._ws
list_of_A =A_factory.make()['data']
def f50(p):
list_of_A = A_factory.make()['data']
data = jnp.array([a.func(p) for a in list_of_A[:5_000]])
return data
def f50b(p):
list_of_A = A_factory.make()['data']
data = jnp.array(jax.tree_map(lambda a: a.func(p), list_of_A[:5_000], is_leaf=lambda x: isinstance(x, A) ))
return data The point is that
but here it is just 5000 eveluations of "a * x**2" with "a" an random number [0;10] and x a real (=10.) My question is: what is the mechanism that delay so much the computation ? |
Beta Was this translation helpful? Give feedback.
-
Your test for point 1 is specially optimized by JAX: If all items are not If there is any |
Beta Was this translation helpful? Give feedback.
-
Hello @YouJiacheng I found interesting the discussion #5322 and got an idea def pytrees_stack(pytrees, axis=0):
results = jax.tree_util.tree_map(
lambda *values: jnp.stack(values, axis=axis), *pytrees)
return results
@jit
def test(p):
data = jax.vmap(lambda x: complicated_a_func(x,p))(pytrees_stack(list_of_A))
return data Notice that 1) I do not use the "func" function of the Then with timer():
test(10.)
with timer():
test(10.) gives
and I can use it in the following expression jax.grad(lambda p: jnp.sum(test(p)))(10.)
jax.vmap(lambda p: jnp.sum(test(p)))(jnp.array([10., 20.])) what do you think? |
Beta Was this translation helpful? Give feedback.
Uh oh!
There was an error while loading. Please reload this page.
-
Hi, here a snippet
Notice that
The problem is the JITization of the f-function that depends linearly of the size of the fraction of "list_of_A" used in the "jax.tree_map".
Does one see another more effective way to proceed?
Thanks
Beta Was this translation helpful? Give feedback.
All reactions