Skip to content

Commit 26e877f

Browse files
committed
changed readme, unified the context interface and added get_flops_per_sec()
1 parent d9558c1 commit 26e877f

File tree

6 files changed

+165
-130
lines changed

6 files changed

+165
-130
lines changed

docs/multi_gpu.md

Lines changed: 41 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -115,35 +115,47 @@ torchrun --nnodes 1 --nproc_per_node 4 examples/finetuning.py --enable_fsdp --m
115115
It lets us specify the training settings for everything from `model_name` to `dataset_name`, `batch_size` and so on. Below is the list of supported settings:
116116

117117
```python
118-
119-
model_name: str="PATH/to/LLAMA 2/7B"
120-
enable_fsdp: bool= False
121-
run_validation: bool=True
122-
batch_size_training: int=4
123-
gradient_accumulation_steps: int=1
124-
num_epochs: int=3
125-
num_workers_dataloader: int=2
126-
lr: float=2e-4
127-
weight_decay: float=0.0
128-
gamma: float= 0.85
129-
use_fp16: bool=False
130-
mixed_precision: bool=True
131-
val_batch_size: int=4
132-
dataset = "samsum_dataset" # alpaca_dataset, grammar_dataset
133-
peft_method: str = "lora" # None , llama_adapter, prefix
134-
use_peft: bool=False
135-
output_dir: str = "./ft-output"
136-
freeze_layers: bool = False
137-
num_freeze_layers: int = 1
138-
quantization: bool = False
139-
save_model: bool = False
140-
dist_checkpoint_root_folder: str="model_checkpoints"
141-
dist_checkpoint_folder: str="fine-tuned"
142-
save_optimizer: bool=False
143-
flop_counter: bool=False # Enable FLOPS counter to measure model throughput, can not be used with pytorch profiler at the same time.
144-
flop_counter_start: int=3 # The step to start profiling, default is 3, which means after 3 steps of warm-up stage, the profiler will start to count FLOPS.
145-
use_profiler: bool=False # Enable pytorch profiler, can not be used with FLOPS counter at the same time.
146-
profiler_dir: str="PATH/to/save/profiler/results" # will be used if using profiler
118+
model_name: str="PATH/to/Model"
119+
tokenizer_name: str=None
120+
enable_fsdp: bool=False
121+
low_cpu_fsdp: bool=False
122+
run_validation: bool=True
123+
batch_size_training: int=4
124+
batching_strategy: str="packing" #alternative: padding
125+
context_length: int=4096
126+
gradient_accumulation_steps: int=1
127+
gradient_clipping: bool = False
128+
gradient_clipping_threshold: float = 1.0
129+
num_epochs: int=3
130+
max_train_step: int=0
131+
max_eval_step: int=0
132+
num_workers_dataloader: int=1
133+
lr: float=1e-4
134+
weight_decay: float=0.0
135+
gamma: float= 0.85
136+
seed: int=42
137+
use_fp16: bool=False
138+
mixed_precision: bool=True
139+
val_batch_size: int=1
140+
dataset = "samsum_dataset"
141+
peft_method: str = "lora" # None,llama_adapter, prefix
142+
use_peft: bool=False
143+
output_dir: str = "PATH/to/save/PEFT/model"
144+
freeze_layers: bool = False
145+
num_freeze_layers: int = 1
146+
quantization: bool = False
147+
one_gpu: bool = False
148+
save_model: bool = True
149+
dist_checkpoint_root_folder: str="PATH/to/save/FSDP/model" # will be used if using FSDP
150+
dist_checkpoint_folder: str="fine-tuned" # will be used if using FSDP
151+
save_optimizer: bool=False # will be used if using FSDP
152+
use_fast_kernels: bool = False # Enable using SDPA from PyTroch Accelerated Transformers, make use Flash Attention and Xformer memory-efficient kernels
153+
use_wandb: bool = False # Enable wandb for experient tracking
154+
save_metrics: bool = False # saves training metrics to a json file for later plotting
155+
flop_counter: bool = False # Enable flop counter to measure model throughput, can not be used with pytorch profiler at the same time.
156+
flop_counter_start: int = 3 # The step to start profiling, default is 3, which means after 3 steps of warmup stage, the profiler will start to count flops.
157+
use_profiler: bool = False # Enable pytorch profiler, can not be used with flop counter at the same time.
158+
profiler_dir: str = "PATH/to/save/profiler/results" # will be used if using profiler
147159
```
148160

