Skip to content

Commit dcef926

Browse files
authored
Suggest bfloat16 and add generation notes for A10, V100 (#95)
I propose we explicitly show loading in bf16 over fp32 in the generation example and in the generation code. This helps avoid OOM for sure in many usages. I add additional notes on getting generation to work on A10, V100 GPUs with 8-bit.
1 parent 3ecb8d9 commit dcef926

File tree

2 files changed

+27
-15
lines changed

2 files changed

+27
-15
lines changed

README.md

Lines changed: 25 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -51,12 +51,13 @@ maximize the potential of all individuals and organizations.
5151

5252
If you'd like to simply test the model without training, the model is available on Hugging Face as [databricks/dolly-v2-12b](https://huggingface.co/databricks/dolly-v2-12b).
5353

54-
To use the model with the `transformers` library on a machine with GPUs:
54+
To use the model with the `transformers` library on a machine with A100 GPUs:
5555

5656
```
5757
from transformers import pipeline
58+
import torch
5859
59-
instruct_pipeline = pipeline(model="databricks/dolly-v2-12b", trust_remote_code=True, device_map="auto")
60+
instruct_pipeline = pipeline(model="databricks/dolly-v2-12b", torch_dtype=torch.bfloat16, trust_remote_code=True, device_map="auto")
6061
```
6162

6263
You can then use the pipeline to answer instructions:
@@ -65,28 +66,38 @@ You can then use the pipeline to answer instructions:
6566
instruct_pipeline("Explain to me the difference between nuclear fission and fusion.")
6667
```
6768

68-
To reduce memory usage you can load the model with `bfloat16`:
69+
### Generating on Other Instances
6970

70-
```
71-
import torch
72-
from transformers import pipeline
71+
A100 instance types are not available in all cloud regions, or can be hard to provision. Inference is possible on other GPU instance types.
7372

74-
instruct_pipeline = pipeline(model="databricks/dolly-v2-12b", torch_dtype=torch.bfloat16, trust_remote_code=True, device_map="auto")
75-
```
73+
#### A10 GPUs
74+
75+
The 6.9B and 2.8B param models should work as-is.
76+
77+
To generate using the 12B param model on A10s (ex: `g5.4xlarge`, 1 x A10 24GB), it's necessary to load and run generating using 8-bit weights, which impacts the results slightly:
78+
79+
- Also install `bitsandbytes`
80+
- Add `model_kwargs={'load_in_8bit': True}` to the `pipeline()` command shown above
81+
82+
#### V100 GPUs
83+
84+
When using V100s (ex: `p3.2xlarge`, 1 x V100 16GB, `NC6s_v3`), in all cases, set `torch_dtype=torch.float16` in `pipeline()` instead.
85+
86+
Otherwise, follow the steps above. The 12B param model may not function well in 8-bit on V100s.
7687

7788
## Getting Started with Training
7889

79-
* Add the `dolly` repo to Databricks (under Repos click Add Repo, enter `https://github.com/databrickslabs/dolly.git`, then click Create Repo).
80-
* Start a `12.2 LTS ML (includes Apache Spark 3.3.2, GPU, Scala 2.12)` single-node cluster with node type having 8 A100 GPUs (e.g. `Standard_ND96asr_v4` or `p4d.24xlarge`). Note that these instance types may not be available in all regions, or may be difficult to provision. In Databricks, note that you must select the GPU runtime first, and unselect "Use Photon", for these instance types to appear (where supported).
81-
* Open the `train_dolly` notebook in the Repo (which is the `train_dolly.py` file in the Github `dolly` repo), attach to your GPU cluster, and run all cells. When training finishes, the notebook will save the model under `/dbfs/dolly_training`.
90+
- Add the `dolly` repo to Databricks (under Repos click Add Repo, enter `https://github.com/databrickslabs/dolly.git`, then click Create Repo).
91+
- Start a `12.2 LTS ML (includes Apache Spark 3.3.2, GPU, Scala 2.12)` single-node cluster with node type having 8 A100 GPUs (e.g. `Standard_ND96asr_v4` or `p4d.24xlarge`). Note that these instance types may not be available in all regions, or may be difficult to provision. In Databricks, note that you must select the GPU runtime first, and unselect "Use Photon", for these instance types to appear (where supported).
92+
- Open the `train_dolly` notebook in the Repo (which is the `train_dolly.py` file in the Github `dolly` repo), attach to your GPU cluster, and run all cells. When training finishes, the notebook will save the model under `/dbfs/dolly_training`.
8293

83-
## Training on Other Instances
94+
### Training on Other Instances
8495

8596
A100 instance types are not available in all cloud regions, or can be hard to provision. Training is possible on other GPU instance types,
8697
for smaller Dolly model sizes, and with small modifications to reduce memory usage.
8798
These modifications are not optimal, but are simple to make.
8899

89-
### A10 GPUs
100+
#### A10 GPUs
90101

91102
Training the 12B param model is not recommended on A10s.
92103

@@ -105,7 +116,7 @@ To train the 2.8B param model:
105116

106117
- Instead, simply set `per-device-train-batch-size` and `per-device-eval-batch-size` to 2 in the `train_dolly.py` invocation of `deepspeed`
107118

108-
### V100 GPUs
119+
#### V100 GPUs
109120

110121
To run on V100 instances with 32GB of GPU memory (ex: `p3dn.24xlarge` or `Standard_ND40rs_v2`), follow instructions above, and add:
111122

training/generate.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import logging
22
import re
33
from typing import List, Tuple
4+
import torch
45

56
import numpy as np
67
from transformers import (
@@ -34,7 +35,7 @@ def load_model_tokenizer_for_generate(
3435
"""
3536
tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path, padding_side="left")
3637
model = AutoModelForCausalLM.from_pretrained(
37-
pretrained_model_name_or_path, device_map="auto", trust_remote_code=True
38+
pretrained_model_name_or_path, device_map="auto", torch_dtype=torch.bfloat16, trust_remote_code=True
3839
)
3940
return model, tokenizer
4041

0 commit comments

Comments
 (0)