Skip to content

Commit f9402d0

Browse files
authored
Add generation server scripts using HF accelerate and DS-inference (#328)
* first step towards making libs * HF accelerate model * refactor accelerate * refactor DS inference * refactor DS ZeRO * make inference library * cli * server * request * remove MaxTokensError * fix batch size error with DS inference server * type fix * add latency * add latency * add min_length to default kwargs * str kwargs * str kwargs * fix comma * add old scripts back * move scripts * drop data * minor changes + add README * update README * drop nccl * fix * default values * resolve issues * handle keyboard interrupt * remove caching * use snapshot_download * make server class * fix snapshot download Co-authored-by: Mayank Mishra <[email protected]>
1 parent c1139c7 commit f9402d0

File tree

21 files changed

+1436
-0
lines changed

21 files changed

+1436
-0
lines changed
File renamed without changes.
Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
## Inference solutions for BLOOM 176B
2+
We support HuggingFace accelerate and DeepSpeed Inference for generation.
3+
4+
Install required packages:
5+
6+
```shell
7+
pip install fastapi uvicorn accelerate huggingface_hub>=0.9.0
8+
```
9+
To install [DeepSpeed](https://github.com/microsoft/DeepSpeed):
10+
```shell
11+
git clone https://github.com/microsoft/DeepSpeed
12+
cd DeepSpeed
13+
CFLAGS="-I$CONDA_PREFIX/include/" LDFLAGS="-L$CONDA_PREFIX/lib/" TORCH_CUDA_ARCH_LIST="7.0" DS_BUILD_CPU_ADAM=1 DS_BUILD_AIO=1 DS_BUILD_UTILS=1 pip install -e . --global-option="build_ext" --global-option="-j8" --no-cache -v --disable-pip-version-check
14+
```
15+
To install [DeepSpeed-MII](https://github.com/microsoft/DeepSpeed-MII):
16+
```shell
17+
git clone https://github.com/microsoft/DeepSpeed-MII
18+
cd DeepSpeed-MII
19+
pip install .
20+
```
21+
22+
All the provided scripts are tested on 8 A100 80GB GPUs for BLOOM 176B. These scripts might not work for other models or a different number of GPUs.
23+
DS inference only supports fp16 for cli and server application. However, for benchmarking, it supports both fp16 and bf16. bf16 support will be added once DeepSpeed adds suitable CUDA kernels for these.
24+
25+
DS inference is deployed using the DeepSpeed MII library which requires the resharded checkpoints for 8 x Tensor Parallel. The HuggingFace checkpoints can be resharded and cached using the following command:
26+
```shell
27+
deepspeed --num_gpus 8 scripts/bloom-inference-server/cache_ds_checkpoints.py --model_name bigscience/bloom --dtype fp16 --save_mp_checkpoint_path <PATH TO DS CACHED MODEL>
28+
```
29+
Note: Running the above script will consume ~350 GB of disk space and will take some time (~30 minutes), depending on both the speed of your GPUs and storage.
30+
31+
Note: sometimes GPU memory is not freed when DS inference deployment is shutdown. You can free this memory by running:
32+
```python
33+
import mii
34+
mii.terminate("ds_inference_grpc_server")
35+
```
36+
or alternatively, just doing a `killall python` in terminal.
37+
38+
#### BLOOM inference via command-line
39+
This asks for generate_kwargs everytime.
40+
Example: generate_kwargs =
41+
```json
42+
{"min_length": 100, "max_new_tokens": 100, "do_sample": false}
43+
```
44+
45+
1. using HF accelerate
46+
```shell
47+
python scripts/bloom-inference-server/cli.py --model_name bigscience/bloom --dtype bf16 --deployment_framework hf_accelerate --generate_kwargs '{"min_length": 100, "max_new_tokens": 100, "do_sample": false}'
48+
```
49+
50+
2. using DS inference
51+
```shell
52+
python scripts/bloom-inference-server/cli.py --model_name bigscience/bloom --dtype fp16 --deployment_framework ds_inference --save_mp_checkpoint_path <PATH TO DS CACHED MODEL> --generate_kwargs '{"min_length": 100, "max_new_tokens": 100, "do_sample": false}'
53+
```
54+
55+
#### BLOOM server deployment
56+
1. using HF accelerate
57+
```shell
58+
python scripts/bloom-inference-server/server.py --model_name bigscience/bloom --dtype bf16 --deployment_framework hf_accelerate --host <HOST ADDRESS> --port <PORT> --allowed_max_new_tokens 100
59+
```
60+
61+
2. using DS inference
62+
```shell
63+
python scripts/bloom-inference-server/server.py --model_name bigscience/bloom --dtype fp16 --deployment_framework ds_inference --save_mp_checkpoint_path <PATH TO DS CACHED MODEL> --host <HOST ADDRESS> --port <PORT> --allowed_max_new_tokens 100
64+
```
65+
66+
We provide an example [script](examples/server_request.py) to query the BLOOM server is provided. To run this script:
67+
```shell
68+
python scripts/bloom-inference-server/examples/server_request.py --host <HOST ADDRESS> --port <PORT>
69+
```
70+
71+
#### Benchmark system for BLOOM inference
72+
1. using HF accelerate
73+
```shell
74+
python scripts/bloom-inference-server/benchmark.py --model_name bigscience/bloom --dtype bf16 --deployment_framework hf_accelerate --benchmark_cycles 5
75+
```
76+
77+
2. using DS inference
78+
```shell
79+
deepspeed --num_gpus 8 scripts/bloom-inference-server/benchmark.py --model_name bigscience/bloom --dtype fp16 --deployment_framework ds_inference --save_mp_checkpoint_path <PATH TO DS CACHED MODEL> --benchmark_cycles 5
80+
```
81+
82+
3. using DS ZeRO
83+
```shell
84+
deepspeed --num_gpus 8 scripts/bloom-inference-server/benchmark.py --model_name bigscience/bloom --dtype bf16 --deployment_framework ds_zero --benchmark_cycles 5
85+
```
86+
87+
Alternatively, the following shell script will benchmark different batch sizes for the model.
88+
```shell
89+
mkdir -p logs
90+
91+
for bs in {1,2,4,8,16,32,64,128}
92+
do
93+
python scripts/bloom-inference-server/benchmark.py --model_name bigscience/bloom --dtype bf16 --deployment_framework hf_accelerate --benchmark_cycles 5 --batch_size $bs 2>&1 | tee logs/hf-$bs.log
94+
95+
deepspeed --num_gpus 8 scripts/bloom-inference-server/benchmark.py --model_name bigscience/bloom --dtype fp16 --deployment_framework ds_inference --save_mp_checkpoint_path <PATH TO DS CACHED MODEL> --benchmark_cycles 5 --batch_size $bs 2>&1 | tee logs/ds-$bs.log
96+
97+
deepspeed --num_gpus 8 scripts/bloom-inference-server/benchmark.py --model_name bigscience/bloom --dtype bf16 --deployment_framework ds_zero --benchmark_cycles 5 --batch_size $bs 2>&1 | tee logs/ds-zero-$bs.log
98+
done
99+
```
100+
101+
The following will benchmark sequence length for batch size = 1 on DS inference.
102+
```shell
103+
for sq in {1,10,50,100,200,300,400,500,600,700,800,900,1000,1500,2000,2500,3000,3500,4000,4500,5000}
104+
do
105+
deepspeed --num_gpus 8 scripts/bloom-inference-server/benchmark.py --model_name bigscience/bloom --dtype fp16 --batch_size 1 --benchmark_cycles 5 --deployment_framework ds_inference --generate_kwargs '{"do_sample": false, "min_length": '$sq', "max_new_tokens": '$sq'}' 2>&1 | tee logs/ds_$sq.log
106+
done
107+
```
Lines changed: 159 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,159 @@
1+
import argparse
2+
import gc
3+
import os
4+
5+
import deepspeed
6+
import torch
7+
8+
import utils
9+
from ds_inference import DSInferenceModel
10+
from ds_zero import DSZeROModel
11+
from hf_accelerate import HFAccelerateModel
12+
from utils import (
13+
BENCHMARK,
14+
DS_INFERENCE,
15+
DS_ZERO,
16+
HF_ACCELERATE,
17+
GenerateRequest,
18+
Model,
19+
get_argument_parser,
20+
get_dummy_batch,
21+
parse_generate_kwargs,
22+
print_rank_n,
23+
run_and_log_time
24+
)
25+
26+
27+
def benchmark_generation(model: Model,
28+
request: GenerateRequest,
29+
cycles: int = 5):
30+
total_new_tokens_generated = 0
31+
for _ in range(cycles):
32+
response = model.generate(request)
33+
total_new_tokens_generated += sum(
34+
new_tokens for new_tokens in response.num_generated_tokens)
35+
return total_new_tokens_generated
36+
37+
38+
def get_benchmark_results(benchmark_time: float,
39+
initialization_time: float,
40+
total_new_tokens_generated: int,
41+
batch_size: int,
42+
cycles: int) -> str:
43+
throughput = total_new_tokens_generated / benchmark_time
44+
latency = benchmark_time / cycles
45+
return f"""
46+
*** Performance stats:
47+
Throughput (including tokenization) = {throughput:.2f} tokens/sec
48+
Throughput (including tokenization) = {1000 / throughput:.2f} msecs/token
49+
Model loading time = {initialization_time:.2f} secs
50+
Total tokens generated = {total_new_tokens_generated} with batch size = {batch_size}
51+
Latency = {latency:.2f} secs
52+
Model loading time + generation time per batch = {initialization_time + latency:.2f} secs
53+
"""
54+
55+
56+
def benchmark_end_to_end(args: argparse.Namespace,
57+
model_class: Model,
58+
zero_activated: bool = False) -> None:
59+
model, initialization_time = run_and_log_time(
60+
(model_class, {"args": args})
61+
)
62+
63+
request = parse_generate_kwargs(
64+
get_dummy_batch(args.batch_size),
65+
args.generate_kwargs
66+
)
67+
68+
print_rank_n(f"generate_kwargs = {args.generate_kwargs}")
69+
print_rank_n(f"batch_size = {args.batch_size}")
70+
71+
# warmup is a must if measuring speed as it's when all the optimizations are performed
72+
# e.g. on 8x80 a100 the first pass of 100 tokens takes 23sec, and the next one is 4secs
73+
response = model.generate(request)
74+
75+
for i, (o, _) in zip(request.text, zip(response.text, response.num_generated_tokens)):
76+
print_rank_n(f"{'-' * 60}\nin = {i}\nout = {o}\n")
77+
78+
if (args.benchmark_cycles > 0):
79+
print_rank_n(f"*** Running benchmark")
80+
81+
torch.cuda.empty_cache()
82+
gc.collect()
83+
84+
# warm up
85+
model.generate(request)
86+
torch.cuda.synchronize()
87+
88+
# benchmark
89+
total_new_tokens_generated, benchmark_time = run_and_log_time(
90+
(
91+
benchmark_generation,
92+
{
93+
"model": model,
94+
"request": request,
95+
"cycles": args.benchmark_cycles
96+
}
97+
)
98+
)
99+
100+
# with ZeRO every GPU is generating batch_size * sequence_length tokens
101+
if (zero_activated):
102+
world_size = int(os.getenv('WORLD_SIZE', '1'))
103+
total_new_tokens_generated *= world_size
104+
105+
print_rank_n(
106+
get_benchmark_results(
107+
benchmark_time,
108+
initialization_time,
109+
total_new_tokens_generated,
110+
args.batch_size,
111+
args.benchmark_cycles
112+
)
113+
)
114+
115+
116+
def get_args() -> argparse.Namespace:
117+
parser = get_argument_parser()
118+
119+
group = parser.add_argument_group(title="launch config")
120+
group.add_argument("--benchmark_cycles", type=int,
121+
default=0, help="additionally run benchmark")
122+
group.add_argument("--local_rank", required=False,
123+
type=int, help="used by dist launchers")
124+
group.add_argument("--batch_size", default=1, type=int, help="batch size")
125+
group.add_argument("--cpu_offload", action="store_true",
126+
help="whether to activate CPU offload for DS ZeRO")
127+
128+
args = utils.get_args(parser, BENCHMARK)
129+
130+
launched_with_deepspeed = args.deployment_framework in [
131+
DS_INFERENCE, DS_ZERO]
132+
133+
if (not launched_with_deepspeed):
134+
assert args.local_rank == None, "local_rank must be None if not launched with DeepSpeed"
135+
136+
if (args.cpu_offload):
137+
assert args.deployment_framework == DS_ZERO, "cpu_offload only works with DS_ZeRO"
138+
139+
return args
140+
141+
142+
def main() -> None:
143+
args = get_args()
144+
145+
if (args.deployment_framework == HF_ACCELERATE):
146+
benchmark_end_to_end(args, HFAccelerateModel)
147+
elif (args.deployment_framework == DS_INFERENCE):
148+
deepspeed.init_distributed("nccl")
149+
benchmark_end_to_end(args, DSInferenceModel)
150+
elif (args.deployment_framework == DS_ZERO):
151+
deepspeed.init_distributed("nccl")
152+
benchmark_end_to_end(args, DSZeROModel, zero_activated=True)
153+
else:
154+
raise ValueError(
155+
f"Unknown deployment framework {args.deployment_framework}")
156+
157+
158+
if (__name__ == "__main__"):
159+
main()
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
import argparse
2+
import json
3+
import sys
4+
5+
import utils
6+
from ds_inference import DSInferenceGRPCServer
7+
from hf_accelerate import HFAccelerateModel
8+
from utils import CLI, DS_INFERENCE, HF_ACCELERATE, get_argument_parser, parse_generate_kwargs, print_rank_n
9+
10+
11+
def get_args() -> argparse.Namespace:
12+
parser = get_argument_parser()
13+
14+
group = parser.add_argument_group(title="launch config")
15+
group.add_argument("--shutdown_command", required=False,
16+
type=str, default="__shutdown__", help="This string will exit the script")
17+
18+
args = utils.get_args(parser, CLI)
19+
20+
return args
21+
22+
23+
def main() -> None:
24+
args = get_args()
25+
26+
if (args.deployment_framework == HF_ACCELERATE):
27+
model = HFAccelerateModel(args)
28+
elif (args.deployment_framework == DS_INFERENCE):
29+
model = DSInferenceGRPCServer(args)
30+
else:
31+
raise ValueError(
32+
f"Unknown deployment framework {args.deployment_framework}")
33+
34+
generate_kwargs = args.generate_kwargs
35+
36+
while (True):
37+
try:
38+
input_text = input("Input text: ")
39+
40+
if (input_text == args.shutdown_command):
41+
model.shutdown()
42+
43+
if (input("change generate_kwargs? [y/n] ") == "y"):
44+
while (True):
45+
try:
46+
generate_kwargs = json.loads(
47+
input("Generate kwargs: "))
48+
break
49+
except KeyboardInterrupt:
50+
model.shutdown()
51+
except Exception as e:
52+
e_type, e_message, _ = sys.exc_info()
53+
print("error =", e_type.__name__)
54+
print("message =", e_message)
55+
continue
56+
57+
request = parse_generate_kwargs([input_text], generate_kwargs)
58+
response = model.generate(request)
59+
60+
print_rank_n("Output text:", response.text[0])
61+
print_rank_n("Generated tokens:", response.num_generated_tokens[0])
62+
except KeyboardInterrupt:
63+
model.shutdown()
64+
65+
66+
if (__name__ == "__main__"):
67+
main()
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
from .grpc_server import DSInferenceGRPCServer
2+
from .model import DSInferenceModel

0 commit comments

Comments
 (0)