Replies: 2 comments 5 replies
-
I'm not sure whether this will improve compile speed, but your approach of manually flattening and unflattening the parameters seems a bit brittle: you might try using import jax.numpy as jnp
from jax.flatten_util import ravel_pytree
a = np.arange(3.0)
b = jnp.ones((2, 1))
c = jnp.eye(3)
flattened, unravel = ravel_pytree((a, b, c))
print(flattened)
# [0. 1. 2. 1. 1. 1. 0. 0. 0. 1. 0. 0. 0. 1.]
a, b, c = unravel(flattened))
print(a)
# DeviceArray([0., 1., 2.], dtype=float32)
print(b)
# DeviceArray([[1.],
# [1.]], dtype=float32),
print(c)
# DeviceArray([[1., 0., 0.],
# [0., 1., 0.],
# [0., 0., 1.]], dtype=float32)] |
Beta Was this translation helpful? Give feedback.
5 replies
-
I found that the main source of compile time lies in the complexity of my model, When I decrease and increase the number of models, the compilation time changes linearly.
|
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
I read the gitbook, I think my problem my be it :
Sometimes it isn’t obvious how to rewrite your code to avoid Python loops because your code makes use of many arrays with different shapes. The recommended solution in this case is to make use of functions like jax.numpy.where() to do your computation on padded arrays with fixed shape. The JAX team is exploring a “masking” transformation to make such code easier to write.
my code :
How to replace the "reshape" to a good compile speed
Beta Was this translation helpful? Give feedback.
All reactions