You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
The question appeared when I executed the p_train_step function. The shape of the batch was (4, 8, 128, 128, 1), and I expected the axis 0 to be mapped by pmap. However, when I executed
where next_rng was 4 random keys generated by jax.random.split, p_train_state and p_model_state were replicated by flax.jax_utils.replicate, the shape of the batch received by the step_fn came out to be (128, 128, 1), thus the error came from
B, H, W, C = batch.shape
The batch dimension was lost. However, the problem disappeared when I ran
reacted with thumbs up emoji reacted with thumbs down emoji reacted with laugh emoji reacted with hooray emoji reacted with confused emoji reacted with heart emoji reacted with rocket emoji reacted with eyes emoji
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
Hello, everyone! I was using
jax.pmap
to start my training on 4 1080Ti GPUs. I got the training step function byand then I combined the step function with
jax.lax.scan
,The question appeared when I executed the
p_train_step
function. The shape of the batch was(4, 8, 128, 128, 1)
, and I expected the axis 0 to be mapped bypmap
. However, when I executedwhere
next_rng
was 4 random keys generated byjax.random.split
,p_train_state
andp_model_state
were replicated byflax.jax_utils.replicate
, the shape of the batch received by thestep_fn
came out to be(128, 128, 1)
, thus the error came fromThe batch dimension was lost. However, the problem disappeared when I ran
then everything went well. How can I solve the problem?
Beta Was this translation helpful? Give feedback.
All reactions