Replies: 2 comments
-
cc @froystig |
Beta Was this translation helpful? Give feedback.
0 replies
-
Do you have a minimal code example that reproduces this? I'll convert this to an issue, since it seems like one. We can move the discussion there. |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
I was working with a transformer model in jax and haiku, and found that dropout greatly slows down data parallel training, the main training step looks like
where
sharding = jax.sharding.PositionalSharding(jax.devices())
, containing GPU:0 and GPU:1train_key
is aPRNGKeyArray
, not shardedself._train_state
is a PyTree of params and opt_states, replicated withjax.device_put(train_state, sharding.replicate())
batch
is a PyTree of data and labels, sharded withjax.device_put(batch, sharding)
Every operation in this model (except final loss reduction) is independent between each sample in batch, so this should be trivially data parallel.
Without
x = hk.dropout(hk.next_rng_key(), self.dropout, x)
(boils down to ajax.random.split
and ajax.random.bernoulli
), every thing works well (Single device: 4.2 it/s, Two devices: 7.5 it/s). But when dropout is enabled (called 20 times), I gotjax.config.update('jax_threefry_partitionable', True)
: 5.32 it/s (I was aware of the document)which is far from expected.
Did I miss somthing? Could this performance be optimized?
Beta Was this translation helpful? Give feedback.
All reactions