149161
* [Datasets config file](../src/llama_recipes/configs/datasets.py) provides the available options for datasets.

docs/single_gpu.md

Lines changed: 41 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -71,36 +71,47 @@ python -m llama_recipes.finetuning --use_peft --peft_method lora --quantization
7171
It let us specify the training settings, everything from `model_name` to `dataset_name`, `batch_size` etc. can be set here. Below is the list of supported settings:
7272

7373
```python
74-
75-
model_name: str="PATH/to/LLAMA 2/7B"
76-
enable_fsdp: bool= False
77-
run_validation: bool=True
78-
batch_size_training: int=4
79-
gradient_accumulation_steps: int=1
80-
num_epochs: int=3
81-
num_workers_dataloader: int=2
82-
lr: float=2e-4
83-
weight_decay: float=0.0
84-
gamma: float= 0.85
85-
use_fp16: bool=False
86-
mixed_precision: bool=True
87-
val_batch_size: int=4
88-
dataset = "samsum_dataset" # alpaca_dataset,grammar_dataset
89-
peft_method: str = "lora" # None , llama_adapter, prefix
90-
use_peft: bool=False
91-
output_dir: str = "./ft-output"
92-
freeze_layers: bool = False
93-
num_freeze_layers: int = 1
94-
quantization: bool = False
95-
one_gpu: bool = False
96-
save_model: bool = False
97-
dist_checkpoint_root_folder: str="model_checkpoints"
98-
dist_checkpoint_folder: str="fine-tuned"
99-
save_optimizer: bool=False
100-
flop_counter: bool=False # Enable FLOPS counter to measure model throughput, can not be used with pytorch profiler at the same time.
101-
flop_counter_start: int=3 # The step to start profiling, default is 3, which means after 3 steps of warm-up stage, the profiler will start to count FLOPS.
102-
use_profiler: bool=False # Enable pytorch profiler, can not be used with FLOPS counter at the same time.
103-
profiler_dir: str="PATH/to/save/profiler/results" # will be used if using profiler
74+
model_name: str="PATH/to/Model"
75+
tokenizer_name: str=None
76+
enable_fsdp: bool=False
77+
low_cpu_fsdp: bool=False
78+
run_validation: bool=True
79+
batch_size_training: int=4
80+
batching_strategy: str="packing" #alternative: padding
81+
context_length: int=4096
82+
gradient_accumulation_steps: int=1
83+
gradient_clipping: bool = False
84+
gradient_clipping_threshold: float = 1.0
85+
num_epochs: int=3
86+
max_train_step: int=0
87+
max_eval_step: int=0
88+
num_workers_dataloader: int=1
89+
lr: float=1e-4
90+
weight_decay: float=0.0
91+
gamma: float= 0.85
92+
seed: int=42
93+
use_fp16: bool=False
94+
mixed_precision: bool=True
95+
val_batch_size: int=1
96+
dataset = "samsum_dataset"
97+
peft_method: str = "lora" # None,llama_adapter, prefix
98+
use_peft: bool=False
99+
output_dir: str = "PATH/to/save/PEFT/model"
100+
freeze_layers: bool = False
101+
num_freeze_layers: int = 1
102+
quantization: bool = False
103+
one_gpu: bool = False
104+
save_model: bool = True
105+
dist_checkpoint_root_folder: str="PATH/to/save/FSDP/model" # will be used if using FSDP
106+
dist_checkpoint_folder: str="fine-tuned" # will be used if using FSDP
107+
save_optimizer: bool=False # will be used if using FSDP
108+
use_fast_kernels: bool = False # Enable using SDPA from PyTroch Accelerated Transformers, make use Flash Attention and Xformer memory-efficient kernels
109+
use_wandb: bool = False # Enable wandb for experient tracking
110+
save_metrics: bool = False # saves training metrics to a json file for later plotting
111+
flop_counter: bool = False # Enable flop counter to measure model throughput, can not be used with pytorch profiler at the same time.
112+
flop_counter_start: int = 3 # The step to start profiling, default is 3, which means after 3 steps of warmup stage, the profiler will start to count flops.
113+
use_profiler: bool = False # Enable pytorch profiler, can not be used with flop counter at the same time.
114+
profiler_dir: str = "PATH/to/save/profiler/results" # will be used if using profiler
104115
```
105116

106117
* [Datasets config file](../src/llama_recipes/configs/datasets.py) provides the available options for datasets.

recipes/finetuning/README.md

Lines changed: 41 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -23,37 +23,47 @@ If you are new to fine-tuning techniques, check out an overview: [](./LLM_finetu
2323
It lets us specify the training settings for everything from `model_name` to `dataset_name`, `batch_size` and so on. Below is the list of supported settings:
2424

2525
```python
26-
27-
model_name: str="PATH/to/LLAMA 2/7B"
28-
enable_fsdp: bool=False
29-
run_validation: bool=True
30-
batch_size_training: int=4
31-
gradient_accumulation_steps: int=1
32-
max_train_step: int=0
33-
max_eval_step: int=0
34-
num_epochs: int=3
35-
num_workers_dataloader: int=2
36-
lr: float=2e-4
37-
weight_decay: float=0.0
38-
gamma: float=0.85
39-
use_fp16: bool=False
40-
mixed_precision: bool=True
41-
val_batch_size: int=4
42-
dataset = "samsum_dataset" # alpaca_dataset, grammar_dataset
43-
peft_method: str="lora" # None , llama_adapter, prefix
44-
use_peft: bool=False
45-
output_dir: str="./ft-output"
46-
freeze_layers: bool = False
47-
num_freeze_layers: int = 1
48-
quantization: bool = False
49-
save_model: bool = False
50-
dist_checkpoint_root_folder: str="model_checkpoints"
51-
dist_checkpoint_folder: str="fine-tuned"
52-
save_optimizer: bool=False
53-
flop_counter: bool=False # Enable FLOPS counter to measure model throughput, can not be used with pytorch profiler at the same time.
54-
flop_counter_start: int=3 # The step to start profiling, default is 3, which means after 3 steps of warm-up stage, the profiler will start to count FLOPS.
55-
use_profiler: bool=False # Enable pytorch profiler, can not be used with FLOPS counter at the same time.
56-
profiler_dir: str="PATH/to/save/profiler/results" # will be used if using profiler
26+
model_name: str="PATH/to/Model"
27+
tokenizer_name: str=None
28+
enable_fsdp: bool=False
29+
low_cpu_fsdp: bool=False
30+
run_validation: bool=True
31+
batch_size_training: int=4
32+
batching_strategy: str="packing" #alternative: padding
33+
context_length: int=4096
34+
gradient_accumulation_steps: int=1
35+
gradient_clipping: bool = False
36+
gradient_clipping_threshold: float = 1.0
37+
num_epochs: int=3
38+
max_train_step: int=0
39+
max_eval_step: int=0
40+
num_workers_dataloader: int=1
41+
lr: float=1e-4
42+
weight_decay: float=0.0
43+
gamma: float= 0.85
44+
seed: int=42
45+
use_fp16: bool=False
46+
mixed_precision: bool=True
47+
val_batch_size: int=1
48+
dataset = "samsum_dataset"
49+
peft_method: str = "lora" # None,llama_adapter, prefix
50+
use_peft: bool=False
51+
output_dir: str = "PATH/to/save/PEFT/model"
52+
freeze_layers: bool = False
53+
num_freeze_layers: int = 1
54+
quantization: bool = False
55+
one_gpu: bool = False
56+
save_model: bool = True
57+
dist_checkpoint_root_folder: str="PATH/to/save/FSDP/model" # will be used if using FSDP
58+
dist_checkpoint_folder: str="fine-tuned" # will be used if using FSDP
59+
save_optimizer: bool=False # will be used if using FSDP
60+
use_fast_kernels: bool = False # Enable using SDPA from PyTroch Accelerated Transformers, make use Flash Attention and Xformer memory-efficient kernels
61+
use_wandb: bool = False # Enable wandb for experient tracking
62+
save_metrics: bool = False # saves training metrics to a json file for later plotting
63+
flop_counter: bool = False # Enable flop counter to measure model throughput, can not be used with pytorch profiler at the same time.
64+
flop_counter_start: int = 3 # The step to start profiling, default is 3, which means after 3 steps of warmup stage, the profiler will start to count flops.
65+
use_profiler: bool = False # Enable pytorch profiler, can not be used with flop counter at the same time.
66+
profiler_dir: str = "PATH/to/save/profiler/results" # will be used if using profiler
5767
```
5868

5969
* [Datasets config file](../../src/llama_recipes/configs/datasets.py) provides the available options for datasets.

src/llama_recipes/configs/training.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
@dataclass
88
class train_config:
9-
model_name: str="PATH/to/LLAMA/7B"
9+
model_name: str="PATH/to/Model"
1010
tokenizer_name: str=None
1111
enable_fsdp: bool=False
1212
low_cpu_fsdp: bool=False
@@ -29,7 +29,7 @@ class train_config:
2929
mixed_precision: bool=True
3030
val_batch_size: int=1
3131
dataset = "samsum_dataset"
32-
peft_method: str = "lora" # None , llama_adapter, prefix
32+
peft_method: str = "lora" # None,llama_adapter, prefix
3333
use_peft: bool=False
3434
output_dir: str = "PATH/to/save/PEFT/model"
3535
freeze_layers: bool = False

src/llama_recipes/utils/flop_utils.py

Lines changed: 34 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from typing import Any, Dict, List, Optional, Union
2-
2+
import time
33
import torch
44
from torch.utils.flop_counter import FlopCounterMode
55

@@ -15,14 +15,12 @@ class FlopMeasure(FlopCounterMode):
1515
1616
.. code-block:: python
1717
18-
mod = ...
19-
flop_counter = FlopMeasure(mod)
18+
model = ...
19+
flop_counter = FlopMeasure(model,local_rank=0,warmup_step=3)
2020
for batch in enumerate(dataloader):
2121
with flop_counter:
22-
if step == 3:
23-
flop_counter.start_counting()
24-
mod(batch)
25-
flop_counter.stop_counting()
22+
model(batch)
23+
flop_counter.step()
2624
"""
2725

2826
def __init__(
@@ -32,50 +30,58 @@ def __init__(
3230
display: bool = True,
3331
custom_mapping: Dict[Any, Any] = None,
3432
rank=None,
33+
warmup_step: int = 3,
3534
):
3635
super().__init__(mods, depth, display, custom_mapping)
37-
self.ready = False
3836
self.rank = rank
37+
self.warmup_step = warmup_step
38+
self.start_time = 0
39+
self.end_time = 0
3940

41+
def step(self):
42+
# decrease the warmup step by 1 for every step, so that the flop counting will start when warmup_step =0. Stop decreasing when warm_up reaches -1.
43+
if self.warmup_step >= 0:
44+
self.warmup_step -= 1
45+
if self.warmup_step == 0 and self.start_time == 0:
46+
self.start_time = time.time()
47+
elif self.warmup_step == -1 and self.start_time != 0 and self.end_time == 0:
48+
self.end_time = time.time()
4049
def __enter__(self):
41-
self.ready = False
50+
if self.warmup_step == 0:
51+
self.start_time = time.time()
4252
super().__enter__()
4353
return self
44-
54+
def is_done(self):
55+
return self.warmup_step == -1
4556
def get_total_flops(self):
4657
return super().get_total_flops()
47-
58+
def get_flops_per_sec(self):
59+
if self.start_time == 0 or self.end_time == 0:
60+
print("Warning: flop count did not finish correctly")
61+
return 0
62+
return super().get_total_flops()/ (self.end_time - self.start_time)
4863
def get_table(self, depth=2):
4964
return super().get_table(depth)
5065

5166
def __exit__(self, *args):
52-
self.ready = False
5367
if self.get_total_flops() == 0:
5468
print(
5569
"Warning: did not record any flops this time. Skipping the flop report"
5670
)
5771
else:
58-
self.stop_counting()
5972
if self.display:
6073
if self.rank is None or self.rank == 0:
61-
print("self.flop_counts", self.get_total_flops())
74+
print("Total time used in this flop counting step is: {}".format(self.end_time - self.start_time))
75+
print("The total TFlop per second is: {}".format(self.get_flops_per_sec() / 1e12))
76+
print("The tflop_count table is below:")
6277
print(self.get_table(self.depth))
6378
# Disable the display feature so that we don't print the table again
6479
self.display = False
6580
super().__exit__(*args)
6681

67-
def start_counting(self):
68-
self.ready = True
69-
70-
def is_ready(self):
71-
return self.ready
72-
73-
def stop_counting(self):
74-
self.ready = False
75-
7682
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
77-
# return the original output if not ready
78-
if not self.ready:
79-
return func(*args, **kwargs)
80-
# otherwise, count the flops and return the original output
81-
return super().__torch_dispatch__(func, types, args, kwargs)
83+
# when warmup_step is 0, count the flops and return the original output
84+
if self.warmup_step == 0:
85+
return super().__torch_dispatch__(func, types, args, kwargs)
86+
# otherwise, just return the original output
87+
return func(*args, **kwargs)

src/llama_recipes/utils/train_utils.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -59,9 +59,9 @@ def profile(cfg, local_rank=None):
5959
) as torch_profiler:
6060
yield torch_profiler
6161
elif use_flop_counter:
62-
if cfg.max_train_step > 0 and cfg.max_train_step < cfg.flop_counter_start:
63-
raise ValueError(f"flop counter requires at least {cfg.flop_counter_start} train steps, please increase the max_train_step, current max_train_step {cfg.max_train_step}")
64-
with FlopMeasure(rank=local_rank) as flop_counter:
62+
if cfg.max_train_step > 0 and cfg.max_train_step <= cfg.flop_counter_start:
63+
raise ValueError(f"flop counter requires at least {cfg.flop_counter_start + 1} train steps, please increase the max_train_step, current max_train_step {cfg.max_train_step}")
64+
with FlopMeasure(rank=local_rank,warmup_step=cfg.flop_counter_start) as flop_counter:
6565
yield flop_counter
6666
else:
6767
torch_profiler = contextlib.nullcontext()
@@ -135,9 +135,6 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
135135
if not train_config.enable_fsdp or local_rank==0:
136136
print("max training steps reached, stopping training, total train steps finished: ", total_train_steps-1)
137137
break
138-
if train_config.flop_counter and total_train_steps == train_config.flop_counter_start:
139-
print("start flop counting at the step: ", total_train_steps)
140-
profile_context.start_counting()
141138
for key in batch.keys():
142139
if train_config.enable_fsdp:
143140
if is_xpu_available():
@@ -183,11 +180,10 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
183180
optimizer.step()
184181
optimizer.zero_grad()
185182
pbar.update(1)
186-
if train_config.use_profiler:
183+
if train_config.use_profiler or train_config.flop_counter:
187184
profile_context.step()
188-
if train_config.flop_counter and profile_context.is_ready():
189-
TFlops = profile_context.get_total_flops() / 1e12
190-
profile_context.stop_counting()
185+
if train_config.flop_counter and profile_context.is_done():
186+
TFlops = profile_context.get_flops_per_sec() / 1e12
191187
if wandb_run:
192188
if not train_config.enable_fsdp or rank==0:
193189
wandb_run.log({

0 commit comments

Comments
 (0)