|
1 | 1 | """Generator of mlc-chat-config.json and tokenizer configuration."""
|
2 |
| - |
3 |
| -import dataclasses |
| 2 | +# pylint: disable=E1101 |
4 | 3 | import json
|
5 | 4 | import re
|
6 | 5 | import shutil
|
7 | 6 | from dataclasses import asdict
|
8 | 7 | from pathlib import Path
|
9 |
| -from typing import Any, Dict, List, Optional, Union |
| 8 | +from typing import Optional |
10 | 9 |
|
11 | 10 | from mlc_llm.conversation_template import ConvTemplateRegistry
|
12 | 11 | from mlc_llm.model import Model
|
| 12 | +from mlc_llm.protocol.mlc_chat_config import MLCChatConfig |
13 | 13 | from mlc_llm.quantization import Quantization
|
14 | 14 | from mlc_llm.support import convert_tiktoken, logging
|
15 | 15 | from mlc_llm.support.style import bold, green, red
|
|
22 | 22 | FOUND = green("Found")
|
23 | 23 | NOT_FOUND = red("Not found")
|
24 | 24 | FAILED = red("Failed")
|
25 |
| -VERSION = "0.1.0" |
26 |
| - |
27 |
| - |
28 |
| -@dataclasses.dataclass |
29 |
| -class MLCChatConfig: # pylint: disable=too-many-instance-attributes |
30 |
| - """Fields in the dumped `mlc-chat-config.json` file.""" |
31 | 25 |
|
32 |
| - model_type: str |
33 |
| - quantization: str |
34 |
| - model_config: Dict[str, Any] |
35 |
| - vocab_size: int |
36 |
| - context_window_size: int |
37 |
| - sliding_window_size: int |
38 |
| - prefill_chunk_size: int |
39 |
| - attention_sink_size: int |
40 |
| - tensor_parallel_shards: int |
41 |
| - # Control the behavior of the runtime |
42 |
| - mean_gen_len: int = None |
43 |
| - max_gen_len: int = None |
44 |
| - shift_fill_factor: float = None |
45 |
| - # Configuration of text generation |
46 |
| - temperature: float = None |
47 |
| - presence_penalty: float = None |
48 |
| - frequency_penalty: float = None |
49 |
| - repetition_penalty: float = None |
50 |
| - top_p: float = None |
51 |
| - # Conversation template |
52 |
| - conv_template: Union[str, Dict[str, Any]] = None |
53 |
| - pad_token_id: int = None |
54 |
| - bos_token_id: int = None |
55 |
| - eos_token_id: int = None |
56 |
| - # Tokenizer configuration |
57 |
| - tokenizer_files: List[str] = dataclasses.field(default_factory=list) |
58 |
| - # The content of tokenizer.TokenizerInfo |
59 |
| - tokenizer_info: Dict[str, Any] = dataclasses.field(default_factory=dict) |
60 |
| - # Version control |
61 |
| - version: str = VERSION |
62 | 26 |
|
63 |
| - def apply_defaults(self) -> None: |
64 |
| - """Apply system default value.""" |
65 |
| - defaults = { |
66 |
| - "pad_token_id": 0, |
67 |
| - "bos_token_id": 1, |
68 |
| - "eos_token_id": 2, |
69 |
| - "temperature": 0.7, |
70 |
| - "presence_penalty": 0.0, |
71 |
| - "frequency_penalty": 0.0, |
72 |
| - "repetition_penalty": 1.0, |
73 |
| - "top_p": 0.95, |
74 |
| - "mean_gen_len": 128, |
75 |
| - "max_gen_len": 512, |
76 |
| - "shift_fill_factor": 0.3, |
77 |
| - } |
78 |
| - for key, value in defaults.items(): |
79 |
| - if getattr(self, key) is None: |
80 |
| - setattr(self, key, value) |
81 |
| - logger.info("[System default] Setting %s: %s", bold(key), value) |
| 27 | +def apply_system_defaults_for_missing_fields(mlc_chat_config: MLCChatConfig) -> None: |
| 28 | + """Apply system default value.""" |
| 29 | + for key, value in mlc_chat_config.get_system_defaults_for_missing_fields().items(): |
| 30 | + setattr(mlc_chat_config, key, value) |
| 31 | + logger.info("[System default] Setting %s: %s", bold(key), value) |
82 | 32 |
|
83 | 33 |
|
84 | 34 | def check_string(s: str) -> bool:
|
@@ -265,10 +215,10 @@ def gen_config( # pylint: disable=too-many-locals,too-many-arguments,too-many-b
|
265 | 215 | logger.info("Detected tokenizer info: %s", mlc_chat_config.tokenizer_info)
|
266 | 216 |
|
267 | 217 | # Step 4. Load system default value
|
268 |
| - mlc_chat_config.apply_defaults() |
| 218 | + apply_system_defaults_for_missing_fields(mlc_chat_config) |
269 | 219 | # Step 5. Dump the configuration file to output directory
|
270 | 220 | with (output / "mlc-chat-config.json").open("w", encoding="utf-8") as out_file:
|
271 |
| - json.dump(dataclasses.asdict(mlc_chat_config), out_file, indent=2) |
| 221 | + json.dump(mlc_chat_config.model_dump(), out_file, indent=2) |
272 | 222 | logger.info("Dumping configuration file to: %s", bold(out_file.name))
|
273 | 223 |
|
274 | 224 |
|
|
0 commit comments