Replies: 1 comment 7 replies
-
JAX doesn't have any mechanism to only load parts of an array from disk – the code you wrote will attempt to load the entire dataset, and only then pull out the subset of relevant data. So, for example, when you call X = jnp.array(data['images'], dtype=jnp.int8) I believe that JAX will convert If you want to use h5py to load the data in batches, you'll have to keep |
Beta Was this translation helpful? Give feedback.
7 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
I would like to sample from the 3d-shapes dataset with some weight on each individual specimen. Here is the full source code, and you need to download the 255.2 MB dataset from Google Cloud Storage and place the file in
data/3dshapes.h5
.In the dataset,
data['images']
anddata['labels']
have size(480000, 64, 64, 3)
and(480000, 6)
respectively, but I am only interested in one of the labels, henceY = jnp.array(data['labels'][:, label_col])
.I tried to load one batch to see if it's working
This turns out to be extremely slow, and I got the following error after several minutes. Note that each batch merely has size
(8, 64, 64, 3)
, and I am on a v3-8 TPU VM with 335GB RAM, so the available memory should be more than enough.Beta Was this translation helpful? Give feedback.
All reactions