forked from NVIDIA/TensorRT-LLM
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathsimple_config.py
More file actions
136 lines (116 loc) · 7.11 KB
/
simple_config.py
File metadata and controls
136 lines (116 loc) · 7.11 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
"""A simple config for Llama-2 building and generating scripts.
Modify directly if you want to change settings.
"""
from dataclasses import dataclass, field
from typing import Dict, List, Literal, Optional, Union
@dataclass
class SimpleConfig:
"""Experiment Configuration."""
### MODEL ARG #############################################################################
# Path or repo_id for a HF model directory
# The model directory should contain model weights and tokenizer configs. Model weights can be
# provided as either of the following:
# 1. Sharded checkpoint (multiple files) in the safetensors format
# 2. Single, unsharded checkpoint in the safetensors format
# 3. Single, unsharded checkpoint in the pytorch format (.pt/.pth) file ending.
model: str
# same as model. None defaults to model. Only used if customize_tokenizer is True
tokenizer: Optional[str] = None
model_factory: Literal["AutoModelForCausalLM", "AutoModelForImageTextToText"] = (
"AutoModelForCausalLM"
)
skip_loading_weights: bool = False # only load the architecture, not the weights
customize_tokenizer: bool = False # True: tokenizer from the model factory, False: from LLM api
### MODEL EXTRA KWARGS #########################################################################
# Extra kwargs for the model config class to customize the model config. Those arguments will
# take precedence over the default values or config values in the model config file in the HF
# directory. Arguments are resolved in the following order:
# 1. Default values in the model config class
# 2. Values in the model config file in the HF directory
# 3. Values in the model_kwargs
# Note that that if the kwarg does not exist in the model config class, it will be ignored.
# An example model config class can be found [here](https://github.com/huggingface/transformers/blob/c409cd81777fb27aadc043ed3d8339dbc020fb3b/src/transformers/models/llama/configuration_llama.py#L26).
model_kwargs: Dict = field(default_factory=dict)
### TOKENIZER EXTRA KWARGS #####################################################################
# Extra kwargs for the tokenizer class to customize the tokenizer. Same as model_kwargs.
# For example, the default HF Llama tokenizer can be initialized with the arguments specified
# [here](https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/tokenization_llama_fast.py#L127).
# NOTE: This is only used if customize_tokenizer is True
tokenizer_kwargs: Dict = field(default_factory=dict)
### CONFIGURE BACKEND, RUNTIME, AND WORLD SIZE ##################################
world_size: int = 1 # choose from number of GPUs for TP (0--> no TP, no spawned processes)
runtime: Literal["demollm", "trtllm"] = "trtllm"
compile_backend: Literal["torch-simple", "torch-compile", "torch-cudagraph", "torch-opt"] = (
"torch-compile"
)
attn_backend: Literal["TritonWithFlattenedInputs", "FlashInfer"] = "FlashInfer"
mla_backend: Literal["MultiHeadLatentAttention"] = "MultiHeadLatentAttention"
max_seq_len: int = 512 # max sequence length for inference/cache
max_batch_size: int = 8 # max dimension for statically allocated kv cache
page_size: int = 64 # page size for attention
simple_shard_only: bool = False # if True, force simple sharding(all_gather) in TP;
# otherwise auto-detect and use column+row (all_reduce) sharding
### SOME SIMPLE PROMPTING CONFIG ###############################################################
batch_size: int = 2 # example input shape
device: str = "cuda"
prompt: Union[str, List[str]] = field(
default_factory=lambda: [
"How big is the universe? ",
"In simple words and in a single sentence, explain the concept of gravity: ",
"How to fix slicing in golf? ",
"Where is the capital of Iceland? ",
"How big is the universe? ",
"In simple words and in a single sentence, explain the concept of gravity: ",
"How to fix slicing in golf? ",
"Where is the capital of Iceland? ",
]
)
max_tokens: int = 100
top_k: int = 200
temperature: float = 1.0
visualize: bool = False
### BENCHMARKING CONFIG ########################################################################
free_mem_ratio: float = 0.0 # specifies the fraction of available memory to occupy for cache
benchmark: bool = False # If true, set ISO to 2048 random int and OSL to 128
benchmark_num: int = 10 # By default run 10 times and get average
benchmark_isl: int = 2048 # input seq length for benchmarking
benchmark_osl: int = 128 # output seq length for benchmarking
benchmark_bs: int = 1 # batch size for benchmarking
benchmark_results_path: Optional[str] = "./benchmark_results.json"
benchmark_store_results: bool = False # if True, store benchmark res in benchmark_results_path
### POST INITIALIZATION ########################################################################
def __post_init__(self):
# check if model was supplied
assert self.model, "model must be supplied!"
# we don't want to loose the default values for model_kwargs unless explicitly set by the
# user. They are not preserved by the standard initialization process since they whole dict
# gets replaced by the user provided one. We don't want that though.
f_default = self.__dataclass_fields__["model_kwargs"].default_factory()
setattr(self, "model_kwargs", {**f_default, **getattr(self, "model_kwargs")})
# special handling for torch_dtype in model_kwargs since HF does not correctly update
# torch_dtype string to an actual torch.dtype object (only with default)
if "torch_dtype" in self.model_kwargs:
import torch
dtype = self.model_kwargs["torch_dtype"]
if isinstance(dtype, str):
dtype = getattr(torch, self.model_kwargs["torch_dtype"])
assert isinstance(dtype, torch.dtype), f"Invalid dtype: {dtype}"
self.model_kwargs["torch_dtype"] = dtype
self.max_batch_size = max(self.max_batch_size, self.batch_size)
# make sure benchmark isl/osl/bs fits into available resources
if self.benchmark:
self.max_batch_size = max(self.benchmark_bs, self.max_batch_size)
self.max_seq_len = max(self.max_seq_len, self.benchmark_isl + self.benchmark_osl)
# No paging allowed in TritonWithFlattenedInputs
if self.attn_backend in ["TritonWithFlattenedInputs"]:
self.page_size = self.max_seq_len
# use min instead of max to avoid OOM for large batch size
self.model_kwargs["max_position_embeddings"] = min(
self.max_seq_len,
self.model_kwargs.get("max_position_embeddings", self.max_seq_len),
)
if isinstance(self.prompt, str):
self.prompt = [self.prompt]
# replicate prompts to get to batch_size
prompts = self.prompt * (self.batch_size // len(self.prompt) + 1)
self.prompt = prompts[: self.batch_size]