Replies: 2 comments 3 replies
-
If by chance you're using HuggingFace, their Flax models actually initialize a copy of the params prior to loading the weights, so you end up with two copies existing at the same time. Also if you're the Michael Mozer I think you were on my senior thesis committee a while back, so thanks for that! |
Beta Was this translation helpful? Give feedback.
-
Haha, i am that Mike. I remember your thesis, circa 2016. :-) I'm using my own from-scratch code, not HuggingFace. I was similarly concerned about having two copies of parameters around being responsible, but the experiment that I reported should control for the number of copies: I have in memory both model-initialized parameters and the parameters I've read, but then I plug just one or the other set into the weight updates. The re-initialized parameters run fine; the saved parameters bomb out with the error. (I've also tried |
Beta Was this translation helpful? Give feedback.
Uh oh!
There was an error while loading. Please reload this page.
-
I have jax code that runs without a problem when I reinitialize the model, but when I replace the initialized model parameters with parameters I've read from a checkpoint, I get a Resource Exhausted error.
I am running in a colab with TPUs. I restart the runtime, read in the saved parameters, and then either use the parameters that come from model.init or the saved parameters. With the saved parameters, I get the error:
XlaRuntimeError: RESOURCE_EXHAUSTED: Error allocating device buffer: Attempting to allocate 1.65G. That was not possible. There are 1.42G free.; (0x0x0_HBM0)
I'd be so appreciative if anyone could help me to understand how I could be screwing up the saved parameters to make this error occur. The particularly confusing thing is that this error does not seem to occur the first few times I save parameters. That is, if I repeatedly save and then read the parameters (each time restarting the runtime before reading the parameters), the first few checkpoints are read correctly and the model resumes. So there is no gross bug in storing or reading parameters.
Some more detail:
(1) I have set
jax_disable_jit=True
to facilitate debugging, but the error occurs with or without jit.(2) I literally change one line of my code from
params = initialized_params
toparams = params_read_from_file
(3) I've verified that the parameters read from the file have the right structure, have no NaNs or Infs
(4) I'm saving the parameters with joblib.dump and restoring with joblib.load.
I have the naive theory that somehow the parameters remember where they were stored on TPU in their previous incarnations and somehow TPU memory becomes fragmented. It doesn't seem possible but it is consistent with my evidence.
Beta Was this translation helpful? Give feedback.
All reactions