Skip to content

Commit 3ecb8d9

Browse files
authored
Update notes for training on A10/V100 (#94)
This is a quick pass as simplistic tweaks needed to train the smaller Dolly sizes on A10 or V100 instance types
1 parent 662d181 commit 3ecb8d9

File tree

1 file changed

+11
-12
lines changed

1 file changed

+11
-12
lines changed

README.md

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -76,20 +76,21 @@ instruct_pipeline = pipeline(model="databricks/dolly-v2-12b", torch_dtype=torch.
7676

7777
## Getting Started with Training
7878

79-
The following instructions refer to Dolly v1 and still need to be updated for v2 training.
80-
8179
* Add the `dolly` repo to Databricks (under Repos click Add Repo, enter `https://github.com/databrickslabs/dolly.git`, then click Create Repo).
8280
* Start a `12.2 LTS ML (includes Apache Spark 3.3.2, GPU, Scala 2.12)` single-node cluster with node type having 8 A100 GPUs (e.g. `Standard_ND96asr_v4` or `p4d.24xlarge`). Note that these instance types may not be available in all regions, or may be difficult to provision. In Databricks, note that you must select the GPU runtime first, and unselect "Use Photon", for these instance types to appear (where supported).
8381
* Open the `train_dolly` notebook in the Repo (which is the `train_dolly.py` file in the Github `dolly` repo), attach to your GPU cluster, and run all cells. When training finishes, the notebook will save the model under `/dbfs/dolly_training`.
8482

8583
## Training on Other Instances
8684

87-
A100 instance types are not available in all cloud regions, or can be hard to provision. Training is possible on other GPU instance types, with small modifications to reduce memory usage.
88-
Training will take longer on these instances. These modifications are not necessarily optimal, but are simple to make.
85+
A100 instance types are not available in all cloud regions, or can be hard to provision. Training is possible on other GPU instance types,
86+
for smaller Dolly model sizes, and with small modifications to reduce memory usage.
87+
These modifications are not optimal, but are simple to make.
8988

9089
### A10 GPUs
9190

92-
To run on A10 instances (ex: `g5.24xlarge`, 4 x A10 24GB; `Standard_NV72ads_A10_v5`, 2 x A10), make the following changes:
91+
Training the 12B param model is not recommended on A10s.
92+
93+
To train the 6.9B param model on A10 instances (ex: `g5.24xlarge`, 4 x A10 24GB; `Standard_NV72ads_A10_v5`, 2 x A10), make the following changes:
9394

9495
- Modify the deepspeed config file `ds_z3_bf16_config.json` to configure optimizer offload. Within the `"zero_optimization"` section, add:
9596
```
@@ -100,23 +101,21 @@ To run on A10 instances (ex: `g5.24xlarge`, 4 x A10 24GB; `Standard_NV72ads_A10_
100101
```
101102
- Set the `num_gpus` widget in `train_dolly` to the number of GPUs in your instance, such as 2 or 4, before running
102103

103-
With 4 A10s, an epoch completes in about 7 hours.
104+
To train the 2.8B param model:
105+
106+
- Instead, simply set `per-device-train-batch-size` and `per-device-eval-batch-size` to 2 in the `train_dolly.py` invocation of `deepspeed`
104107

105108
### V100 GPUs
106109

107-
To run on V100 instances with 32GB of GPU memory (ex: `p3dn.24xlarge` or `Standard_ND40rs_v2`), make the following changes:
110+
To run on V100 instances with 32GB of GPU memory (ex: `p3dn.24xlarge` or `Standard_ND40rs_v2`), follow instructions above, and add:
108111

109-
- Modify the deepspeed config to enable optimizer offload, as above
110-
- Modify `trainer.py` to disable `bf16` and enable `fp16` in `TrainingArguments`:
112+
- Modify `training/trainer.py` to disable `bf16` and enable `fp16` in `TrainingArguments`:
111113
```
112114
...
113115
fp16=True,
114116
bf16=False,
115117
...
116118
```
117-
- Set the `num_gpus` widget in `train_dolly` to the number of GPUs in your instance, typically 8
118-
119-
With 8 V100s, an epoch completes in about 3.5 hours. Note that the resulting model may be slightly different when trained with `fp16` versus `bf16`.
120119

121120
## Running Unit Tests Locally
122121

0 commit comments

Comments
 (0)