Skip to content
Discussion options

You must be logged in to vote

To make the computation jittable, you might want to check out lax.dynamic_slice.

General suggestion: instead of using a for loop, you could vmap a function over some pre-generated arrays of random coordinates. Something like

def get_patch(img, i, j, psize=PATCH_SIZE):
  return lax.dynamic_slice(img, (i, j), (psize, psize))

get_patches = jit(vmap(get_patch, in_axes=(None, 0, 0)))

I, J = random(...), random(...)
patches = get_patches(image, I, J)

Replies: 1 comment

Comment options

You must be logged in to vote
0 replies
Answer selected by jakevdp
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants