-
Hello community, I am trying to do a projection of a vector on a very high-dimensional space. This function is taking a lot of time (minutes) to compile. Looking forward to learning how I can compile it faster. Module Code can be found here: https://github.com/harsh306/continuation-jax/blob/main/cjax/continuation/methods/corrector/perturb_parc_evolve.py#L114 EDIT: Adding a self-contained example. import jax.numpy as jnp
from jax import jit
from jax.ops import index_update
from datetime import datetime
from functools import partial
from jax.experimental.optimizers import l2_norm
from jax import lax
@partial(jit, static_argnums=(0))
def projection_affine(n_dim, u, n, u_0):
"""
Args:
n_dim: affine transformation space
u: random point to be projected on n as L
n: secant normal vector
u_0: secant starting point
Returns:
projected vector
"""
n_norm = l2_norm(n)
I = jnp.eye(n_dim)
# I think I can improve next 2 lines with lax scan, but getting some errors
p2 = [0 * k for k in range(n_dim)]
for k in range(n_dim):
p2[k] = (jnp.dot(n, I[k]) / n_norm ** 2) * n
p2 = jnp.asarray([p2[i] for i in range(n_dim)])
u_0 = lax.reshape(u_0, (n_dim, 1))
#u_0 = u_0.reshape(n_dim, 1)
I = jnp.eye(n_dim)
t1 = jnp.block([[I, u_0], [jnp.zeros(shape=(1, n_dim)), 1.0]])
t2 = jnp.block(
[[p2, jnp.zeros(shape=(n_dim, 1))], [jnp.zeros(shape=(1, n_dim)), 1.0]]
)
t3 = jnp.block([[I, -1 * u_0], [jnp.zeros(shape=(1, n_dim)), 1.0]])
P = jnp.matmul(jnp.matmul(t1, t2), t3)
pr = jnp.matmul(P, jnp.hstack([u, 1.0]))
pr = lax.slice(pr, [0], [n_dim])
return pr
if __name__ == "__main__":
n_dim = 1000
n = jnp.zeros(n_dim)
n = index_update(n, -1, 1)
u_0 = 3.0 * jnp.ones(n_dim)
u = 2.0 * jnp.ones(n_dim)
start = datetime.now()
projection = projection_affine(n_dim, u, n, u_0)
print(f"duration { datetime.now()-start}")
print(projection) There are two prime functions here I am adding one of them. Should learn from that alone and if required I will add one more step. later |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 4 replies
-
Hi - thanks for the question. I suspect that in its current form you will not get many answers. It's not clear, for example, how the class should be instantiated or what the arguments to that method should be. Would it be possible to provide a self-contained example that demonstrates the slow compilation you're seeing? |
Beta Was this translation helpful? Give feedback.
Hi - thanks for the question. I suspect that in its current form you will not get many answers. It's not clear, for example, how the class should be instantiated or what the arguments to that method should be. Would it be possible to provide a self-contained example that demonstrates the slow compilation you're seeing?