Replies: 1 comment
-
I've made some progress: after looking at the trace of the inference code for a small model using tensorboard memory profiler, I've discovered that XLA decided to transpose all matrices for fully connected layers, which constitute the majority of parameters, thus effectively doubling memory consumption. It might have to do with the structure of the model: my dense params are of shape What is the right way to control this? |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
I'm currently facing the following problem:
I have a jitted jax function, say
f(params, input) -> output
, which performs inference for a large model. The model is so large that it's only possible to store one copy ofparams
in GPU memory.According to my calculations, both parameters and intermediate tensors in
f
(assuming XLA discards them as soon as they aren't needed anymore) should fit into the memory of a GPU. However, when I actually runf
on an input, XLA tries to allocate a huge chunk of memory (suspiciously similar in size toparams
, as isf
was making a copy of theparams
argument) which it cannot do, because there isn't enough memory.I'm trying to figure out why would my computation require this chunk of memory. I've looked at the generated code and checked for some obvious suspects:
So I need to dig deeper. Is there some way to figure out what the XLA allocates this tensor for? I've tried finding something that can answer this question in the generated code (
lower(f).compile().as_text()
), but from what I can tell there is no information about allocations there.Any suggestions appreciated.
Beta Was this translation helpful? Give feedback.
All reactions