Skip to content

Commit 778e31e

Browse files
authored
Fix checkpoint saving (meta-llama#650)
1 parent 04766dc commit 778e31e

File tree

10 files changed

+123
-127
lines changed

10 files changed

+123
-127
lines changed

recipes/quickstart/finetuning/quickstart_peft_finetuning.ipynb

Lines changed: 26 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@
6565
{
6666
"data": {
6767
"application/vnd.jupyter.widget-view+json": {
68-
"model_id": "c7963d43806d432aaa3d00e2055e355c",
68+
"model_id": "68838a4f42f84545912e95b339a31034",
6969
"version_major": 2,
7070
"version_minor": 0
7171
},
@@ -75,13 +75,6 @@
7575
},
7676
"metadata": {},
7777
"output_type": "display_data"
78-
},
79-
{
80-
"name": "stderr",
81-
"output_type": "stream",
82-
"text": [
83-
"Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n"
84-
]
8578
}
8679
],
8780
"source": [
@@ -101,6 +94,7 @@
10194
"train_config.context_length = 1024 if torch.cuda.get_device_properties(0).total_memory < 16e9 else 2048 # T4 16GB or A10 24GB\n",
10295
"train_config.batching_strategy = \"packing\"\n",
10396
"train_config.output_dir = \"meta-llama-samsum\"\n",
97+
"train_config.use_peft = True\n",
10498
"\n",
10599
"from transformers import BitsAndBytesConfig\n",
106100
"config = BitsAndBytesConfig(\n",
@@ -205,7 +199,7 @@
205199
"model_input = tokenizer(eval_prompt, return_tensors=\"pt\").to(\"cuda\")\n",
206200
"\n",
207201
"model.eval()\n",
208-
"with torch.no_grad():\n",
202+
"with torch.inference_mode():\n",
209203
" print(tokenizer.decode(model.generate(**model_input, max_new_tokens=100)[0], skip_special_tokens=True))"
210204
]
211205
},
@@ -230,34 +224,20 @@
230224
"name": "stderr",
231225
"output_type": "stream",
232226
"text": [
233-
"/home/ubuntu/miniconda3/envs/llama/lib/python3.11/site-packages/datasets/load.py:1486: FutureWarning: The repository for samsum contains custom code which must be executed to correctly load the dataset. You can inspect the repository content at https://hf.co/datasets/samsum\n",
234-
"You can avoid this message in future by passing the argument `trust_remote_code=True`.\n",
235-
"Passing `trust_remote_code=True` will be mandatory to load this dataset from the next major release of `datasets`.\n",
236-
" warnings.warn(\n",
237-
"Preprocessing dataset: 100%|██████████| 14732/14732 [00:02<00:00, 6124.69it/s]\n"
227+
"/home/ubuntu/llama-recipes/src/llama_recipes/model_checkpointing/checkpoint_handler.py:17: DeprecationWarning: `torch.distributed._shard.checkpoint` will be deprecated, use `torch.distributed.checkpoint` instead\n",
228+
" from torch.distributed._shard.checkpoint import (\n",
229+
"Preprocessing dataset: 100%|██████████| 14732/14732 [00:02<00:00, 5872.02it/s]\n"
238230
]
239231
}
240232
],
241233
"source": [
242234
"from llama_recipes.configs.datasets import samsum_dataset\n",
243-
"from llama_recipes.data.concatenator import ConcatDataset\n",
244-
"from llama_recipes.utils.config_utils import get_dataloader_kwargs\n",
245-
"from llama_recipes.utils.dataset_utils import get_preprocessed_dataset\n",
246-
"\n",
247-
"train_dataset = get_preprocessed_dataset(tokenizer, samsum_dataset, 'train')\n",
248-
"\n",
249-
"train_dl_kwargs = get_dataloader_kwargs(train_config, train_dataset, tokenizer, \"train\")\n",
235+
"from llama_recipes.utils.dataset_utils import get_dataloader\n",
250236
"\n",
251-
"if train_config.batching_strategy == \"packing\":\n",
252-
" train_dataset = ConcatDataset(train_dataset, chunk_size=train_config.context_length)\n",
237+
"samsum_dataset.trust_remote_code = True\n",
253238
"\n",
254-
"# Create DataLoaders for the training and validation dataset\n",
255-
"train_dataloader = torch.utils.data.DataLoader(\n",
256-
" train_dataset,\n",
257-
" num_workers=train_config.num_workers_dataloader,\n",
258-
" pin_memory=True,\n",
259-
" **train_dl_kwargs,\n",
260-
")"
239+
"train_dataloader = get_dataloader(tokenizer, samsum_dataset, train_config)\n",
240+
"eval_dataloader = get_dataloader(tokenizer, samsum_dataset, train_config, \"val\")"
261241
]
262242
},
263243
{
@@ -310,17 +290,23 @@
310290
"name": "stderr",
311291
"output_type": "stream",
312292
"text": [
313-
"/home/ubuntu/miniconda3/envs/llama/lib/python3.11/site-packages/torch/cuda/memory.py:330: FutureWarning: torch.cuda.reset_max_memory_allocated now calls torch.cuda.reset_peak_memory_stats, which resets /all/ peak memory stats.\n",
293+
"/home/ubuntu/llama-recipes/src/llama_recipes/utils/train_utils.py:92: FutureWarning: `torch.cuda.amp.GradScaler(args...)` is deprecated. Please use `torch.amp.GradScaler('cuda', args...)` instead.\n",
294+
" scaler = torch.cuda.amp.GradScaler()\n",
295+
"/home/ubuntu/miniconda3/envs/llama/lib/python3.11/site-packages/torch/cuda/memory.py:343: FutureWarning: torch.cuda.reset_max_memory_allocated now calls torch.cuda.reset_peak_memory_stats, which resets /all/ peak memory stats.\n",
314296
" warnings.warn(\n",
315297
"Training Epoch: 1: 0%|\u001b[34m \u001b[0m| 0/319 [00:00<?, ?it/s]huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
316298
"To disable this warning, you can either:\n",
317299
"\t- Avoid using `tokenizers` before the fork if possible\n",
318300
"\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n",
319-
"/home/ubuntu/miniconda3/envs/llama/lib/python3.11/site-packages/torch/utils/checkpoint.py:464: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.4 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n",
320-
" warnings.warn(\n",
301+
"/home/ubuntu/llama-recipes/src/llama_recipes/utils/train_utils.py:151: FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.\n",
302+
" with autocast():\n",
303+
"/home/ubuntu/miniconda3/envs/llama/lib/python3.11/site-packages/torch/_dynamo/eval_frame.py:600: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.4 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n",
304+
" return fn(*args, **kwargs)\n",
321305
"/home/ubuntu/miniconda3/envs/llama/lib/python3.11/site-packages/bitsandbytes/autograd/_functions.py:316: UserWarning: MatMul8bitLt: inputs will be cast from torch.float32 to float16 during quantization\n",
322306
" warnings.warn(f\"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization\")\n",
323-
"Training Epoch: 1/1, step 1278/1279 completed (loss: 0.27870458364486694): : 320it [2:07:09, 23.84s/it] 3.94s/it] \n"
307+
"/home/ubuntu/miniconda3/envs/llama/lib/python3.11/site-packages/torch/utils/checkpoint.py:295: FutureWarning: `torch.cpu.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cpu', args...)` instead.\n",
308+
" with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs): # type: ignore[attr-defined]\n",
309+
"Training Epoch: 1/1, step 1278/1279 completed (loss: 0.28094857931137085): : 320it [2:08:50, 24.16s/it] 4.21s/it] \n"
324310
]
325311
},
326312
{
@@ -332,7 +318,7 @@
332318
"Peak active CUDA memory was 15 GB\n",
333319
"CUDA Malloc retries : 0\n",
334320
"CPU Total Peak Memory consumed during the train (max): 2 GB\n",
335-
"Epoch 1: train_perplexity=1.3403, train_epoch_loss=0.2929, epoch time 7630.169942979002s\n"
321+
"Epoch 1: train_perplexity=1.3404, train_epoch_loss=0.2930, epoch time 7730.981359725998s\n"
336322
]
337323
}
338324
],
@@ -354,7 +340,7 @@
354340
"results = train(\n",
355341
" model,\n",
356342
" train_dataloader,\n",
357-
" None,\n",
343+
" eval_dataloader,\n",
358344
" tokenizer,\n",
359345
" optimizer,\n",
360346
" scheduler,\n",
@@ -380,16 +366,7 @@
380366
"cell_type": "code",
381367
"execution_count": 7,
382368
"metadata": {},
383-
"outputs": [
384-
{
385-
"name": "stderr",
386-
"output_type": "stream",
387-
"text": [
388-
"/home/ubuntu/miniconda3/envs/llama/lib/python3.11/site-packages/huggingface_hub/file_download.py:1132: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.\n",
389-
" warnings.warn(\n"
390-
]
391-
}
392-
],
369+
"outputs": [],
393370
"source": [
394371
"model.save_pretrained(train_config.output_dir)"
395372
]
@@ -440,13 +417,13 @@
440417
"A: He said he’d name it after his dead hamster – Lemmy - he's a great Motorhead fan :-)))\n",
441418
"---\n",
442419
"Summary:\n",
443-
"A wants to get a puppy for her son. She will take him to the animal shelter tomorrow. B is not sure if he can go with her, but he's willing to.\n"
420+
"A wants to get a puppy for his son. A took him to the animal shelter last Monday and he showed A one he really liked. A wants to get him one of those little dogs. A and B agree that raising a dog is a tough issue.\n"
444421
]
445422
}
446423
],
447424
"source": [
448425
"model.eval()\n",
449-
"with torch.no_grad():\n",
426+
"with torch.inference_mode():\n",
450427
" print(tokenizer.decode(model.generate(**model_input, max_new_tokens=100)[0], skip_special_tokens=True))\n"
451428
]
452429
}
@@ -467,7 +444,7 @@
467444
"name": "python",
468445
"nbconvert_exporter": "python",
469446
"pygments_lexer": "ipython3",
470-
"version": "3.10.14"
447+
"version": "3.11.9"
471448
},
472449
"vscode": {
473450
"interpreter": {

requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ tabulate
2222
evaluate
2323
rouge_score
2424
pyyaml==6.0.1
25-
faiss-gpu
25+
faiss-gpu; python_version < '3.11'
2626
unstructured[pdf]
2727
langchain_openai
2828
langchain

src/llama_recipes/configs/fsdp.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ class fsdp_config:
1414
hsdp : bool =False # Require HYBRID_SHARD to be set. This flag can extend the HYBRID_SHARD by allowing sharding a model on customized number of GPUs (Sharding_group) and Replicas over Sharding_group.
1515
sharding_group_size : int=0 # requires hsdp to be set. This specifies the sharding group size, number of GPUs that you model can fit into to form a replica of a model.
1616
replica_group_size: int=0 #requires hsdp to be set. This specifies the replica group size, which is world_size/sharding_group_size.
17-
checkpoint_type: StateDictType = StateDictType.SHARDED_STATE_DICT # alternatively can use SHARDED_STATE_DICT save one file per rank, and can resize the world-size.
17+
checkpoint_type: StateDictType = StateDictType.SHARDED_STATE_DICT # alternatively FULL_STATE_DICT can be used. SHARDED_STATE_DICT saves one file with sharded weights per rank while FULL_STATE_DICT will collect all weights on rank 0 and save them in a single file.
1818
fsdp_activation_checkpointing: bool=True
1919
fsdp_cpu_offload: bool=False
2020
pure_bf16: bool = False
Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,18 @@
11
# Copyright (c) Meta Platforms, Inc. and affiliates.
22
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
33

4+
from functools import partial
5+
46
from llama_recipes.datasets.grammar_dataset.grammar_dataset import get_dataset as get_grammar_dataset
57
from llama_recipes.datasets.alpaca_dataset import InstructionDataset as get_alpaca_dataset
8+
from llama_recipes.datasets.custom_dataset import get_custom_dataset
69
from llama_recipes.datasets.samsum_dataset import get_preprocessed_samsum as get_samsum_dataset
7-
from llama_recipes.datasets.toxicchat_dataset import get_llamaguard_toxicchat_dataset as get_llamaguard_toxicchat_dataset
10+
from llama_recipes.datasets.toxicchat_dataset import get_llamaguard_toxicchat_dataset as get_llamaguard_toxicchat_dataset
11+
12+
DATASET_PREPROC = {
13+
"alpaca_dataset": partial(get_alpaca_dataset),
14+
"grammar_dataset": get_grammar_dataset,
15+
"samsum_dataset": get_samsum_dataset,
16+
"custom_dataset": get_custom_dataset,
17+
"llamaguard_toxicchat_dataset": get_llamaguard_toxicchat_dataset,
18+
}

src/llama_recipes/finetuning.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
generate_peft_config,
3838
generate_dataset_config,
3939
get_dataloader_kwargs,
40+
check_fsdp_config,
4041
)
4142
from llama_recipes.utils.dataset_utils import get_preprocessed_dataset
4243

@@ -162,6 +163,8 @@ def main(**kwargs):
162163

163164
#setting up FSDP if enable_fsdp is enabled
164165
if train_config.enable_fsdp:
166+
check_fsdp_config(fsdp_config)
167+
165168
if not train_config.use_peft and train_config.freeze_layers:
166169
freeze_transformer_layers(model, train_config.num_freeze_layers)
167170

src/llama_recipes/model_checkpointing/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,9 @@
33

44
from llama_recipes.model_checkpointing.checkpoint_handler import (
55
load_model_checkpoint,
6-
save_model_checkpoint,
6+
save_fsdp_model_checkpoint_full,
77
save_peft_checkpoint,
8+
save_model_checkpoint,
89
load_optimizer_checkpoint,
910
save_optimizer_checkpoint,
1011
save_model_and_optimizer_sharded,

src/llama_recipes/model_checkpointing/checkpoint_handler.py

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ def save_model_and_optimizer_sharded(model, rank, cfg,optim=None):
123123
print(
124124
f"Checkpoint Time = {t1-t0:.4f}\n"
125125
)
126-
def save_model_checkpoint(
126+
def save_fsdp_model_checkpoint_full(
127127
model,
128128
optimizer,
129129
rank,
@@ -152,7 +152,7 @@ def save_model_checkpoint(
152152
)
153153
save_dir = Path.cwd() / folder_name
154154
save_dir.mkdir(parents=True, exist_ok=True)
155-
save_name = cfg.model_name + "-" + str(epoch) + ".pt"
155+
save_name = cfg.model_name.replace("/","--") + "-" + str(epoch) + ".pt"
156156
save_full_path = str(save_dir) + "/" + save_name
157157

158158
# save model
@@ -271,6 +271,20 @@ def save_peft_checkpoint(model, model_path):
271271
"""save_pretrained peft model"""
272272

273273
options = StateDictOptions(full_state_dict=True, cpu_offload=True)
274-
275-
state_dict = get_model_state_dict(model, options=options)
276-
model.save_pretrained(model_path, state_dict=state_dict)
274+
275+
if isinstance(model, FSDP):
276+
state_dict = get_model_state_dict(model, options=options)
277+
model.save_pretrained(model_path, state_dict=state_dict)
278+
else:
279+
model.save_pretrained(model_path)
280+
281+
282+
def save_model_checkpoint(model, output_dir):
283+
"""save model when not peft and on single device"""
284+
285+
output_file = Path(output_dir) / "model.pt"
286+
287+
state_dict = model.state_dict()
288+
289+
torch.save(state_dict, output_file)
290+

src/llama_recipes/utils/config_utils.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from dataclasses import asdict
66

77
import torch.distributed as dist
8+
from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType
89
from torch.utils.data import DistributedSampler
910
from peft import (
1011
LoraConfig,
@@ -106,3 +107,18 @@ def get_dataloader_kwargs(train_config, dataset, tokenizer, mode):
106107
raise ValueError(f"Unknown batching strategy: {train_config.batching_strategy}")
107108

108109
return kwargs
110+
111+
112+
def check_fsdp_config(fsdp_config):
113+
VALID_TYPES = (StateDictType.SHARDED_STATE_DICT, StateDictType.FULL_STATE_DICT)
114+
if isinstance(fsdp_config.checkpoint_type, str):
115+
str_to_obj = {
116+
"StateDictType.SHARDED_STATE_DICT": StateDictType.SHARDED_STATE_DICT,
117+
"StateDictType.FULL_STATE_DICT": StateDictType.FULL_STATE_DICT,
118+
}
119+
if fsdp_config.checkpoint_type in str_to_obj:
120+
fsdp_config.checkpoint_type = str_to_obj[fsdp_config.checkpoint_type]
121+
122+
if not fsdp_config.checkpoint_type in VALID_TYPES:
123+
raise ValueError(f"Invalid checkpoint_type {fsdp_config.checkpoint_type}")
124+
Lines changed: 21 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -1,63 +1,11 @@
11
# Copyright (c) Meta Platforms, Inc. and affiliates.
22
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
33

4-
import importlib
5-
from functools import partial
6-
from pathlib import Path
7-
84
import torch
95

10-
from llama_recipes.datasets import (
11-
get_grammar_dataset,
12-
get_alpaca_dataset,
13-
get_samsum_dataset,
14-
get_llamaguard_toxicchat_dataset,
15-
)
16-
17-
18-
def load_module_from_py_file(py_file: str) -> object:
19-
"""
20-
This method loads a module from a py file which is not in the Python path
21-
"""
22-
module_name = Path(py_file).name
23-
loader = importlib.machinery.SourceFileLoader(module_name, py_file)
24-
spec = importlib.util.spec_from_loader(module_name, loader)
25-
module = importlib.util.module_from_spec(spec)
26-
27-
loader.exec_module(module)
28-
29-
return module
30-
31-
32-
def get_custom_dataset(dataset_config, tokenizer, split: str):
33-
if ":" in dataset_config.file:
34-
module_path, func_name = dataset_config.file.split(":")
35-
else:
36-
module_path, func_name = dataset_config.file, "get_custom_dataset"
37-
38-
if not module_path.endswith(".py"):
39-
raise ValueError(f"Dataset file {module_path} is not a .py file.")
40-
41-
module_path = Path(module_path)
42-
if not module_path.is_file():
43-
raise FileNotFoundError(f"Dataset py file {module_path.as_posix()} does not exist or is not a file.")
44-
45-
module = load_module_from_py_file(module_path.as_posix())
46-
try:
47-
return getattr(module, func_name)(dataset_config, tokenizer, split)
48-
except AttributeError as e:
49-
print(f"It seems like the given method name ({func_name}) is not present in the dataset .py file ({module_path.as_posix()}).")
50-
raise e
51-
52-
53-
DATASET_PREPROC = {
54-
"alpaca_dataset": partial(get_alpaca_dataset),
55-
"grammar_dataset": get_grammar_dataset,
56-
"samsum_dataset": get_samsum_dataset,
57-
"custom_dataset": get_custom_dataset,
58-
"llamaguard_toxicchat_dataset": get_llamaguard_toxicchat_dataset,
59-
60-
}
6+
from llama_recipes.data.concatenator import ConcatDataset
7+
from llama_recipes.datasets import DATASET_PREPROC, get_custom_dataset
8+
from llama_recipes.utils.config_utils import get_dataloader_kwargs
619

6210

6311
def get_preprocessed_dataset(
@@ -78,3 +26,21 @@ def get_split():
7826
tokenizer,
7927
get_split(),
8028
)
29+
30+
31+
def get_dataloader(tokenizer, dataset_config, train_config, split: str = "train"):
32+
dataset = get_preprocessed_dataset(tokenizer, dataset_config, split)
33+
dl_kwargs = get_dataloader_kwargs(train_config, dataset, tokenizer, split)
34+
35+
if split == "train" and train_config.batching_strategy == "packing":
36+
dataset = ConcatDataset(dataset, chunk_size=train_config.context_length)
37+
38+
# Create data loader
39+
dataloader = torch.utils.data.DataLoader(
40+
dataset,
41+
num_workers=train_config.num_workers_dataloader,
42+
pin_memory=True,
43+
**dl_kwargs,
44+
)
45+
return dataloader
46+

0 commit comments

Comments
 (0)