Replies: 1 comment
-
def unreplicate(tree):
"""Returns a single instance of a replicated array."""
return jax.tree_map(lambda x: x[0], tree) |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
I am one of the participant of Huggingface Community Week. I got a question about pmap and training loop, would really appreciate if someone could offer me some advice on that:
I have an memory & time expensive step that needa run per N iterations, e.g.:
Questions
run_expensive_step
create a big bottleneck onpmap_train
call, or it wont be blocked (maybe becoz of async nature?)...run_expensive_step
to be called only on the host core? (coz my host have 300+GB andrun_expensive_step
is super memory hungry)model
has been passed throughflax.jax_utils.replicate
andbatch
has been passed through shard, butclip_model
haven't beenflax.jax_utils.replicate
. Do I needa reducemodel
andbatch
dimension before eachrun_expensive_step
call?Beta Was this translation helpful? Give feedback.
All reactions