lax.map resulting in large memory usage #20458
Replies: 1 comment
-
Hello, It would be helpful to have a minimal reproducible example. But from your code, I can say this: def apply_to_all_indexed_triplets(arr, func):
v_func = jax.vmap(
jax.vmap(
jax.vmap(func, in_axes=(0, None, None, 0, None, None)),
in_axes=(None, 0, None, None, 0, None),
),
in_axes=(None, None, 0, None, None, 0),
)
indices = jnp.arange(arr.shape[0])
return v_func(arr, arr, arr, indices, indices, indices) You have a triple nested loop with your array used three times in the input. Increased memory usage :Your algorithm has nested loops which means that it has be at least Now, I think if you provide the purpose of this function, I might help you better. But I can provide this snippet that should reduce the overhead. Instead of this You should a use single from itertools import combinations
def is_valid_triplet(labels, indices):
one = lambda: 1.0
zero = lambda: 0.0
index_i, index_j, index_k = indices
label_i = jax.lax.dynamic_slice(labels,(index_i,0), (1 , labels.shape[1]))
label_j = jax.lax.dynamic_slice(labels,(index_j,0), (1 , labels.shape[1]))
label_k = jax.lax.dynamic_slice(labels,(index_k,0), (1 , labels.shape[1]))
i_not_equal_j = jax.lax.cond(index_i == index_j, zero, one)
i_not_equal_k = jax.lax.cond(index_i == index_k, zero, one)
j_not_equal_k = jax.lax.cond(index_j == index_k, zero, one)
distinct_indices = jnp.logical_and(
jnp.logical_and(i_not_equal_j, i_not_equal_k), j_not_equal_k
)
li_lj_equal = jax.lax.cond(label_i == label_j, one, zero)
li_lk_not_equal = jax.lax.cond(label_i == label_k, zero, one)
valid_labels = jnp.logical_and(li_lj_equal, li_lk_not_equal)
return jnp.logical_and(distinct_indices, valid_labels)
def apply_to_all_indexed_triplets(arr, func):
v_func = jax.vmap(in_axes=(None,0))
indices = jnp.arange(arr.shape[0])
permutations = combinations(indices, 3)
return v_func(arr, permutations) Note: This is not exactly what you should do, you should get inpired from this and find what is best for you. This way you have a single vmap and the only thing duplicated are indices and not labels. I am having a hard time understanding what you want without examples .. provide some examples and I will look further, In case you need the triplets you might want to consieder using a data loader such as tf.dataset or pytorch dataloader. There is a tutorial specific to JAX here for tfdataset and pytorch. |
Beta Was this translation helpful? Give feedback.
Uh oh!
There was an error while loading. Please reload this page.
-
I have a problem using
lax.map
where my memory usage increases over time until it crashes where I would expect it to have constant memory usage.when I try doing one instance of
get_triplet_count
on its own I have no problems. I have also tried using a normal python loopwhich works fine and does not result in increasing memory usage over time but ideally I want something faster than this as the length of my data is very large.
I have also tried using vmap and scan but those also ran out of memory.
I expect there is a memory leak somewhere but I am having a hard time identifying it.
Any help would be greatly appreciated
Beta Was this translation helpful? Give feedback.
All reactions