Skip to content

Cannot resume trainer from checkpoint #42

@YerayL

Description

@YerayL

Hello, I run the script train_phoenix_7b.sh to save the checkpoint. When I run the script for the second time to resume from the checkpoint, an error is reported. I don’t know how to solve it:

Loading checkpoint shards: 100%|██████████| 2/2 [00:12<00:00,  6.01s/it]
Loading checkpoint shards: 100%|██████████| 2/2 [00:11<00:00,  5.82s/it]
Loading checkpoint shards: 100%|██████████| 2/2 [00:11<00:00,  5.77s/it]
╭───────────────────── Traceback (most recent call last) ──────────────────────╮
│ /home/ficina/lyy/LLMZoo/train.py:71 in <module>                            │
│                                                                              │
│   68                                                                         │
│   69                                                                         │
│   70 if __name__ == "__main__":                                              │
│ ❱ 71 │   train()                                                             │
│   72                                                                         │
│                                                                              │
│ /home/ficina/lyy/LLMZoo/train.py:59 in train                               │
│                                                                              │
│   56 │   │   │   model = torch.compile(model)                                │
│   57 │                                                                       │
│   58 │   if list(pathlib.Path(training_args.output_dir).glob("checkpoint-*") │
│ ❱ 59 │   │   trainer.train(resume_from_checkpoint=True)                      │
│   60 │   else:                                                               │
│   61 │   │   trainer.train()                                                 │
│   62                                                                         │
│                                                                              │
│ /home/ficina/anaconda3/envs/LLMZOO/lib/python3.8/site-packages/transformer │
│ s/trainer.py:1685 in train                                                   │
│                                                                              │
│   1682 │   │   │   │   raise ValueError(f"No valid checkpoint found in outpu │
│   1683 │   │                                                                 │
│   1684 │   │   if resume_from_checkpoint is not None and not is_sagemaker_mp │
│ ❱ 1685 │   │   │   self._load_from_checkpoint(resume_from_checkpoint)        │
│   1686 │   │                                                                 │
│   1687 │   │   # If model was re-initialized, put it on the right device and │
│   1688 │   │   if model_reloaded:                                            │
│                                                                              │
│ /home/ficina/anaconda3/envs/LLMZOO/lib/python3.8/site-packages/transformer │
│ s/trainer.py:2211 in _load_from_checkpoint                                   │
│                                                                              │
│   2208 │   │   │   │   self._issue_warnings_after_load(load_result)          │
│   2209 │   │   else:                                                         │
│   2210 │   │   │   # We load the sharded checkpoint                          │
│ ❱ 2211 │   │   │   load_result = load_sharded_checkpoint(                    │
│   2212 │   │   │   │   model, resume_from_checkpoint, strict=is_sagemaker_mp │
│   2213 │   │   │   )                                                         │
│   2214 │   │   │   if not is_sagemaker_mp_enabled():                         │
│                                                                              │
│ /home/ficina/anaconda3/envs/LLMZOO/lib/python3.8/site-packages/transformer │
│ s/modeling_utils.py:423 in load_sharded_checkpoint                           │
│                                                                              │
│    420 │   loader = safe_load_file if load_safe else partial(torch.load, map │
│    421 │                                                                     │
│    422 │   for shard_file in shard_files:                                    │
│ ❱  423 │   │   state_dict = loader(os.path.join(folder, shard_file))         │
│    424 │   │   model.load_state_dict(state_dict, strict=False)               │
│    425 │   │                                                                 │
│    426 │   │   # Make sure memory is freed before we load the next state dic │
│                                                                              │
│ /home/ficina/anaconda3/envs/LLMZOO/lib/python3.8/site-packages/torch/seria │
│ lization.py:809 in load                                                      │
│                                                                              │
│    806 │   │   │   │   │   │   return _load(opened_zipfile, map_location, _w │
│    807 │   │   │   │   │   except RuntimeError as e:                         │
│    808 │   │   │   │   │   │   raise pickle.UnpicklingError(UNSAFE_MESSAGE + │
│ ❱  809 │   │   │   │   return _load(opened_zipfile, map_location, pickle_mod │
│    810 │   │   if weights_only:                                              │
│    811 │   │   │   try:                                                      │
│    812 │   │   │   │   return _legacy_load(opened_file, map_location, _weigh │
│                                                                              │
│ /home/ficina/anaconda3/envs/LLMZOO/lib/python3.8/site-packages/torch/seria │
│ lization.py:1172 in _load                                                    │
│                                                                              │
│   1169 │                                                                     │
│   1170 │   unpickler = UnpicklerWrapper(data_file, **pickle_load_args)       │
│   1171 │   unpickler.persistent_load = persistent_load                       │
│ ❱ 1172 │   result = unpickler.load()                                         │
│   1173 │                                                                     │
│   1174 │   torch._utils._validate_loaded_sparse_tensors()                    │
│   1175                                                                       │
│                                                                              │
│ /home/ficina/anaconda3/envs/LLMZOO/lib/python3.8/site-packages/torch/_util │
│ s.py:169 in _rebuild_tensor_v2                                               │
│                                                                              │
│   166 def _rebuild_tensor_v2(                                                │
│   167 │   storage, storage_offset, size, stride, requires_grad, backward_hoo │
│   168 ):                                                                     │
│ ❱ 169 │   tensor = _rebuild_tensor(storage, storage_offset, size, stride)    │
│   170 │   tensor.requires_grad = requires_grad                               │
│   171 │   if metadata:                                                       │
│   172 │   │   set_tensor_metadata(tensor, metadata)                          │
│                                                                              │
│ /home/ficina/anaconda3/envs/LLMZOO/lib/python3.8/site-packages/torch/_util │
│ s.py:148 in _rebuild_tensor                                                  │
│                                                                              │
│   145 def _rebuild_tensor(storage, storage_offset, size, stride):            │
│   146 │   # first construct a tensor with the correct dtype/device           │
│   147 │   t = torch.tensor([], dtype=storage.dtype, device=storage._untyped_ │
│ ❱ 148 │   return t.set_(storage._untyped_storage, storage_offset, size, stri │
│   149                                                                        │
│   150                                                                        │
│   151 def get_tensor_metadata(tensor):                                       │
╰──────────────────────────────────────────────────────────────────────────────╯
RuntimeError: Trying to resize storage that is not resizable
╭───────────────────── Traceback (most recent call last) ──────────────────────╮
│ /home/ficina/lyy/LLMZoo/train.py:71 in <module>                            │
│                                                                              │
│   68                                                                         │
│   69                                                                         │
│   70 if __name__ == "__main__":                                              │
│ ❱ 71 │   train()                                                             │
│   72                                                                         │
│                                                                              │
│ /home/ficina/lyy/LLMZoo/train.py:59 in train                               │
│                                                                              │
│   56 │   │   │   model = torch.compile(model)                                │
│   57 │                                                                       │
│   58 │   if list(pathlib.Path(training_args.output_dir).glob("checkpoint-*") │
│ ❱ 59 │   │   trainer.train(resume_from_checkpoint=True)                      │
│   60 │   else:                                                               │
│   61 │   │   trainer.train()                                                 │
│   62                                                                         │
│                                                                              │
│ /home/ficina/anaconda3/envs/LLMZOO/lib/python3.8/site-packages/transformer │
│ s/trainer.py:1685 in train                                                   │
│                                                                              │
│   1682 │   │   │   │   raise ValueError(f"No valid checkpoint found in outpu │
│   1683 │   │                                                                 │
│   1684 │   │   if resume_from_checkpoint is not None and not is_sagemaker_mp │
│ ❱ 1685 │   │   │   self._load_from_checkpoint(resume_from_checkpoint)        │
│   1686 │   │                                                                 │
│   1687 │   │   # If model was re-initialized, put it on the right device and │
│   1688 │   │   if model_reloaded:                                            │
│                                                                              │
│ /home/ficina/anaconda3/envs/LLMZOO/lib/python3.8/site-packages/transformer │
│ s/trainer.py:2211 in _load_from_checkpoint                                   │
│                                                                              │
│   2208 │   │   │   │   self._issue_warnings_after_load(load_result)          │
│   2209 │   │   else:                                                         │
│   2210 │   │   │   # We load the sharded checkpoint                          │
│ ❱ 2211 │   │   │   load_result = load_sharded_checkpoint(                    │
│   2212 │   │   │   │   model, resume_from_checkpoint, strict=is_sagemaker_mp │
│   2213 │   │   │   )                                                         │
│   2214 │   │   │   if not is_sagemaker_mp_enabled():                         │
│                                                                              │
│ /home/ficina/anaconda3/envs/LLMZOO/lib/python3.8/site-packages/transformer │
│ s/modeling_utils.py:423 in load_sharded_checkpoint                           │
│                                                                              │
│    420 │   loader = safe_load_file if load_safe else partial(torch.load, map │
│    421 │                                                                     │
│    422 │   for shard_file in shard_files:                                    │
│ ❱  423 │   │   state_dict = loader(os.path.join(folder, shard_file))         │
│    424 │   │   model.load_state_dict(state_dict, strict=False)               │
│    425 │   │                                                                 │
│    426 │   │   # Make sure memory is freed before we load the next state dic │
│                                                                              │
│ /home/ficina/anaconda3/envs/LLMZOO/lib/python3.8/site-packages/torch/seria │
│ lization.py:809 in load                                                      │
│                                                                              │
│    806 │   │   │   │   │   │   return _load(opened_zipfile, map_location, _w │
│    807 │   │   │   │   │   except RuntimeError as e:                         │
│    808 │   │   │   │   │   │   raise pickle.UnpicklingError(UNSAFE_MESSAGE + │
│ ❱  809 │   │   │   │   return _load(opened_zipfile, map_location, pickle_mod │
│    810 │   │   if weights_only:                                              │
│    811 │   │   │   try:                                                      │
│    812 │   │   │   │   return _legacy_load(opened_file, map_location, _weigh │
│                                                                              │
│ /home/ficina/anaconda3/envs/LLMZOO/lib/python3.8/site-packages/torch/seria │
│ lization.py:1172 in _load                                                    │
│                                                                              │
│   1169 │                                                                     │
│   1170 │   unpickler = UnpicklerWrapper(data_file, **pickle_load_args)       │
│   1171 │   unpickler.persistent_load = persistent_load                       │
│ ❱ 1172 │   result = unpickler.load()                                         │
│   1173 │                                                                     │
│   1174 │   torch._utils._validate_loaded_sparse_tensors()                    │
│   1175                                                                       │
│                                                                              │
│ /home/ficina/anaconda3/envs/LLMZOO/lib/python3.8/site-packages/torch/_util │
│ s.py:169 in _rebuild_tensor_v2                                               │
│                                                                              │
│   166 def _rebuild_tensor_v2(                                                │
│   167 │   storage, storage_offset, size, stride, requires_grad, backward_hoo │
│   168 ):                                                                     │
│ ❱ 169 │   tensor = _rebuild_tensor(storage, storage_offset, size, stride)    │
│   170 │   tensor.requires_grad = requires_grad                               │
│   171 │   if metadata:                                                       │
│   172 │   │   set_tensor_metadata(tensor, metadata)                          │
│                                                                              │
│ /home/ficina/anaconda3/envs/LLMZOO/lib/python3.8/site-packages/torch/_util │
│ s.py:148 in _rebuild_tensor                                                  │
│                                                                              │
│   145 def _rebuild_tensor(storage, storage_offset, size, stride):            │
│   146 │   # first construct a tensor with the correct dtype/device           │
│   147 │   t = torch.tensor([], dtype=storage.dtype, device=storage._untyped_ │
│ ❱ 148 │   return t.set_(storage._untyped_storage, storage_offset, size, stri │
│   149                                                                        │
│   150                                                                        │
│   151 def get_tensor_metadata(tensor):                                       │
╰──────────────────────────────────────────────────────────────────────────────╯
RuntimeError: Trying to resize storage that is not resizable
WARNING:torch.distributed.elastic.multiprocessing.api:Sending process 1599332 closing signal SIGTERM
WARNING:torch.distributed.elastic.multiprocessing.api:Sending process 1599333 closing signal SIGTERM
ERROR:torch.distributed.elastic.multiprocessing.api:failed (exitcode: 1) local_rank: 0 (pid: 1599331) of binary: /home/ficina/anaconda3/envs/LLMZOO/bin/python
Traceback (most recent call last):
  File "/home/ficina/anaconda3/envs/LLMZOO/bin/torchrun", line 8, in <module>
    sys.exit(main())
  File "/home/ficina/anaconda3/envs/LLMZOO/lib/python3.8/site-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 346, in wrapper
    return f(*args, **kwargs)
  File "/home/ficina/anaconda3/envs/LLMZOO/lib/python3.8/site-packages/torch/distributed/run.py", line 794, in main
    run(args)
  File "/home/ficina/anaconda3/envs/LLMZOO/lib/python3.8/site-packages/torch/distributed/run.py", line 785, in run
    elastic_launch(
  File "/home/ficina/anaconda3/envs/LLMZOO/lib/python3.8/site-packages/torch/distributed/launcher/api.py", line 134, in __call__
    return launch_agent(self._config, self._entrypoint, list(args))
  File "/home/ficina/anaconda3/envs/LLMZOO/lib/python3.8/site-packages/torch/distributed/launcher/api.py", line 250, in launch_agent
    raise ChildFailedError(
torch.distributed.elastic.multiprocessing.errors.ChildFailedError:
============================================================
train.py FAILED
------------------------------------------------------------
Failures:
[1]:
  time      : 2023-06-02_09:19:19
  host      : GPU-A40-116
  rank      : 3 (local_rank: 3)
  exitcode  : 1 (pid: 1599334)
  error_file: <N/A>
  traceback : To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html
------------------------------------------------------------
Root Cause (first observed failure):
[0]:
  time      : 2023-06-02_09:19:19
  host      : GPU-A40-116
  rank      : 0 (local_rank: 0)
  exitcode  : 1 (pid: 1599331)
  error_file: <N/A>
  traceback : To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html
============================================================```

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions