Skip to content

Commit a33d774

Browse files
authored
Merge pull request #182 from tnixon/a10_v100_config
A10 & v100 config
2 parents 5021d94 + cd67192 commit a33d774

File tree

6 files changed

+171
-51
lines changed

6 files changed

+171
-51
lines changed

README.md

Lines changed: 25 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -94,40 +94,47 @@ Otherwise, follow the steps above. The 12B param model may not function well in
9494
### Training on Other Instances
9595

9696
A100 instance types are not available in all cloud regions, or can be hard to provision. Training is possible on other GPU instance types,
97-
for smaller Dolly model sizes, and with small modifications to reduce memory usage.
98-
These modifications are not optimal, but are simple to make.
97+
for smaller Dolly model sizes, and with small modifications to reduce memory usage. These modifications are not optimal, but are simple to make.
98+
99+
Select your GPU family type from the `gpu_family` widget, enter the number of GPUs available in the `num_gpus` widget, and then run the rest of the code.
100+
A number of different options will be set for you to train the model for one of the following GPU types:
101+
- A100 (default)
102+
- A10
103+
- V100
104+
105+
Details of the different configurations are below.
106+
107+
#### A100 GPUs
108+
109+
A100 GPUs are preferred for training all model sizes, and are the only GPUs that can train the 12B param model in a reasonable amount of time.
110+
As such, this is the default configuration, as set in the `a100_config.json` deepspeed config file.
99111

100112
#### A10 GPUs
101113

102114
Training the 12B param model is not recommended on A10s.
103115

104-
To train the 6.9B param model on A10 instances (ex: `g5.24xlarge`, 4 x A10 24GB; `Standard_NV72ads_A10_v5`, 2 x A10), make the following changes:
116+
To train the 6.9B param model on A10 instances (ex: `g5.24xlarge`, 4 x A10 24GB; `Standard_NV72ads_A10_v5`, 2 x A10),
117+
simply select `a10` from the `gpu_family` widget and enter the number of GPUs available in the `num_gpus` widget, then run the rest of the code.
118+
This will use the `a10_config.json` deepspeed config file, which makes the following changes:
105119

106-
- Set `per-device-train-batch-size` and `per-device-eval-batch-size` to 3 in the `train_dolly.py` invocation of `deepspeed`
107-
- Modify the deepspeed config file `ds_z3_bf16_config.json` to configure optimizer offload. Within the `"zero_optimization"` section, add:
120+
- `per-device-train-batch-size` and `per-device-eval-batch-size` are set to 3 in the `train_dolly.py` invocation of `deepspeed`
121+
- Within the `"zero_optimization"` section of the deepspeed config, we have added:
108122
```
109123
"offload_optimizer": {
110124
"device": "cpu",
111125
"pin_memory": true
112126
},
113127
```
114-
- Set the `num_gpus` widget in `train_dolly` to the number of GPUs in your instance, such as 2 or 4, before running
115-
116-
To train the 2.8B param model:
117-
118-
- Instead, only set `per-device-train-batch-size` and `per-device-eval-batch-size` to 3 in the `train_dolly.py` invocation of `deepspeed`
119128

120129
#### V100 GPUs
121130

122-
To run on V100 instances with 32GB of GPU memory (ex: `p3dn.24xlarge` or `Standard_ND40rs_v2`), follow instructions above, and add:
131+
To run on V100 instances with 32GB of GPU memory (ex: `p3dn.24xlarge` or `Standard_ND40rs_v2`),
132+
simply select `v100` from the `gpu_family` widget and enter the number of GPUs available in the `num_gpus` widget, and then run the rest of the code.
133+
This will use the `v100_config.json` deepspeed config file, which makes the following changes:
123134

124-
- Modify `training/trainer.py` to disable `bf16` and enable `fp16` in `TrainingArguments`:
125-
```
126-
...
127-
fp16=True,
128-
bf16=False,
129-
...
130-
```
135+
- It makes the changes described above for A10s
136+
- It enables fp16 floating point format
137+
- It sets the `per-device-train-batch-size` and `per-device-eval-batch-size` to 3
131138

132139
You may be able to slightly increase the batch size with 32GB instances, compared to what works above for 24GB A10s.
133140

config/a10_config.json

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
{
2+
"bf16": {
3+
"enabled": "auto"
4+
},
5+
"optimizer": {
6+
"type": "AdamW",
7+
"params": {
8+
"lr": "auto",
9+
"betas": "auto",
10+
"eps": "auto",
11+
"weight_decay": "auto"
12+
}
13+
},
14+
"scheduler": {
15+
"type": "WarmupLR",
16+
"params": {
17+
"warmup_min_lr": "auto",
18+
"warmup_max_lr": "auto",
19+
"warmup_num_steps": "auto"
20+
}
21+
},
22+
"zero_optimization": {
23+
"stage": 3,
24+
"overlap_comm": true,
25+
"contiguous_gradients": true,
26+
"sub_group_size": 1e9,
27+
"reduce_bucket_size": "auto",
28+
"stage3_prefetch_bucket_size": "auto",
29+
"stage3_param_persistence_threshold": "auto",
30+
"stage3_max_live_parameters": 1e9,
31+
"stage3_max_reuse_distance": 1e9,
32+
"stage3_gather_16bit_weights_on_model_save": true,
33+
"offload_optimizer": {
34+
"device": "cpu",
35+
"pin_memory": true
36+
}
37+
},
38+
"gradient_accumulation_steps": "auto",
39+
"gradient_clipping": "auto",
40+
"steps_per_print": 2000,
41+
"train_batch_size": "auto",
42+
"train_micro_batch_size_per_gpu": "auto",
43+
"wall_clock_breakdown": false
44+
}

config/v100_config.json

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
{
2+
"fp16": {
3+
"enabled": true
4+
},
5+
"optimizer": {
6+
"type": "AdamW",
7+
"params": {
8+
"lr": "auto",
9+
"betas": "auto",
10+
"eps": "auto",
11+
"weight_decay": "auto"
12+
}
13+
},
14+
"scheduler": {
15+
"type": "WarmupLR",
16+
"params": {
17+
"warmup_min_lr": "auto",
18+
"warmup_max_lr": "auto",
19+
"warmup_num_steps": "auto"
20+
}
21+
},
22+
"zero_optimization": {
23+
"stage": 3,
24+
"overlap_comm": true,
25+
"contiguous_gradients": true,
26+
"sub_group_size": 1e9,
27+
"reduce_bucket_size": "auto",
28+
"stage3_prefetch_bucket_size": "auto",
29+
"stage3_param_persistence_threshold": "auto",
30+
"stage3_max_live_parameters": 1e9,
31+
"stage3_max_reuse_distance": 1e9,
32+
"stage3_gather_16bit_weights_on_model_save": true,
33+
"offload_optimizer": {
34+
"device": "cpu",
35+
"pin_memory": true
36+
}
37+
},
38+
"gradient_accumulation_steps": "auto",
39+
"gradient_clipping": "auto",
40+
"steps_per_print": 2000,
41+
"train_batch_size": "auto",
42+
"train_micro_batch_size_per_gpu": "auto",
43+
"wall_clock_breakdown": false
44+
}

train_dolly.py

Lines changed: 53 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -50,10 +50,10 @@
5050

5151
# COMMAND ----------
5252

53-
#!wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2004/x86_64/libcusparse-dev-11-7_11.7.3.50-1_amd64.deb -O /tmp/libcusparse-dev-11-7_11.7.3.50-1_amd64.deb && \
54-
# wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2004/x86_64/libcublas-dev-11-7_11.10.1.25-1_amd64.deb -O /tmp/libcublas-dev-11-7_11.10.1.25-1_amd64.deb && \
55-
# wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2004/x86_64/libcusolver-dev-11-7_11.4.0.1-1_amd64.deb -O /tmp/libcusolver-dev-11-7_11.4.0.1-1_amd64.deb && \
56-
# wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2004/x86_64/libcurand-dev-11-7_10.2.10.91-1_amd64.deb -O /tmp/libcurand-dev-11-7_10.2.10.91-1_amd64.deb && \
53+
#!wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64/libcusparse-dev-11-7_11.7.3.50-1_amd64.deb -O /tmp/libcusparse-dev-11-7_11.7.3.50-1_amd64.deb && \
54+
# wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64/libcublas-dev-11-7_11.10.1.25-1_amd64.deb -O /tmp/libcublas-dev-11-7_11.10.1.25-1_amd64.deb && \
55+
# wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64/libcusolver-dev-11-7_11.4.0.1-1_amd64.deb -O /tmp/libcusolver-dev-11-7_11.4.0.1-1_amd64.deb && \
56+
# wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64/libcurand-dev-11-7_10.2.10.91-1_amd64.deb -O /tmp/libcurand-dev-11-7_10.2.10.91-1_amd64.deb && \
5757
# dpkg -i /tmp/libcusparse-dev-11-7_11.7.3.50-1_amd64.deb && \
5858
# dpkg -i /tmp/libcublas-dev-11-7_11.10.1.25-1_amd64.deb && \
5959
# dpkg -i /tmp/libcusolver-dev-11-7_11.4.0.1-1_amd64.deb && \
@@ -91,6 +91,7 @@
9191
dbutils.widgets.text("local_training_root", "", "local_training_root")
9292
dbutils.widgets.text("dbfs_output_root", "", "dbfs_output_root")
9393
dbutils.widgets.text("experiment_id", "", "experiment_id")
94+
dbutils.widgets.combobox("gpu_family", "a100", ["v100", "a10", "a100"])
9495

9596
# COMMAND ----------
9697

@@ -112,9 +113,6 @@
112113

113114
checkpoint_dir_name = f"{model_name}__{timestamp}"
114115

115-
root_path = os.getcwd()
116-
deepspeed_config = os.path.join(root_path, "config/ds_z3_bf16_config.json")
117-
118116
dolly_training_dir_name = "dolly_training"
119117

120118
# Use the local training root path if it was provided. Otherwise try to find a sensible default.
@@ -136,19 +134,32 @@
136134

137135
local_output_dir = os.path.join(local_training_root, checkpoint_dir_name)
138136
dbfs_output_dir = os.path.join(dbfs_output_root, checkpoint_dir_name)
137+
tensorboard_display_dir = f"{local_output_dir}/runs"
138+
139+
print(f"Local Output Dir: {local_output_dir}")
140+
print(f"DBFS Output Dir: {dbfs_output_dir}")
141+
print(f"Tensorboard Display Dir: {tensorboard_display_dir}")
142+
143+
# pick an appropriate config file
144+
gpu_family = dbutils.widgets.get("gpu_family")
145+
config_file_name = f"{gpu_family}_config.json"
146+
deepspeed_config = os.path.join(os.getcwd(), "config", config_file_name)
147+
print(f"Deepspeed config file: {deepspeed_config}")
148+
149+
# configure the batch_size
150+
batch_size = 3
151+
if gpu_family == "a10":
152+
batch_size = 4
153+
elif gpu_family == "a100":
154+
batch_size = 6
139155

156+
# configure num_gpus, if specified
140157
num_gpus_flag = ""
141158
num_gpus = dbutils.widgets.get("num_gpus")
142159
if num_gpus:
143160
num_gpus = int(num_gpus)
144161
num_gpus_flag = f"--num_gpus={num_gpus}"
145162

146-
tensorboard_display_dir = f"{local_output_dir}/runs"
147-
148-
print(f"Local Output Dir: {local_output_dir}")
149-
print(f"DBFS Output Dir: {dbfs_output_dir}")
150-
print(f"Tensorboard Display Dir: {tensorboard_display_dir}")
151-
152163
os.environ["TOKENIZERS_PARALLELISM"] = "false"
153164

154165
# COMMAND ----------
@@ -158,28 +169,28 @@
158169

159170
# COMMAND ----------
160171

161-
# MAGIC !deepspeed {num_gpus_flag} \
162-
# MAGIC --module training.trainer \
163-
# MAGIC --input-model {input_model} \
164-
# MAGIC --deepspeed {deepspeed_config} \
165-
# MAGIC --epochs 2 \
166-
# MAGIC --local-output-dir {local_output_dir} \
167-
# MAGIC --dbfs-output-dir {dbfs_output_dir} \
168-
# MAGIC --per-device-train-batch-size 6 \
169-
# MAGIC --per-device-eval-batch-size 6 \
170-
# MAGIC --logging-steps 10 \
171-
# MAGIC --save-steps 200 \
172-
# MAGIC --save-total-limit 20 \
173-
# MAGIC --eval-steps 50 \
174-
# MAGIC --warmup-steps 50 \
175-
# MAGIC --test-size 200 \
176-
# MAGIC --lr 5e-6
172+
!deepspeed {num_gpus_flag} \
173+
--module training.trainer \
174+
--input-model {input_model} \
175+
--deepspeed {deepspeed_config} \
176+
--epochs 2 \
177+
--local-output-dir {local_output_dir} \
178+
--dbfs-output-dir {dbfs_output_dir} \
179+
--per-device-train-batch-size {batch_size} \
180+
--per-device-eval-batch-size {batch_size} \
181+
--logging-steps 10 \
182+
--save-steps 200 \
183+
--save-total-limit 20 \
184+
--eval-steps 50 \
185+
--warmup-steps 50 \
186+
--test-size 200 \
187+
--lr 5e-6
177188

178189
# COMMAND ----------
179190

180191
from training.generate import generate_response, load_model_tokenizer_for_generate
181192

182-
model, tokenizer = load_model_tokenizer_for_generate(local_output_dir)
193+
model, tokenizer = load_model_tokenizer_for_generate(dbfs_output_dir)
183194

184195
# COMMAND ----------
185196

@@ -192,8 +203,19 @@
192203
"Give me a list of 5 science fiction books I should read next.",
193204
]
194205

206+
# set some additional pipeline args
207+
pipeline_kwargs = {'torch_dtype': "auto"}
208+
if gpu_family == "v100":
209+
pipeline_kwargs['torch_dtype'] = "float16"
210+
elif gpu_family == "a10" or gpu_family == "a100":
211+
pipeline_kwargs['torch_dtype'] = "bfloat16"
212+
195213
# Use the model to generate responses for each of the instructions above.
196214
for instruction in instructions:
197-
response = generate_response(instruction, model=model, tokenizer=tokenizer)
215+
response = generate_response(instruction, model=model, tokenizer=tokenizer, **pipeline_kwargs)
198216
if response:
199217
print(f"Instruction: {instruction}\n\n{response}\n\n-----------\n")
218+
219+
# COMMAND ----------
220+
221+

training/trainer.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -232,14 +232,17 @@ def train(
232232
tokenizer=tokenizer, mlm=False, return_tensors="pt", pad_to_multiple_of=8
233233
)
234234

235+
# enable fp16 if not bf16
236+
fp16 = not bf16
237+
235238
if not dbfs_output_dir:
236239
logger.warn("Will NOT save to DBFS")
237240

238241
training_args = TrainingArguments(
239242
output_dir=local_output_dir,
240243
per_device_train_batch_size=per_device_train_batch_size,
241244
per_device_eval_batch_size=per_device_eval_batch_size,
242-
fp16=False,
245+
fp16=fp16,
243246
bf16=bf16,
244247
learning_rate=lr,
245248
num_train_epochs=epochs,
@@ -316,7 +319,7 @@ def train(
316319
default=True,
317320
help="Provided by deepspeed to identify which instance this process is when performing multi-GPU training.",
318321
)
319-
@click.option("--bf16", type=bool, default=True, help="Whether to use bf16 (preferred on A100's).")
322+
@click.option("--bf16", type=bool, default=None, help="Whether to use bf16 (preferred on A100's).")
320323
def main(**kwargs):
321324
train(**kwargs)
322325

0 commit comments

Comments
 (0)