JAX/FLAX scales up worse than Tensorflow #16473
Unanswered
giorgiofranceschelli
asked this question in
Q&A
Replies: 0 comments
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
Hello,
I started learning JAX/FLAX with a very simple VAE model working on MNIST, and when I saw its better performances vs Tensorflow I decided to move my current project to JAX/FLAX. However, with bigger and more complex architectures, I experienced worse performances than the original TF implementation. So I went back to the VAE-MNIST and checked if it scales up correctly, but it des not seem to be the case.
In particular, with this implementation:
and by varying image dimension, latent size, batch size or convolutional filters I obtained the following performances:
with z_dim = 64, bs = 64, image_dim = (28, 28, 1), conv_filters = 32:
with z_dim = 64, bs = 64, image_dim = (64, 64, 1), conv_filters = 32:
with z_dim = 256, bs = 64, image_dim = (28, 28, 1), conv_filters = 32:
with z_dim = 64, bs = 256, image_dim = (28, 28, 1), conv_filters = 32:
with z_dim = 64, bs = 64, image_dim = (28, 28, 1), conv_filters = 128:
I runned everything on GPU with Google Colab.
As you can see, JAX/FLAX is way faster for the base experiment, but slows down with respect to Tensorflow if z_dim, batch_size, image_dim or conv_filters increase.
What am I doing wrong? Any help would be appreciated.
Beta Was this translation helpful? Give feedback.
All reactions