Skip to content

Commit ae379b0

Browse files
author
tnixon
committed
set torch_dtype as string for different fp architectures
1 parent bd1a3a9 commit ae379b0

File tree

1 file changed

+4
-5
lines changed

1 file changed

+4
-5
lines changed

train_dolly.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -192,8 +192,6 @@
192192

193193
# COMMAND ----------
194194

195-
import torch
196-
197195
# Examples from https://www.databricks.com/blog/2023/03/24/hello-dolly-democratizing-magic-chatgpt-open-models.html
198196
instructions = [
199197
"Write a love letter to Edgar Allan Poe.",
@@ -204,10 +202,11 @@
204202
]
205203

206204
# set some additional pipeline args
205+
pipeline_kwargs = {'torch_dtype': "auto"}
207206
if gpu_family == "v100":
208-
pipeline_kwargs = {'torch_dtype': torch.float16}
209-
else:
210-
pipeline_kwargs = {}
207+
pipeline_kwargs['torch_dtype'] = "float16"
208+
elif gpu_family == "a10" or gpu_family == "a100":
209+
pipeline_kwargs['torch_dtype'] = "bfloat16"
211210

212211
# Use the model to generate responses for each of the instructions above.
213212
for instruction in instructions:

0 commit comments

Comments
 (0)