random crops in JAX #9213
-
Hello everyone, I have encountered too many times the problem of quickly retrieving a patch from an image, usually from a dataloader perspective. I still don't know how to do this properly with JAX. For instance, in the following code I load an image and extract random patches from the image (of fixed def get_bboxes(skey0, skey1, n, h, w, patch_size):
coords0 = jax.random.randint(skey0, shape=(n,), minval=0, maxval=h-1-patch_size)
coords1 = jax.random.randint(skey1, shape=(n,), minval=0, maxval=w-1-patch_size)
bboxes = jnp.vstack((coords0, coords0+patch_size, coords1, coords1+patch_size)).T
return bboxes
def get_patches(jgimg, n, bboxes, patch_size):
patches = jnp.empty((n, patch_size, patch_size))
for i in range(n):
bbox = bboxes.at[i].get()
patches = patches.at[i].set(jgimg.at[bbox[0]:bbox[1], bbox[2]:bbox[3]].get())
return patches
split_3 = jax.jit(lambda key: jax.random.split(key, 3))
def get_n_patches(key, n, jgimg, patch_size):
h, w = jgimg.shape
skeys = split_3(key)
bboxes = get_bboxes(skeys[1], skeys[2], n, h, w, patch_size)
patches = get_patches(jgimg, n, bboxes, patch_size)
return skeys[0], patches
# initialize X as a list
X = []
key = jax.random.PRNGKey(SEED)
img_paths = sorted(glob.glob('dataset/train/*.jpg'))
for img_n, img_path in enumerate(img_paths): # iterate over dataset
t0 = time.time()
jimg = jnp.array(plt.imread(img_path)) # jax array gray image
t1 = time.time()
key, patches = get_n_patches(key, N, jgimg, PATCH_SIZE)
X.append(patches)
t2 = time.time()
print('---')
print('img_n:', img_n)
print(t1 - t0)
print(t2 - t1) Could anyone make this faster? |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment
-
To make the computation jittable, you might want to check out lax.dynamic_slice. General suggestion: instead of using a for loop, you could def get_patch(img, i, j, psize=PATCH_SIZE):
return lax.dynamic_slice(img, (i, j), (psize, psize))
get_patches = jit(vmap(get_patch, in_axes=(None, 0, 0)))
I, J = random(...), random(...)
patches = get_patches(image, I, J) |
Beta Was this translation helpful? Give feedback.
To make the computation jittable, you might want to check out lax.dynamic_slice.
General suggestion: instead of using a for loop, you could
vmap
a function over some pre-generated arrays of random coordinates. Something like