Replies: 6 comments 54 replies
-
I notice that you misunderstand the usage of num_batches = 100
batch_size = 512
assert jax.lax.scan(lambda n_batches, _: (n_batches + 1, None), 0, jnp.split(jnp.zeros(num_batches * batch_size), num_batches))[0] == batch_size Namely, your |
Beta Was this translation helpful? Give feedback.
-
Can you take a profile using |
Beta Was this translation helpful? Give feedback.
-
Beta Was this translation helpful? Give feedback.
-
|
Beta Was this translation helpful? Give feedback.
-
I don’t think that augmentation or overheads related to scan are
problematic here; the profile suggested that virtually all your time is
spent in the device step. But none of us on the JAX team are experts in
cuDNN, and it really feels like the issue is something like cuDNN (via XLA)
selecting slow kernels.
…On Mon, Mar 7, 2022 at 8:18 PM Anselm Levskaya ***@***.***> wrote:
@samuela <https://github.com/samuela>
- do you have any sense how much time is spent in your on-device
augmentation vs the actual NN code? In all "production" code using JAX we
tend to use tf.data to set up the dataset augmentation pipeline since
tf.data is great at utilizing CPUs. Shoving augmentation computations onto
device is only going to hurt performance since you're leaving a lot of
compute on the table.
- I think a few of us are worried about scan potentially introducing
some inefficiencies here. In practice, we never scan a training loop. It
might be innocent of trouble, but it's a bit odd. Given jax's async
dispatch you don't really gain much by doing this.
—
Reply to this email directly, view it on GitHub
<#9669 (comment)>,
or unsubscribe
<https://github.com/notifications/unsubscribe-auth/ACZPRNSS335WINEJMPDUEG3U63IKPANCNFSM5PC7JETA>
.
You are receiving this because you commented.Message ID:
***@***.***>
|
Beta Was this translation helpful? Give feedback.
-
@samuela, I want to make absolutely sure that we're comparing apples-to-apples here. First thing I'm wondering about is the padding in the convolution. The Pytorch version uses Regardless, the discrepancy between the two paddings is a valid difference, right? |
Beta Was this translation helpful? Give feedback.
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
The current # 8 record holder on the DAWNBench CIFAR-10 benchmark is a PyTorch ResNet by @davidcpage running on an EC2 p3.2xlarge instance (single V100 GPU). Record holders # 1-# 7 either use multiple GPUs or run on more esoteric cloud providers, so I'm not worrying about them for now.
My goal is to at least match the performance of this PyTorch implementation in JAX in terms of
seconds/epoch
. However, I've found that even my somewhat carefully written JAX version is a shocking order of magnitude slower than the PyTorch version. On a p3.2xlarge instance, my JAX version (attached) clocks in at about 24.2s/epoch. The PyTorch version reports completing 24 epochs in 72s, which comes out to 3s/epoch.Some notes:
lax.scan
.Original PyTorch implementation: https://github.com/davidcpage/cifar10-fast
JAX implementation (including my shell.nix for reproducibility): https://gist.github.com/samuela/78a3f0bbac759833a0464048aa499c98
What am I doing wrong here? What do I have to do to get competitive performance out of JAX?
cc @sharadmv
Beta Was this translation helpful? Give feedback.
All reactions