Skip to content

Commit 8bcfe36

Browse files
[example] Update Inference Example (#5725)
* [example] update inference example
1 parent a8d459f commit 8bcfe36

File tree

3 files changed

+75
-100
lines changed

3 files changed

+75
-100
lines changed

colossalai/inference/spec/README.md

Lines changed: 0 additions & 96 deletions
This file was deleted.

examples/inference/llama/README.md

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
## Run Inference
2+
3+
The provided example `llama_generation.py` is an example to configure, initialize the engine, and run inference on provided model. We've added `AutoModelForCausalLM` and `NoPaddingLlamaModelInferPolicy` as model class and policy class, and the script is good to run inference with Llama 3.
4+
5+
For a basic setting, you could run the example by:
6+
```bash
7+
colossalai run --nproc_per_node 1 llama_generation.py -m PATH_MODEL --max_length 128
8+
```
9+
10+
Run multi-GPU inference (Tensor Parallelism), as in the following example using 2 GPUs:
11+
```bash
12+
colossalai run --nproc_per_node 2 llama_generation.py -m PATH_MODEL --max_length 128 --tp_size 2
13+
```
14+
15+
## Run Speculative Decoding
16+
17+
Colossal-Inference supports speculative decoding using the inference engine, with optimized kernels and cache management for the main model.
18+
19+
Both a drafter model (small model) and a main model (large model) will be used during speculative decoding process. The drafter model will generate a few tokens sequentially, and then the main model will validate those candidate tokens in parallel and accept validated ones. The decoding process will be speeded up, for the latency of speculating multiple tokens by the drafter model is lower than that by the main model.
20+
21+
Moreover, Colossal-Inference also supports GLIDE, a modified draft model architecture that reuses key and value caches from the main model, which improves the acceptance rate and increment the speed-up ratio. Details can be found in research paper GLIDE with a CAPE - A Low-Hassle Method to Accelerate Speculative Decoding on [arXiv](https://arxiv.org/pdf/2402.02082.pdf).
22+
23+
Right now, Colossal-Inference offers a GLIDE model compatible with vicuna7B (https://huggingface.co/lmsys/vicuna-7b-v1.5). You can find the fine-tuned GLIDE drafter model `cxdu/glide-vicuna7b` on the HuggingFace Hub: https://huggingface.co/cxdu/glide-vicuna7b.
24+
25+
Benchmarking with gsm8k and MT-Bench dataset with batch size 1 on H800, the speed increase for using speculative decoding is around 1.28x, and the speed increase for using speculative decoding with Glide model (as drafter model) is around 1.5x.
26+
27+
## Usage
28+
29+
For main model, you might want to use model card `lmsys/vicuna-7b-v1.5` at [HuggingFace Hub](https://huggingface.co/lmsys/vicuna-7b-v1.5).
30+
For regular drafter model, you might want to use model card `JackFram/llama-68m` at [HuggingFace Hub](https://huggingface.co/JackFram/llama-68m).
31+
For the GLIDE drafter model, you could use model card `cxdu/glide-vicuna7b` at [HuggingFace Hub](https://huggingface.co/cxdu/glide-vicuna7b).
32+
33+
34+
You could run speculative decoding by
35+
```bash
36+
colossalai run --nproc_per_node 1 llama_generation.py -m PATH_MODEL --drafter_model PATH_DRAFTER_MODEL --max_length 128
37+
```
38+
39+
Run multi-GPU inference (Tensor Parallelism), as in the following example using 2 GPUs.
40+
```bash
41+
colossalai run --nproc_per_node 2 llama_generation.py -m PATH_MODEL --drafter_model PATH_DRAFTER_MODEL --max_length 128 --tp_size 2
42+
```
43+
44+
If you want to try the GLIDE model (glide-vicuna7b) as the drafter model with vicuna-7B, you could provide the GLIDE model path or model card as drafter model and enable the feature by
45+
```python
46+
engine.enable_spec_dec(drafter_model, use_glide_drafter=True)
47+
```

examples/inference/llama/llama_generation.py

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ def infer(args):
2727
model = MODEL_CLS.from_pretrained(model_path_or_name)
2828
tokenizer = AutoTokenizer.from_pretrained(model_path_or_name)
2929
tokenizer.pad_token = tokenizer.eos_token
30-
coordinator.print_on_master(f"Model Config:\n{model.config}")
30+
# coordinator.print_on_master(f"Model Config:\n{model.config}")
3131

3232
# ==============================
3333
# Initialize InferenceEngine
@@ -52,20 +52,39 @@ def infer(args):
5252
pad_token_id=tokenizer.eos_token_id,
5353
eos_token_id=tokenizer.eos_token_id,
5454
max_length=args.max_length,
55-
do_sample=True,
55+
do_sample=args.do_sample,
56+
temperature=args.temperature,
57+
top_k=args.top_k,
58+
top_p=args.top_p,
5659
)
5760
coordinator.print_on_master(f"Generating...")
5861
out = engine.generate(prompts=[args.prompt], generation_config=generation_config)
59-
coordinator.print_on_master(out[0])
62+
coordinator.print_on_master(out)
63+
64+
# ==============================
65+
# Optionally, load drafter model and proceed speculative decoding
66+
# ==============================
67+
drafter_model_path_or_name = args.drafter_model
68+
if drafter_model_path_or_name is not None:
69+
drafter_model = AutoModelForCausalLM.from_pretrained(drafter_model_path_or_name)
70+
# turn on speculative decoding with the drafter model
71+
engine.enable_spec_dec(drafter_model)
72+
coordinator.print_on_master(f"Generating...")
73+
out = engine.generate(prompts=[args.prompt], generation_config=generation_config)
74+
coordinator.print_on_master(out)
75+
76+
engine.disable_spec_dec()
6077

6178

6279
# colossalai run --nproc_per_node 1 llama_generation.py -m MODEL_PATH
80+
# colossalai run --nproc_per_node 2 llama_generation.py -m MODEL_PATH --tp_size 2
6381
if __name__ == "__main__":
6482
# ==============================
6583
# Parse Arguments
6684
# ==============================
6785
parser = argparse.ArgumentParser()
6886
parser.add_argument("-m", "--model", type=str, help="Path to the model or model name")
87+
parser.add_argument("--drafter_model", type=str, help="Path to the drafter model or model name")
6988
parser.add_argument(
7089
"-p", "--prompt", type=str, default="Introduce some landmarks in the United Kingdom, such as", help="Prompt"
7190
)
@@ -75,7 +94,12 @@ def infer(args):
7594
parser.add_argument("-t", "--tp_size", type=int, default=1, help="Tensor Parallelism size")
7695
parser.add_argument("-d", "--dtype", type=str, default="fp16", help="Data type", choices=["fp16", "fp32", "bf16"])
7796
parser.add_argument("--use_cuda_kernel", action="store_true", help="Use CUDA kernel, use Triton by default")
78-
parser.add_argument("--max_length", type=int, default=32, help="Max length for generation")
97+
# Generation configs
98+
parser.add_argument("--max_length", type=int, default=64, help="Max length for generation")
99+
parser.add_argument("--do_sample", action="store_true", help="Use sampling for generation")
100+
parser.add_argument("--temperature", type=float, default=1.0, help="Temperature for generation")
101+
parser.add_argument("--top_k", type=int, default=50, help="Top k for generation")
102+
parser.add_argument("--top_p", type=float, default=1.0, help="Top p for generation")
79103
args = parser.parse_args()
80104

81105
infer(args)

0 commit comments

Comments
 (0)