Skip to content

Commit 7b5f175

Browse files
committed
not yet working script
1 parent 2ce8bb4 commit 7b5f175

File tree

1 file changed

+153
-0
lines changed

1 file changed

+153
-0
lines changed
Lines changed: 153 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,153 @@
1+
2+
# usage:
3+
# deepspeed --num_gpus 1 bloom-inference.py --name bigscience/bloom-350m
4+
#
5+
6+
#import glob
7+
from argparse import ArgumentParser
8+
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
9+
from transformers.deepspeed import HfDeepSpeedConfig
10+
from transformers.models.bloom.modeling_bloom import BloomBlock as BloomBlock
11+
import deepspeed
12+
import io
13+
import json
14+
import os
15+
import torch
16+
import torch.distributed as dist
17+
18+
parser = ArgumentParser()
19+
20+
parser.add_argument("--name", required=True, type=str)
21+
parser.add_argument("--local_rank", required=False, type=int)
22+
parser.add_argument("--deepspeed", action="store_true")
23+
args = parser.parse_args()
24+
25+
local_rank = int(os.getenv('LOCAL_RANK', '0'))
26+
world_size = int(os.getenv('WORLD_SIZE', '1'))
27+
28+
def get_checkpoint_files(pretrained_model_name_or_path):
29+
# XXX: I just hacked this one together to automatically handle the fetching of the model file or
30+
# shards into cache and returning the cached entries - note that I removed most arguments
31+
32+
from transformers.utils import WEIGHTS_NAME, WEIGHTS_INDEX_NAME, cached_path, hf_bucket_url
33+
34+
cache_dir = None
35+
is_sharded = False
36+
filename = WEIGHTS_NAME
37+
archive_file = hf_bucket_url(pretrained_model_name_or_path, filename=filename)
38+
39+
try:
40+
resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir)
41+
return [resolved_archive_file]
42+
43+
except EntryNotFoundError:
44+
if filename == WEIGHTS_NAME:
45+
# Maybe the checkpoint is sharded, we try to grab the index name in this case.
46+
archive_file = hf_bucket_url(
47+
pretrained_model_name_or_path,
48+
filename=WEIGHTS_INDEX_NAME,
49+
)
50+
resolved_archive_file = cached_path(
51+
archive_file,
52+
cache_dir=cache_dir,
53+
)
54+
is_sharded = True
55+
56+
if is_sharded:
57+
# resolved_archive_file becomes a list of files that point to the different checkpoint shards in this case.
58+
resolved_archive_file, sharded_metadata = get_checkpoint_shard_files(
59+
pretrained_model_name_or_path,
60+
resolved_archive_file,
61+
cache_dir=cache_dir,
62+
)
63+
64+
return resolved_archive_file
65+
66+
67+
model_name = args.name
68+
69+
tokenizer = AutoTokenizer.from_pretrained(model_name)
70+
config = AutoConfig.from_pretrained(model_name)
71+
model_hidden_size = config.hidden_size
72+
train_batch_size = 1 * world_size
73+
model = AutoModelForCausalLM.from_config(config)
74+
75+
ds_config = {
76+
"fp16": {
77+
"enabled": model.dtype == torch.float16,
78+
},
79+
"bf16": {
80+
"enabled": model.dtype == torch.bfloat16,
81+
},
82+
"zero_optimization": {
83+
"stage": 3,
84+
"offload_param": {
85+
"device": "cpu",
86+
"pin_memory": True
87+
},
88+
"overlap_comm": True,
89+
"contiguous_gradients": True,
90+
"reduce_bucket_size": model_hidden_size * model_hidden_size,
91+
"stage3_prefetch_bucket_size": 0.9 * model_hidden_size * model_hidden_size,
92+
"stage3_param_persistence_threshold": 0
93+
},
94+
"steps_per_print": 2000,
95+
"train_batch_size": train_batch_size,
96+
"train_micro_batch_size_per_gpu": 1,
97+
"wall_clock_breakdown": False
98+
}
99+
100+
dschf = HfDeepSpeedConfig(ds_config)
101+
102+
model = model.eval()
103+
ds_engine = deepspeed.initialize(model=model, config_params=ds_config)[0]
104+
ds_engine.module.eval()
105+
model = ds_engine.module
106+
107+
108+
109+
checkpoints_json = "checkpoints.json"
110+
with io.open(checkpoints_json, 'w', encoding='utf-8') as f:
111+
112+
#checkpoint_files = glob.glob(f"args.checkpoint_dir/*bin")
113+
checkpoint_files = get_checkpoint_files(model_name)
114+
115+
print("Checkpoint files:", checkpoint_files)
116+
117+
data = {
118+
"type": "BLOOM-176B",
119+
"checkpoints": checkpoint_files,
120+
"version": 1.0
121+
}
122+
json.dump(data, f)
123+
124+
125+
model = deepspeed.init_inference(model,
126+
mp_size=1,
127+
dtype=torch.half,
128+
checkpoint=checkpoints_json,
129+
#injection_policy={BloomBlock: ('self_attention.dense', 'mlp.dense_4h_to_h')}
130+
replace_with_kernel_inject=True
131+
)
132+
model = model.module
133+
134+
text_in = 'DeepSpeed is'
135+
136+
tokens = tokenizer(text_in, return_tensors="pt")
137+
138+
for t in tokens:
139+
if torch.is_tensor(tokens[t]):
140+
tokens[t] = tokens[t].to(torch.cuda.current_device())
141+
142+
with torch.no_grad():
143+
gen_tokens = model.generate(
144+
**tokens,
145+
min_length=50,
146+
max_length=50,
147+
do_sample=False,
148+
)
149+
150+
151+
text_out = tokenizer.batch_decode(gen_tokens)[0]
152+
153+
print(f"in={text_in}\nout={text_out}")

0 commit comments

Comments
 (0)