You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Suggest bfloat16 and add generation notes for A10, V100 (#95)
I propose we explicitly show loading in bf16 over fp32 in the generation example and in the generation code. This helps avoid OOM for sure in many usages.
I add additional notes on getting generation to work on A10, V100 GPUs with 8-bit.
Copy file name to clipboardExpand all lines: README.md
+25-14Lines changed: 25 additions & 14 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -51,12 +51,13 @@ maximize the potential of all individuals and organizations.
51
51
52
52
If you'd like to simply test the model without training, the model is available on Hugging Face as [databricks/dolly-v2-12b](https://huggingface.co/databricks/dolly-v2-12b).
53
53
54
-
To use the model with the `transformers` library on a machine with GPUs:
54
+
To use the model with the `transformers` library on a machine with A100 GPUs:
To generate using the 12B param model on A10s (ex: `g5.4xlarge`, 1 x A10 24GB), it's necessary to load and run generating using 8-bit weights, which impacts the results slightly:
78
+
79
+
- Also install `bitsandbytes`
80
+
- Add `model_kwargs={'load_in_8bit': True}` to the `pipeline()` command shown above
81
+
82
+
#### V100 GPUs
83
+
84
+
When using V100s (ex: `p3.2xlarge`, 1 x V100 16GB, `NC6s_v3`), in all cases, set `torch_dtype=torch.float16` in `pipeline()` instead.
85
+
86
+
Otherwise, follow the steps above. The 12B param model may not function well in 8-bit on V100s.
76
87
77
88
## Getting Started with Training
78
89
79
-
* Add the `dolly` repo to Databricks (under Repos click Add Repo, enter `https://github.com/databrickslabs/dolly.git`, then click Create Repo).
80
-
* Start a `12.2 LTS ML (includes Apache Spark 3.3.2, GPU, Scala 2.12)` single-node cluster with node type having 8 A100 GPUs (e.g. `Standard_ND96asr_v4` or `p4d.24xlarge`). Note that these instance types may not be available in all regions, or may be difficult to provision. In Databricks, note that you must select the GPU runtime first, and unselect "Use Photon", for these instance types to appear (where supported).
81
-
* Open the `train_dolly` notebook in the Repo (which is the `train_dolly.py` file in the Github `dolly` repo), attach to your GPU cluster, and run all cells. When training finishes, the notebook will save the model under `/dbfs/dolly_training`.
90
+
- Add the `dolly` repo to Databricks (under Repos click Add Repo, enter `https://github.com/databrickslabs/dolly.git`, then click Create Repo).
91
+
- Start a `12.2 LTS ML (includes Apache Spark 3.3.2, GPU, Scala 2.12)` single-node cluster with node type having 8 A100 GPUs (e.g. `Standard_ND96asr_v4` or `p4d.24xlarge`). Note that these instance types may not be available in all regions, or may be difficult to provision. In Databricks, note that you must select the GPU runtime first, and unselect "Use Photon", for these instance types to appear (where supported).
92
+
- Open the `train_dolly` notebook in the Repo (which is the `train_dolly.py` file in the Github `dolly` repo), attach to your GPU cluster, and run all cells. When training finishes, the notebook will save the model under `/dbfs/dolly_training`.
82
93
83
-
## Training on Other Instances
94
+
###Training on Other Instances
84
95
85
96
A100 instance types are not available in all cloud regions, or can be hard to provision. Training is possible on other GPU instance types,
86
97
for smaller Dolly model sizes, and with small modifications to reduce memory usage.
87
98
These modifications are not optimal, but are simple to make.
88
99
89
-
### A10 GPUs
100
+
####A10 GPUs
90
101
91
102
Training the 12B param model is not recommended on A10s.
92
103
@@ -105,7 +116,7 @@ To train the 2.8B param model:
105
116
106
117
- Instead, simply set `per-device-train-batch-size` and `per-device-eval-batch-size` to 2 in the `train_dolly.py` invocation of `deepspeed`
107
118
108
-
### V100 GPUs
119
+
####V100 GPUs
109
120
110
121
To run on V100 instances with 32GB of GPU memory (ex: `p3dn.24xlarge` or `Standard_ND40rs_v2`), follow instructions above, and add:
0 commit comments