Skip to content
Discussion options

You must be logged in to vote

I see 3 ways to speedup your code: jit, vmap, both

(1) jit: You can simply jax.jit your function

f = jax.jit(lambda line_list: jax.tree_util.tree_map(Line_sensitivity,line_list))
f(line_list)  # the first call will be slow 
f(line_list)  # all the others will be fast

It will be fast but the potential problem with this approach is that the compilation time will grow with the length of the list.

(2) vmap: jax.vmap only work along an axis of an ndarray. So you will have to change the format of your input.

pts_coord = np.random.rand(5000,2) #5000 lines
lines = line(point(pts_coord[:, 0]), point(pts_coord[:, 1]))
jax.vmap(Line_sensitivity)(lines)

(3) both

f = jax.jit(jax.vmap(Line_sensitivity))

Replies: 2 comments 4 replies

Comment options

You must be logged in to vote
3 replies
@GaoyuanWu
Comment options

@mariogeiger
Comment options

@GaoyuanWu
Comment options

Answer selected by GaoyuanWu
Comment options

You must be logged in to vote
1 reply
@mariogeiger
Comment options

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants