Skip to content

Commit 0bd1b1e

Browse files
[sharktank] Revert [sharktank] Enable sample token generation for LLM (#1277)
Revert [sharktank] Enable sample token generation for LLM #1144, because it regresses numerics for fp8 quark parity tests.
1 parent cb25984 commit 0bd1b1e

File tree

9 files changed

+192
-189
lines changed

9 files changed

+192
-189
lines changed

docs/model_cookbook.md

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ python ~/llama.cpp/convert_hf_to_gguf.py --outtype f32 --outfile /tmp/mistral-7b
121121
python -m sharktank.examples.paged_llm_v1 \
122122
--gguf-file=/tmp/mistral-7b-v0.1-f32.gguf \
123123
--tokenizer-config-json=/tmp/mistral-7b/tokenizer_config.json \
124-
--prompt "Prompt"
124+
"Prompt"
125125

126126
# Export as MLIR
127127
python -m sharktank.examples.export_paged_llm_v1 \
@@ -149,7 +149,7 @@ For example, to run the
149149
[SlyEcho/open_llama_3b_v2_gguf](https://huggingface.co/SlyEcho/open_llama_3b_v2_gguf):
150150

151151
```bash
152-
python -m sharktank.examples.paged_llm_v1 --hf-dataset=open_llama_3b_v2_q8_0_gguf --prompt "Prompt 1"
152+
python -m sharktank.examples.paged_llm_v1 --hf-dataset=open_llama_3b_v2_q8_0_gguf "Prompt 1"
153153

154154
open-llama-3b-v2-q8_0.gguf: 100%|█████████████████████████████| 3.64G/3.64G [01:35<00:00, 38.3MB/s]
155155
tokenizer.model: 100%|███████████████████████████████████████████| 512k/512k [00:00<00:00, 128MB/s]
@@ -259,13 +259,13 @@ iree-run-module \
259259

260260
[Instructions](../sharktank/sharktank/evaluate/README.md) to run perplexity test
261261

262-
## Generate sample input tokens for IREE inference/tracy:
262+
## Generating data for llama models
263263

264264
```bash
265-
python -m sharktank.examples.paged_llm_v1 \
266-
--hf-dataset=open_llama_3b_v2_f16_gguf \
267-
--prompt-seq-len=128 \
268-
--bs=4 \
269-
--dump-decode-steps=1 \
270-
--dump-path='/tmp'
265+
set TURBINE_DEBUG=log_level=info
266+
python -m sharktank.models.llama.tools.generate_data \
267+
--tokenizer=openlm-research/open_llama_3b_v2 \
268+
--config=/tmp/open_llama_3b_v2/open-llama-3b-v2-f16.json \
269+
--output-dir=/tmp/open_llama_3b_v2/inputs \
270+
--prompt="What is the meaning of life?"
271271
```

sharktank/README.md

Lines changed: 2 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,8 @@ Note: Use `--device='cuda:0'` to run this inference on an AMD GPU.
2828
```shell
2929
python -m sharktank.examples.paged_llm_v1 \
3030
--hf-dataset=open_llama_3b_v2_f16_gguf \
31-
--prompt "Prompt 1" "Prompt 2" ...
31+
"Prompt 1" \
32+
"Prompt 2" ...
3233
```
3334

3435
### Export an IREE compilable batched LLM for serving:
@@ -40,17 +41,6 @@ python -m sharktank.examples.export_paged_llm_v1 \
4041
--output-config=/tmp/open_llama_3b_v2_f16.json
4142
```
4243

43-
### Generate sample input tokens for IREE inference/tracy:
44-
45-
```shell
46-
python -m sharktank.examples.paged_llm_v1 \
47-
--hf-dataset=open_llama_3b_v2_f16_gguf \
48-
--prompt-seq-len=128 \
49-
--bs=4 \
50-
--dump-decode-steps=1 \
51-
--dump-path='/tmp'
52-
```
53-
5444
### Dump parsed information about a model from a gguf file:
5545

5646
```shell

sharktank/sharktank/examples/paged_llm_v1.py

Lines changed: 5 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -18,16 +18,7 @@
1818

1919

2020
def main():
21-
"""
22-
Run LLM inference in torch/eager mode. Use --device='cuda:0' to run on AMD GPU
23-
Args:
24-
--prompt: list[str] - Custom space separated prompts
25-
--prompt-seq-len: int - Generate random token ids for given seq len and bs and save prefill & first decode step input args as npy files
26-
--dump-path: str - Path to save prefill and decode input args as npy files
27-
--dump-decode-steps: int - Number of decode steps to dump decode args (defaults to 1 decode step)
28-
--bs: int - batch size, for custom prompts, bs is number of given prompts (defaults to 4)
29-
--save_intermediates_path: str - save module forward outputs to safetensors, ex: run_0 will save to run_0_prefill.savetensors"
30-
"""
21+
from ..utils import cli
3122

3223
parser = cli.create_parser()
3324
cli.add_input_dataset_options(parser)
@@ -38,13 +29,6 @@ def main():
3829
cli.add_save_tensor_options(parser)
3930

4031
args = cli.parse(parser)
41-
42-
prompt_seq_len = args.prompt_seq_len
43-
44-
assert (
45-
args.prompt or prompt_seq_len
46-
), "Pass --prompt for custom prompts or --prompt-seq-len and --bs to generate random token ids"
47-
4832
device = torch.device(args.device) if args.device else None
4933
dataset = cli.get_input_dataset(args)
5034
tokenizer = cli.get_tokenizer(args)
@@ -74,15 +58,11 @@ def main():
7458

7559
generator = TorchGenerator(model, tokenizer)
7660

77-
token_ids, seq_lens = generator.preprocess_prompts(
78-
prompts=args.prompt, prompt_seq_len=prompt_seq_len, bs=args.bs
79-
)
61+
token_ids, seq_lens = generator.preprocess_prompts(prompts=args.prompt)
8062
batch = generator.begin_batch(
8163
token_ids=token_ids,
8264
seq_lens=seq_lens,
83-
prompt_seq_len=prompt_seq_len,
84-
dump_path=args.dump_path,
85-
dump_decode_steps=args.dump_decode_steps,
65+
dump_bins=args.dump_bins,
8666
)
8767
results = batch.prefill()
8868
batch.print_current_results()
@@ -101,7 +81,8 @@ def main():
10181
intermediates_saver.save_file(
10282
args.save_intermediates_path + f"_step_{counter}.safetensors"
10383
)
104-
84+
print(f":: Result tokens: {batch.results}")
85+
batch.print_current_results()
10586
counter += 1
10687

10788
if len(batch.parent.free_pages) == 0:
Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
# Copyright 2024 Advanced Micro Devices, Inc.
2+
#
3+
# Licensed under the Apache License v2.0 with LLVM Exceptions.
4+
# See https://llvm.org/LICENSE.txt for license information.
5+
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
7+
"""Generates data files for calling iree-run-module from a prompt and config.
8+
9+
Usage:
10+
$ python -m sharktank.models.llama.tools.generate_data \
11+
--tokenizer=openlm-research/open_llama_3b_v2 \
12+
--config=/tmp/open-llama-3b-v2-f16.json \
13+
--output-dir=/tmp/inputs \
14+
--prompt="What is the meaning of life?"
15+
16+
$ ls /tmp/inputs
17+
18+
arg0.bin
19+
arg1.bin
20+
arg2.bin
21+
arg3.bin
22+
23+
$ iree-run-module \
24+
--module=/tmp/open-llama-3b-v2-f16_cpu.vmfb \
25+
--parameters=model=/tmp/open-llama-3b-v2-f16.gguf \
26+
--function=prefill_bs4 \
27+
--device=local-task \
28+
--input=4x1xi64=@/tmp/inputs/arg0.bin \
29+
--input=4xi64=@/tmp/inputs/arg1.bin \
30+
--input=4x1xi64=@/tmp/inputs/arg2.bin \
31+
--input=1x2662400xf16=@/tmp/inputs/arg3.bin
32+
33+
# TODO(scotttodd): similar script to convert outputs to text via tokenizer
34+
# TODO(scotttodd): teach service_v1_cli to also dump its inputs/outputs?
35+
# TODO(scotttodd): generate expected outputs using reference model?
36+
"""
37+
38+
from pathlib import Path
39+
import logging
40+
import sys
41+
import json
42+
import numpy as np
43+
44+
from transformers import LlamaTokenizer # type: ignore
45+
46+
from ....utils.logging import get_logger
47+
from .data_utils import write_ndarray_to_bin
48+
49+
logger = get_logger("sharktank.models.llama.tools.generate_data")
50+
51+
52+
def main(argv):
53+
from ....utils import cli
54+
55+
parser = cli.create_parser()
56+
parser.add_argument(
57+
"--tokenizer", help="name of hugginface tokenizer to use", required=True
58+
)
59+
parser.add_argument(
60+
"--config",
61+
type=Path,
62+
help="json config file with hyperparameters",
63+
required=True,
64+
)
65+
parser.add_argument(
66+
"--output-dir",
67+
type=Path,
68+
help="Generate .bin files into this directory",
69+
required=True,
70+
)
71+
parser.add_argument("--prompt", help="Prompt string", required=True)
72+
# TODO(scotttodd): output path (directory to dump .bin/.npy files)
73+
args = cli.parse(parser, args=argv)
74+
75+
# Load config hyperparameters.
76+
with open(args.config) as f:
77+
config = json.load(f)
78+
logger.info("Loaded config with hyperparameters:")
79+
logger.info(json.dumps(config, indent=4))
80+
81+
# Load tokenizer.
82+
# TODO(scotttodd): Unify tokenizer flags across sharktank and shortfin?
83+
# cli.add_tokenizer_options(parser)
84+
# tokenizer = cli.get_tokenizer(args)
85+
tokenizer = LlamaTokenizer.from_pretrained(args.tokenizer, legacy=False)
86+
87+
# TODO(scotttodd): loop over batch sizes (generate one dataset per batch size)
88+
prefill_batch_size = config["prefill_batch_sizes"][0]
89+
90+
# Declare input arguments.
91+
# TODO(scotttodd): compute max_seq_len from tokens, _not_ config here
92+
arg0_prefill_tokens = np.zeros(
93+
[prefill_batch_size, config["max_seq_len"]], dtype=np.int64
94+
)
95+
arg1_prefill_seq_lens = np.zeros(prefill_batch_size, dtype=np.int64)
96+
# TODO(scotttodd): arg2 - attention block indices
97+
# TODO(scotttodd): arg3 - attention block buffer
98+
99+
# Populate input arguments.
100+
# TODO(scotttodd): loop over 1 prompt per batch here (or duplicate)
101+
prompt = args.prompt
102+
prompt_tokens = tokenizer.encode(prompt, return_tensors="pt")[0].tolist()
103+
logger.info(f"prompt -> encoded tokens: {prompt_tokens}")
104+
prompt_seq_len = len(prompt_tokens)
105+
arg0_prefill_tokens[0, 0:prompt_seq_len] = prompt_tokens
106+
arg1_prefill_seq_lens[0] = prompt_seq_len
107+
with np.printoptions(threshold=np.inf):
108+
logger.debug("arg0_prefill_tokens:")
109+
logger.debug(arg0_prefill_tokens)
110+
logger.debug("arg1_prefill_seq_lens:")
111+
logger.debug(arg1_prefill_seq_lens)
112+
113+
logger.info(f"Writing argument .bin files to '{args.output_dir}'")
114+
args.output_dir.mkdir(parents=True, exist_ok=True)
115+
write_ndarray_to_bin(arg0_prefill_tokens, args.output_dir / "arg0.bin")
116+
write_ndarray_to_bin(arg1_prefill_seq_lens, args.output_dir / "arg1.bin")
117+
118+
119+
if __name__ == "__main__":
120+
main(argv=sys.argv[1:])

sharktank/sharktank/utils/cli.py

Lines changed: 3 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -205,27 +205,9 @@ def add_save_tensor_options(parser: argparse.ArgumentParser):
205205
help="save module forward outputs to safetensors, ex: run_0 will save to run_0_prefill.savetensors",
206206
)
207207
parser.add_argument(
208-
"--dump-path",
209-
help="Path to dump prefill/decode input tensors to npy files",
210-
type=str,
211-
default=None,
212-
)
213-
parser.add_argument(
214-
"--dump-decode-steps",
215-
help="Number of decode steps to dump decode input tensors",
216-
type=int,
217-
default=1,
218-
)
219-
parser.add_argument(
220-
"--prompt-seq-len",
221-
help="Seq len to generate input prompts for prefill",
222-
type=int,
223-
)
224-
parser.add_argument(
225-
"--bs",
226-
help="Batch size",
227-
type=int,
228-
default="4",
208+
"--dump-bins",
209+
help="dump input tensors to bin files",
210+
action="store_true",
229211
)
230212

231213

0 commit comments

Comments
 (0)