Skip to content

Commit 4e1466c

Browse files
Implement H2O for long context inference on summarization tasks (meta-llama#411)
2 parents a32e919 + 3511426 commit 4e1466c

File tree

11 files changed

+3542
-0
lines changed

11 files changed

+3542
-0
lines changed

.github/scripts/spellcheck_conf/wordlist.txt

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1351,6 +1351,13 @@ Weaviate
13511351
MediaGen
13521352
SDXL
13531353
SVD
1354+
KV
1355+
KVs
1356+
XSUM
1357+
contrains
1358+
knowlege
1359+
kv
1360+
prefilling
13541361
DataFrame
13551362
DuckDB
13561363
Groq
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
## Run Llama with H2O for long context inference
2+
3+
### Overview:
4+
5+
Heavy-Hitter Oracle (H2O) is an efficient inference framework of LLMs. During the generative inference of transfomers, the size of KV cache grows linearly with the sequence length (prompt length + generation length) during long context generation. And the size KV cache is usually significantly larger than the model parameters, contrains the inference throughput. H2O identifies the critical KV pairs and evicts other unnecessary ones, maintaining a small cache size thus improving the throughput.
6+
7+
Besides, LLMs usually have poor generation to long sequence during inference. H2O handles this issue by maintaining only heavy-hitter tokens and the most recent tokens. Incorporated with the positional rolling strategy (reassigning the position of each kv with the position in the kv cache instead of the original sequence), H2O can process sequence length much longer than the pretrained context window. Different from other approaches, like [Positional Interpolation](https://arxiv.org/abs/2306.15595), H2O is a KV cache policy and do not involve any training process for long context processing.
8+
9+
Current implementation supports llama-1/2/3, from 7B to 70B. Since H2O only maintains the most important KV pairs, it might missing some important information in the middle content for some knowlege-intensive tasks.
10+
11+
More details please refer to Paper: **https://arxiv.org/pdf/2306.14048**; Blog: **https://allenz.work/?p=11**.
12+
13+
**Note: this implementation is tested with transformers == 4.39.0**
14+
15+
### Evaluation on Summarization Tasks
16+
17+
The following example runs inference of Llama-2-7b and Meta-Llama-3-8B on XSUM summarization tasks. We're using `--enable_h2o_generation` to enable H2O algorithm that only keeps heavy-hitter and the local KV pairs. Use `--num_window_length `to decide the KV cache size. The number of local and heavy-hitter KV pairs equals to half of the --num_window_length (Option: the number of heavy-hitters can also be decided by `--num_heavy_hitter_tokens`) Also, use --enable_position_rolling to enable position rolling in the KV cache size that assign the positions in the KV cache instead of the ones in original sequences. Enabling positional rolling is important when sequence length exceeds the pretrained context windows, e.g., 8K in Llama-3.
18+
19+
```
20+
python run_summarization.py \
21+
--input-path data/summarization/xsum.jsonl \
22+
--output-path summarization_output/xsum_h2o.jsonl \
23+
--model-name meta-llama/Meta-Llama-3-8B \
24+
--enable_h2o_generation
25+
```
26+
27+
##### **Results**
28+
29+
Expected results on XSUM (Rouge-2 score, the higher the better) from the above scripts on Llama-2/3 models. The sequence length of inputs are ~2k. Here we constrains the size of KV cache, allowing only n KVs to be write/read after the prefilling stage. n ranges from **64** to **full** where we maintain all the KV pairs. With 128 KVs, the performance can be matched as the full baseline (~2k KVs) while performance degradation is observed with 64 KVs. Also, maintaining a smaller KV cache reduces the I/O cost of KVs, thus we can achieve better throughput.
30+
31+
| KV Cache Size | 64 | 128 | 256 | 512 | 1024 | Full |
32+
| ------------- | ------ | ------ | ------ | ------ | ------ | ------ |
33+
| Llama-2-7B | 0.0439 | 0.1127 | 0.1148 | 0.1182 | 0.1170 | 0.1164 |
34+
| Llama-2-13B | 0.1180 | 0.1217 | 0.1243 | 0.1291 | 0.1302 | 0.1332 |
35+
| Llama-3-8B | 0.1107 | 0.1189 | 0.1200 | 0.1347 | 0.1290 | 0.1311 |
36+
37+
### One Demo on Streaming to "Infinite" Context Length
38+
39+
The following example demonstrates the generation process of "infinite" sequence length. We use MT-Bench data and generate the context sample-by-sample. The KV Cache will keep the KV pairs from the previous samples while maintain a fixed size. Results can be found on [Demo](https://allenz.work/?p=11) (Video 1).
40+
41+
```
42+
# run with full cache
43+
# expected results: 1) normal generation at the early stage; 2) performance collapse and generation slow down at the middle stage, because the sequence length exceeds the context window and the I/O cost of KV cache contrains the throughput; 3) OOM errors and stop.
44+
bash src/streaming.sh full
45+
46+
# run with h2o
47+
# expected results: normal generation at all stage.
48+
# adjust the number of heavy-hitter tokens with --num_heavy_hitter_tokens and size of KV cache with --num_window_length in src/streaming.sh
49+
bash src/streaming.sh h2o
50+
```

recipes/experimental/long-context/H2O/data/summarization/cnn_dailymail.jsonl

Lines changed: 1000 additions & 0 deletions
Large diffs are not rendered by default.

recipes/experimental/long-context/H2O/data/summarization/xsum.jsonl

Lines changed: 1000 additions & 0 deletions
Large diffs are not rendered by default.
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
transformers
2+
rouge
3+
xopen
4+
needlehaystack
Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
import torch
2+
import argparse
3+
import json
4+
import os
5+
import time
6+
import re
7+
import sys
8+
9+
from utils.streaming import load, download_url, load_jsonl, greedy_generate
10+
11+
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
12+
from utils.llama import H2OLlamaForCausalLM
13+
from utils.cache import Cache, HHCache, StaticCache
14+
15+
16+
@torch.no_grad()
17+
def streaming_inference_h2o(model, tokenizer, config, prompts, max_gen_len=1000, enable_h2o_generation=False):
18+
past_key_values = None
19+
for idx, prompt in enumerate(prompts):
20+
prompt = "USER: " + prompt + "\n\nASSISTANT: "
21+
print("\n" + prompt, end="")
22+
input_ids = tokenizer(prompt, return_tensors="pt").input_ids
23+
input_ids = input_ids.to(model.device)
24+
seq_len = input_ids.shape[1]
25+
26+
past_key_values = greedy_generate(
27+
model, tokenizer, input_ids, past_key_values, max_gen_len=max_gen_len
28+
)
29+
if enable_h2o_generation:
30+
space_needed = seq_len + max_gen_len
31+
past_key_values = HHCache.from_legacy_cache(config.num_window_length, config.num_heavy_hitter_tokens, past_key_values)
32+
past_key_values.evict_for_space(space_needed)
33+
past_key_values = past_key_values.to_legacy_cache()
34+
35+
36+
def main():
37+
parser = argparse.ArgumentParser()
38+
39+
parser.add_argument("--input-path", type=str, default="")
40+
parser.add_argument("--model-name", type=str, default="lmsys/vicuna-13b-v1.5")
41+
42+
parser.add_argument("--enable_h2o_generation", action='store_true')
43+
parser.add_argument("--num_heavy_hitter_tokens", type=int, default=128)
44+
parser.add_argument("--num_window_length", type=int, default=256)
45+
46+
parser.add_argument("--enable_position_rolling", action='store_true')
47+
48+
parser.add_argument("--seed", type=int, default=42, help="random seed for initialization")
49+
50+
args = parser.parse_args()
51+
52+
model_name = args.model_name
53+
data_root = args.input_path
54+
55+
config = AutoConfig.from_pretrained(model_name)
56+
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
57+
58+
if args.enable_h2o_generation:
59+
config.num_heavy_hitter_tokens = args.num_heavy_hitter_tokens
60+
config.num_window_length = args.num_window_length
61+
config.enable_position_rolling = args.enable_position_rolling
62+
model = H2OLlamaForCausalLM.from_pretrained(model_name,
63+
torch_dtype=torch.float16,
64+
device_map='auto',
65+
low_cpu_mem_usage=True,
66+
config=config)
67+
else:
68+
model = AutoModelForCausalLM.from_pretrained(model_name,
69+
torch_dtype=torch.float16,
70+
device_map='auto',
71+
low_cpu_mem_usage=True,)
72+
73+
test_filepath = os.path.join(data_root, "mt_bench.jsonl")
74+
print(f"Loading data from {test_filepath} ...")
75+
76+
if not os.path.exists(test_filepath):
77+
download_url(
78+
"https://raw.githubusercontent.com/lm-sys/FastChat/main/fastchat/llm_judge/data/mt_bench/question.jsonl",
79+
data_root,
80+
)
81+
os.rename(os.path.join(data_root, "question.jsonl"), test_filepath)
82+
83+
list_data = load_jsonl(test_filepath)
84+
prompts = []
85+
for sample in list_data:
86+
prompts += sample["turns"]
87+
88+
streaming_inference_h2o(model, tokenizer, config, prompts, enable_h2o_generation=args.enable_h2o_generation)
89+
90+
if __name__ == "__main__":
91+
main()
Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,147 @@
1+
import os
2+
import tqdm
3+
import json
4+
import copy
5+
import math
6+
7+
import torch
8+
import logging
9+
import argparse
10+
11+
import numpy as np
12+
from rouge import Rouge
13+
14+
import dataclasses
15+
from xopen import xopen
16+
17+
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
18+
from utils.llama import H2OLlamaForCausalLM
19+
20+
def set_seed(args):
21+
np.random.seed(args.seed)
22+
torch.manual_seed(args.seed)
23+
torch.cuda.manual_seed_all(args.seed)
24+
25+
if __name__ == '__main__':
26+
27+
parser = argparse.ArgumentParser()
28+
29+
parser.add_argument("--input-path", type=str, default="")
30+
parser.add_argument("--output-path", type=str, default="")
31+
32+
parser.add_argument("--model-name", type=str, default="")
33+
34+
parser.add_argument("--enable_h2o_generation", action='store_true')
35+
parser.add_argument("--num_heavy_hitter_tokens", type=int, default=-1)
36+
parser.add_argument("--num_window_length", type=int, default=256)
37+
38+
parser.add_argument("--enable_position_rolling", action='store_true')
39+
40+
parser.add_argument("--sample_num", type=int, default=500)
41+
parser.add_argument("--seed", type=int, default=42, help="random seed for initialization")
42+
43+
args = parser.parse_args()
44+
45+
set_seed(args)
46+
47+
model_name = args.model_name
48+
input_path = args.input_path
49+
output_path = args.output_path
50+
os.makedirs(os.path.dirname(output_path), exist_ok=True)
51+
52+
config = AutoConfig.from_pretrained(model_name)
53+
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
54+
if args.num_heavy_hitter_tokens == -1:
55+
print('not assign number of heavy hitter tokens, use half of the cache size: {}'.format(args.num_window_length // 2))
56+
args.num_heavy_hitter_tokens = args.num_window_length // 2
57+
58+
if args.enable_h2o_generation:
59+
config.num_heavy_hitter_tokens = args.num_heavy_hitter_tokens
60+
config.num_window_length = args.num_window_length
61+
config.enable_position_rolling = args.enable_position_rolling
62+
model = H2OLlamaForCausalLM.from_pretrained(model_name,
63+
torch_dtype=torch.float16,
64+
device_map='auto',
65+
low_cpu_mem_usage=True,
66+
config=config)
67+
else:
68+
model = AutoModelForCausalLM.from_pretrained(model_name,
69+
torch_dtype=torch.float16,
70+
device_map='auto',
71+
low_cpu_mem_usage=True,)
72+
73+
# loading inference data
74+
requests = []
75+
with open(input_path, 'r') as f:
76+
for line in f:
77+
if line.strip() != '':
78+
requests.append(json.loads(line))
79+
80+
if args.sample_num < len(requests):
81+
print('Sample {} Examples from {} samples'.format(args.sample_num, len(requests)))
82+
requests = requests[:args.sample_num]
83+
84+
results = []
85+
rouge = Rouge()
86+
rouge1_score_list = []
87+
rouge2_score_list = []
88+
rougel_score_list = []
89+
90+
with torch.no_grad():
91+
for request in tqdm.tqdm(requests):
92+
result = {'request': request, 'result': {}}
93+
prompt = request['article']
94+
label = request['summary_gt']
95+
temperature = request['temperature']
96+
stop = request['stop']
97+
98+
input_ids = tokenizer(prompt, add_special_tokens=False, return_tensors='pt').input_ids.to(model.device)
99+
100+
output_sequences = model.generate(
101+
input_ids=input_ids,
102+
max_length=request['max_tokens'] + len(input_ids[0]),
103+
temperature=temperature,
104+
top_p=request['top_p'],
105+
do_sample=True,
106+
num_return_sequences=request['n'],
107+
return_dict_in_generate=True, output_scores=True,
108+
pad_token_id=tokenizer.eos_token_id
109+
)
110+
111+
tokens = tokenizer.convert_ids_to_tokens(output_sequences['sequences'].squeeze(0))[len(input_ids[0]):]
112+
logprobs = [logits.log_softmax(dim=-1).max().item() for logits in output_sequences['scores']]
113+
top_logprobs = [{i: v for i, v in zip(tokens, logprobs)}]
114+
115+
generate_text = tokenizer.decode(output_sequences['sequences'].squeeze(0)[len(input_ids[0]):])
116+
generate_text = generate_text[: generate_text.find(stop[0])]
117+
118+
scores = rouge.get_scores(generate_text, label)[0]
119+
rouge1_score_list.append(scores['rouge-1']['f'])
120+
rouge2_score_list.append(scores['rouge-2']['f'])
121+
rougel_score_list.append(scores['rouge-l']['f'])
122+
123+
result['result'] = {
124+
"choices": [
125+
{
126+
"text": generate_text,
127+
"logprobs": {
128+
"tokens": tokens,
129+
"token_logprobs": logprobs,
130+
"top_logprobs": top_logprobs,
131+
"text_offset": []
132+
},
133+
"finish_reason": "length"
134+
}
135+
],
136+
"request_time": {
137+
"batch_time": 0,
138+
"batch_size": 1}
139+
}
140+
141+
results.append(result)
142+
143+
print('Average Rouge1: {:.6f}, Rouge-2: {:.6f}, Rouge-l: {:.6f}'.format(np.mean(rouge1_score_list), np.mean(rouge2_score_list), np.mean(rougel_score_list)))
144+
with open(output_path, 'w') as f:
145+
for result in results:
146+
f.write(json.dumps(result) + '\n')
147+
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
method=$1
2+
if [[ ${method} == 'h2o' ]]; then
3+
python -u run_streaming.py \
4+
--input-path data \
5+
--model-name lmsys/vicuna-13b-v1.5 \
6+
--enable_h2o_generation \
7+
--num_heavy_hitter_tokens 2048 \
8+
--num_window_length 4096 \
9+
--enable_position_rolling
10+
elif [[ ${method} == 'full' ]]; then
11+
python -u run_streaming.py \
12+
--input-path data \
13+
--model-name lmsys/vicuna-13b-v1.5
14+
else
15+
echo 'unknown argment for method'
16+
fi
17+
18+
19+
20+
21+
22+
23+

0 commit comments

Comments
 (0)