Best Practise for Loading Data onto GPU with Jax? #7433
-
Hello everyone, I hope you are all well. I have a question with respect to the optimal way of handling a dataset for a model trained on GPU. Suppose I have a very large data set X which is too large to fit on the GPU, and in addition I have my model's parameters Currently I would move a batch onto the GPU and then initiate a training step on this batch, i.e. for i in range(no_batches):
# move batch onto the GPU by creating a jnp.array
batch = jnp.array(X[i])
train_step(params, batch) However, this seems inefficient, as for each step in the loop the function, What is the optimal way to go about moving data onto the GPU for training with JAX? Many thanks for your help. |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
Two comments:
|
Beta Was this translation helpful? Give feedback.
Two comments:
jnp.array
nor yourtrain_step
will block unless you have a blocking operation like print, things just get queued for later execution in the background as they become available.jnp.asarray
, it sometimes prevents unnecessary copying.