|
| 1 | +# Copyright 2024 Advanced Micro Devices, Inc. |
| 2 | +# |
| 3 | +# Licensed under the Apache License v2.0 with LLVM Exceptions. |
| 4 | +# See https://llvm.org/LICENSE.txt for license information. |
| 5 | +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | + |
| 7 | +"""Generates data files for calling iree-run-module from a prompt and config. |
| 8 | +
|
| 9 | +Usage: |
| 10 | + $ python -m sharktank.models.llama.tools.generate_data \ |
| 11 | + --tokenizer=openlm-research/open_llama_3b_v2 \ |
| 12 | + --config=/tmp/open-llama-3b-v2-f16.json \ |
| 13 | + --output-dir=/tmp/inputs \ |
| 14 | + --prompt="What is the meaning of life?" |
| 15 | +
|
| 16 | + $ ls /tmp/inputs |
| 17 | +
|
| 18 | + arg0.bin |
| 19 | + arg1.bin |
| 20 | + arg2.bin |
| 21 | + arg3.bin |
| 22 | +
|
| 23 | + $ iree-run-module \ |
| 24 | + --module=/tmp/open-llama-3b-v2-f16_cpu.vmfb \ |
| 25 | + --parameters=model=/tmp/open-llama-3b-v2-f16.gguf \ |
| 26 | + --function=prefill_bs4 \ |
| 27 | + --device=local-task \ |
| 28 | + --input=4x1xi64=@/tmp/inputs/arg0.bin \ |
| 29 | + --input=4xi64=@/tmp/inputs/arg1.bin \ |
| 30 | + --input=4x1xi64=@/tmp/inputs/arg2.bin \ |
| 31 | + --input=1x2662400xf16=@/tmp/inputs/arg3.bin |
| 32 | +
|
| 33 | +# TODO(scotttodd): similar script to convert outputs to text via tokenizer |
| 34 | +# TODO(scotttodd): teach service_v1_cli to also dump its inputs/outputs? |
| 35 | +# TODO(scotttodd): generate expected outputs using reference model? |
| 36 | +""" |
| 37 | + |
| 38 | +from pathlib import Path |
| 39 | +import logging |
| 40 | +import sys |
| 41 | +import json |
| 42 | +import numpy as np |
| 43 | + |
| 44 | +from transformers import LlamaTokenizer # type: ignore |
| 45 | + |
| 46 | +from ....utils.logging import get_logger |
| 47 | +from .data_utils import write_ndarray_to_bin |
| 48 | + |
| 49 | +logger = get_logger("sharktank.models.llama.tools.generate_data") |
| 50 | + |
| 51 | + |
| 52 | +def main(argv): |
| 53 | + from ....utils import cli |
| 54 | + |
| 55 | + parser = cli.create_parser() |
| 56 | + parser.add_argument( |
| 57 | + "--tokenizer", help="name of hugginface tokenizer to use", required=True |
| 58 | + ) |
| 59 | + parser.add_argument( |
| 60 | + "--config", |
| 61 | + type=Path, |
| 62 | + help="json config file with hyperparameters", |
| 63 | + required=True, |
| 64 | + ) |
| 65 | + parser.add_argument( |
| 66 | + "--output-dir", |
| 67 | + type=Path, |
| 68 | + help="Generate .bin files into this directory", |
| 69 | + required=True, |
| 70 | + ) |
| 71 | + parser.add_argument("--prompt", help="Prompt string", required=True) |
| 72 | + # TODO(scotttodd): output path (directory to dump .bin/.npy files) |
| 73 | + args = cli.parse(parser, args=argv) |
| 74 | + |
| 75 | + # Load config hyperparameters. |
| 76 | + with open(args.config) as f: |
| 77 | + config = json.load(f) |
| 78 | + logger.info("Loaded config with hyperparameters:") |
| 79 | + logger.info(json.dumps(config, indent=4)) |
| 80 | + |
| 81 | + # Load tokenizer. |
| 82 | + # TODO(scotttodd): Unify tokenizer flags across sharktank and shortfin? |
| 83 | + # cli.add_tokenizer_options(parser) |
| 84 | + # tokenizer = cli.get_tokenizer(args) |
| 85 | + tokenizer = LlamaTokenizer.from_pretrained(args.tokenizer, legacy=False) |
| 86 | + |
| 87 | + # TODO(scotttodd): loop over batch sizes (generate one dataset per batch size) |
| 88 | + prefill_batch_size = config["prefill_batch_sizes"][0] |
| 89 | + |
| 90 | + # Declare input arguments. |
| 91 | + # TODO(scotttodd): compute max_seq_len from tokens, _not_ config here |
| 92 | + arg0_prefill_tokens = np.zeros( |
| 93 | + [prefill_batch_size, config["max_seq_len"]], dtype=np.int64 |
| 94 | + ) |
| 95 | + arg1_prefill_seq_lens = np.zeros(prefill_batch_size, dtype=np.int64) |
| 96 | + # TODO(scotttodd): arg2 - attention block indices |
| 97 | + # TODO(scotttodd): arg3 - attention block buffer |
| 98 | + |
| 99 | + # Populate input arguments. |
| 100 | + # TODO(scotttodd): loop over 1 prompt per batch here (or duplicate) |
| 101 | + prompt = args.prompt |
| 102 | + prompt_tokens = tokenizer.encode(prompt, return_tensors="pt")[0].tolist() |
| 103 | + logger.info(f"prompt -> encoded tokens: {prompt_tokens}") |
| 104 | + prompt_seq_len = len(prompt_tokens) |
| 105 | + arg0_prefill_tokens[0, 0:prompt_seq_len] = prompt_tokens |
| 106 | + arg1_prefill_seq_lens[0] = prompt_seq_len |
| 107 | + with np.printoptions(threshold=np.inf): |
| 108 | + logger.debug("arg0_prefill_tokens:") |
| 109 | + logger.debug(arg0_prefill_tokens) |
| 110 | + logger.debug("arg1_prefill_seq_lens:") |
| 111 | + logger.debug(arg1_prefill_seq_lens) |
| 112 | + |
| 113 | + logger.info(f"Writing argument .bin files to '{args.output_dir}'") |
| 114 | + args.output_dir.mkdir(parents=True, exist_ok=True) |
| 115 | + write_ndarray_to_bin(arg0_prefill_tokens, args.output_dir / "arg0.bin") |
| 116 | + write_ndarray_to_bin(arg1_prefill_seq_lens, args.output_dir / "arg1.bin") |
| 117 | + |
| 118 | + |
| 119 | +if __name__ == "__main__": |
| 120 | + main(argv=sys.argv[1:]) |
0 commit comments