You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I'm working on an application which requires a 'chunked' vmap operation, like described in #11319 to be distributed across multiple devices. I have a working example where I pad the arrays in the input PyTree, reshape the leading axis to have a size of n_devices and apply a pmap with a nested chunked vmap operation. However, I have noticed, that the post-processing reshape step (x = x.reshape(-1, *x.shape[2:])) takes a long time to compute compared to the overall function, making the speedup from using pmap far from ideal. I was wondering if there are any ways to refactor my code to avoid this operation or make it less expensive. I am fairly new to pmap so I imagine I could be doing something wrong here.
Below is the relevant section of my code. Note that np stands for jax.numpy here and onp refers to vanilla NumPy here.
n_devices = jax.device_count()
# Total number of chunks - this is dictated by the memory
n_chunks_total = 10
# Ensure n_chunks_total is at least n_devices and a multiple of n_devices
n_chunks_total = max(n_devices,
(n_chunks_total // n_devices) * n_devices)
input_data = [
cells_sol,
self.shape_grads,
self.JxW,
self.v_grads_JxW,
*kernel_vars,
] # Input PyTree
n_cells = len(cells_sol)
# Pad the data to be divisible by the number of devices
padding_size = (-n_cells % n_devices) % n_devices
target_size = n_cells + padding_size
n_cells_per_device = target_size // n_devices
if n_devices > 1:
# Pad the data to be divisible by the number of devices
padding_size = (-n_cells % n_devices) % n_devices
target_size = n_cells + padding_size
n_cells_per_device = target_size // n_devices
def _pad_and_reshape(x):
# Pad the arrays with zeros
if padding_size:
pad_width = [(0, 0)] * onp.ndim(x)
pad_width[0] = (0, padding_size)
x = onp.pad(x, pad_width)
device_shape = (n_devices, n_cells_per_device)
x = x.reshape(device_shape + x.shape[1:])
return x
def _remove_pad_and_reshape(x):
# Reshape to the original shape
x = x.reshape(-1, *x.shape[2:])
# Compute how much to slice off
if padding_size:
slice_end = -padding_size if padding_size else None
x = x[:slice_end]
return x
# Pad and reshape to distribute across devices
input_data = jax.tree_map(_pad_and_reshape, input_data)
def _extract_data_chunk(input_data, chunk_id, chunk_size, num_chunks):
start = chunk_id * chunk_size
if chunk_id < num_chunks - 1:
end = (chunk_id + 1) * chunk_size
else: # For the last chunk, take all remaining elements
end = None
data_chunk = jax.tree_map(lambda x: x[start:end], input_data)
return data_chunk
def chunked_vmap(f, num_chunks):
def chunked_fn(input_data):
# Check the size of the first argument
n_elements = input_data[0].shape[0]
chunk_size = n_elements // num_chunks
values = []
jacs = []
for chunk_id in range(num_chunks):
# Extract chunk
data_chunk = _extract_data_chunk(input_data, chunk_id,
chunk_size, num_chunks)
# Apply original function to the chunk
if jac_flag:
value, jac = jax.vmap(f)(*data_chunk)
values.append(value)
jacs.append(jac)
else:
value = jax.vmap(f)(*data_chunk)
values.append(value)
vals = np.vstack(values)
if jac_flag:
jacs = np.vstack(jacs)
return vals, jacs
else:
return vals
return chunked_fn
n_chunks_per_device = n_chunks_total // n_devices
chunked_vmap_fn = chunked_vmap(fn, n_chunks_per_device)
apply_fn = jax.pmap(chunked_vmap_fn) if n_devices > 1 else chunked_vmap_fn
# PyTree of inputs to the kernel function
if jac_flag:
values, jacs = apply_fn(input_data)
if n_devices > 1:
values = _remove_pad_and_reshape(values)
jacs = _remove_pad_and_reshape(jacs)
return values, jacs
else:
values = apply_fn(input_data)
if n_devices > 1:
values = _remove_pad_and_reshape(values)
return values
reacted with thumbs up emoji reacted with thumbs down emoji reacted with laugh emoji reacted with hooray emoji reacted with confused emoji reacted with heart emoji reacted with rocket emoji reacted with eyes emoji
Uh oh!
There was an error while loading. Please reload this page.
-
Hello everyone,
I'm working on an application which requires a 'chunked' vmap operation, like described in #11319 to be distributed across multiple devices. I have a working example where I pad the arrays in the input PyTree, reshape the leading axis to have a size of
n_devices
and apply a pmap with a nested chunked vmap operation. However, I have noticed, that the post-processing reshape step (x = x.reshape(-1, *x.shape[2:]))
takes a long time to compute compared to the overall function, making the speedup from using pmap far from ideal. I was wondering if there are any ways to refactor my code to avoid this operation or make it less expensive. I am fairly new to pmap so I imagine I could be doing something wrong here.Below is the relevant section of my code. Note that
np
stands forjax.numpy
here andonp
refers to vanilla NumPy here.Beta Was this translation helpful? Give feedback.
All reactions