Skip to content

Commit 92389d0

Browse files
committed
modify args
1 parent 7084c1b commit 92389d0

File tree

3 files changed

+11
-4
lines changed

3 files changed

+11
-4
lines changed

recipes/experimental/long-context/H2O/README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,13 @@ More details please refer to Paper: **https://arxiv.org/pdf/2306.14048**; Blog:
1414

1515
### Evaluation on Summarization Tasks
1616

17-
The following example runs inference of Llama-2-7b on XSUM summarization tasks. We're using `--enable_h2o_generation` to enable H2O algorithm that only keeps heavy-hitter and the local KV pairs. Use `--num_heavy_hitter_tokens` to decide the number of heavy-hitter KV pairs and `--num_window_length `for the KV cache size. The number of local KV pairs equals num_window_length - num_heavy_hitter_tokens. Also, use --enable_position_rolling to enable position rolling in the KV cache size that assign the positions in the KV cache instead of the ones in original sequences. Enabling postional rolling is important when sequence length exceeds the pretrained context windows, e.g., 4K in Llama-2.
17+
The following example runs inference of Llama-2-7b and Meta-Llama-3-8B on XSUM summarization tasks. We're using `--enable_h2o_generation` to enable H2O algorithm that only keeps heavy-hitter and the local KV pairs. Use `--num_window_length `to decide the KV cache size. The number of local and heavy-hitter KV pairs equals to half of the --num_window_length (Option: the number of heavy-hitters can also be specific by `--num_heavy_hitter_tokens`) Also, use --enable_position_rolling to enable position rolling in the KV cache size that assign the positions in the KV cache instead of the ones in original sequences. Enabling positional rolling is important when sequence length exceeds the pretrained context windows, e.g., 8K in Llama-3.
1818

1919
```
2020
python run_summarization.py \
2121
--input-path data/summarization/xsum.jsonl \
2222
--output-path summarization_output/xsum_h2o.jsonl \
23-
--model-name meta-llama/Llama-2-7b-hf \
23+
--model-name meta-llama/Meta-Llama-3-8B \
2424
--enable_h2o_generation
2525
```
2626

recipes/experimental/long-context/H2O/run_needle_haystack_test.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def set_seed(args):
3030
parser.add_argument("--model-name", type=str, default="")
3131

3232
parser.add_argument("--enable_h2o_generation", action='store_true')
33-
parser.add_argument("--num_heavy_hitter_tokens", type=int, default=128)
33+
parser.add_argument("--num_heavy_hitter_tokens", type=int, default=-1)
3434
parser.add_argument("--num_window_length", type=int, default=256)
3535
parser.add_argument("--num_chunk_size", type=int, default=2048)
3636

@@ -53,6 +53,10 @@ def set_seed(args):
5353
config = AutoConfig.from_pretrained(model_name)
5454
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
5555

56+
if args.num_heavy_hitter_tokens == -1:
57+
print('not assign number of heavy hitter tokens, use half of the cache size: {}'.format(args.num_window_length // 2))
58+
args.num_heavy_hitter_tokens = args.num_window_length // 2
59+
5660
if args.enable_h2o_generation:
5761
config.num_heavy_hitter_tokens = args.num_heavy_hitter_tokens
5862
config.num_window_length = args.num_window_length

recipes/experimental/long-context/H2O/run_summarization.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ def set_seed(args):
3232
parser.add_argument("--model-name", type=str, default="")
3333

3434
parser.add_argument("--enable_h2o_generation", action='store_true')
35-
parser.add_argument("--num_heavy_hitter_tokens", type=int, default=128)
35+
parser.add_argument("--num_heavy_hitter_tokens", type=int, default=-1)
3636
parser.add_argument("--num_window_length", type=int, default=256)
3737

3838
parser.add_argument("--enable_position_rolling", action='store_true')
@@ -51,6 +51,9 @@ def set_seed(args):
5151

5252
config = AutoConfig.from_pretrained(model_name)
5353
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
54+
if args.num_heavy_hitter_tokens == -1:
55+
print('not assign number of heavy hitter tokens, use half of the cache size: {}'.format(args.num_window_length // 2))
56+
args.num_heavy_hitter_tokens = args.num_window_length // 2
5457

5558
if args.enable_h2o_generation:
5659
config.num_heavy_hitter_tokens = args.num_heavy_hitter_tokens

0 commit comments

Comments
 (0)