pjit requires too much memory #16659
Replies: 1 comment
-
This is replicated in #16679 |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
Suppose I have a large
state
stored on CPU, and a tree ofshardings
matching the state structure. Let's assume the partition is fairly uniform, so that about the same amount of memory should supposedly be allocated per device after sharding.On my multi-GPU setup, the following code goes out of memory:
while the following does not:
Why? The result should be the same. In both cases, I should obtain a state sharded across my GPU devices.
Some context
The example above is a proof of concept. More specifically, I am trying to instantiate on GPU the state of GPT-J, an LLM with 6b parameters. The state is mainly comprised of the model parameters, as well as two more replicas corresponding to the
mu
andnu
parameters of anoptax.adam
optimizer. So, 18b parameters in total. In half-precision (float16, i.e. 2 bytes), this gives me 36b bytes, that is 36 GB of memory.My Amazon EC2 instance has 8 GPUs, each with 16 GiB of GPU memory. I have created a
sharding
tree that partitions the state fairly uniformly across the 8 devices. Since 36 / 8 GB = 4.5 GB, the sharded state should comfortably fit in GPU memory.And yet, when initializing the state with
pjit
, I go out of memory. I then stripped down my example to the bone, and discovered, like in the example above, thatdevice_put
works instead, as it should.Do you have an explanation for why this is happening? Is there anything I can do to make
pjit
work? Eventually, I really want to work withpjit
rather thandevice_put
, because I want my state to be directly sharded on GPU, rather than having to store it on CPU first and then putting it on the devices.Beta Was this translation helpful? Give feedback.
All reactions