Replies: 2 comments
-
One key piece of information that would help: what hardware are you using and what software versions are you using? It's hard to say what's happening here without running it and looking at a profile. A tensorboard profile would be interesting if you wanted to try grabbing one: https://jax.readthedocs.io/en/latest/profiling.html |
Beta Was this translation helpful? Give feedback.
-
I am using jax version 0.2.7, jaxlib 0.1.57 and CUDA 11 on a Gefore RTX 2080Ti. For some reason, I cannot get Tensorboard profiling to work but attached is the hopefully relevant part from the XLA profile:
|
Beta Was this translation helpful? Give feedback.
Uh oh!
There was an error while loading. Please reload this page.
-
I need to convolve 3D data with a decomposable kernel
K
of shape(35,35,6,6,6)
with 35 being the channel dimension.I can decompose this kernel into 3 kernels
Ki
of shapes(35,35,6,1,1)
,(35,35,1,6,1)
and(35,35,1,1,6)
such thatK1*K2*K3 = K
with broadcasting.In principle, this should lead to a 12x reduction in FLOPs [
(6*6*6)/(6*3)
] but it actually leads to a 3x slow down. Below is an example of what I am doing.I would really like to not pay the price for a full convolution when I am able to decompose it. Has anyone gotten something like this to work?
Also,
wrong_fac_conv
andfac_conv
technically require the amount of FLOPs. One is however 8x faster than the other.Would writing my own
custom_call
be the only option to get the FLOPs out of JAX?Beta Was this translation helpful? Give feedback.
All reactions