7
7
8
8
from mlc_llm .interface .help import HELP
9
9
from mlc_llm .interface .serve import serve
10
- from mlc_llm .serve .config import ModelConfigOverride
11
10
from mlc_llm .support import argparse
12
11
from mlc_llm .support .argparse import ArgumentParser
13
12
14
13
15
14
@dataclasses .dataclass
16
- class EngineAndModelConfigOverride : # pylint: disable=too-many-instance-attributes
15
+ class EngineConfigOverride : # pylint: disable=too-many-instance-attributes
17
16
"""Arguments for overriding engine config."""
18
17
19
18
# Overrides for EngineConfig (runtime)
@@ -24,8 +23,6 @@ class EngineAndModelConfigOverride: # pylint: disable=too-many-instance-attribu
24
23
gpu_memory_utilization : Optional [float ] = None
25
24
spec_draft_length : Optional [int ] = None
26
25
prefix_cache_max_num_recycling_seqs : Optional [int ] = None
27
-
28
- # Overrides for model config (compile time)
29
26
context_window_size : Optional [int ] = None
30
27
sliding_window_size : Optional [int ] = None
31
28
attention_sink_size : Optional [int ] = None
@@ -51,7 +48,7 @@ def __repr__(self) -> str:
51
48
return out .getvalue ().rstrip ()
52
49
53
50
@staticmethod
54
- def from_str (source : str ) -> "EngineAndModelConfigOverride " :
51
+ def from_str (source : str ) -> "EngineConfigOverride " :
55
52
"""Parse engine config override values from a string."""
56
53
parser = argparse .ArgumentParser (description = "Engine config override values" )
57
54
@@ -67,7 +64,7 @@ def from_str(source: str) -> "EngineAndModelConfigOverride":
67
64
parser .add_argument ("--attention_sink_size" , type = int , default = None )
68
65
parser .add_argument ("--tensor_parallel_shards" , type = int , default = None )
69
66
results = parser .parse_args ([f"--{ i } " for i in source .split (";" ) if i ])
70
- return EngineAndModelConfigOverride (
67
+ return EngineConfigOverride (
71
68
max_num_sequence = results .max_num_sequence ,
72
69
max_total_seq_length = results .max_total_seq_length ,
73
70
prefill_chunk_size = results .prefill_chunk_size ,
@@ -81,17 +78,6 @@ def from_str(source: str) -> "EngineAndModelConfigOverride":
81
78
tensor_parallel_shards = results .tensor_parallel_shards ,
82
79
)
83
80
84
- def to_model_config_overrides (self ) -> ModelConfigOverride :
85
- """Extract the model config overrides."""
86
- return ModelConfigOverride (
87
- context_window_size = self .context_window_size ,
88
- sliding_window_size = self .sliding_window_size ,
89
- prefill_chunk_size = self .prefill_chunk_size ,
90
- attention_sink_size = self .attention_sink_size ,
91
- max_batch_size = self .max_num_sequence ,
92
- tensor_parallel_shards = self .tensor_parallel_shards ,
93
- )
94
-
95
81
96
82
def main (argv ):
97
83
"""Parse command line arguments and call `mlc_llm.interface.serve`."""
@@ -145,7 +131,7 @@ def main(argv):
145
131
)
146
132
parser .add_argument (
147
133
"--overrides" ,
148
- type = EngineAndModelConfigOverride .from_str ,
134
+ type = EngineConfigOverride .from_str ,
149
135
default = "" ,
150
136
help = HELP ["overrides_serve" ],
151
137
)
@@ -199,16 +185,19 @@ def main(argv):
199
185
mode = parsed .mode ,
200
186
enable_debug = parsed .enable_debug ,
201
187
additional_models = additional_models ,
188
+ tensor_parallel_shards = parsed .overrides .tensor_parallel_shards ,
202
189
speculative_mode = parsed .speculative_mode ,
203
190
prefix_cache_mode = parsed .prefix_cache_mode ,
204
191
max_num_sequence = parsed .overrides .max_num_sequence ,
205
192
max_total_sequence_length = parsed .overrides .max_total_seq_length ,
193
+ max_single_sequence_length = parsed .overrides .context_window_size ,
206
194
prefill_chunk_size = parsed .overrides .prefill_chunk_size ,
195
+ sliding_window_size = parsed .overrides .sliding_window_size ,
196
+ attention_sink_size = parsed .overrides .attention_sink_size ,
207
197
max_history_size = parsed .overrides .max_history_size ,
208
198
gpu_memory_utilization = parsed .overrides .gpu_memory_utilization ,
209
199
spec_draft_length = parsed .overrides .spec_draft_length ,
210
200
prefix_cache_max_num_recycling_seqs = parsed .overrides .prefix_cache_max_num_recycling_seqs ,
211
- model_config_overrides = parsed .overrides .to_model_config_overrides (),
212
201
enable_tracing = parsed .enable_tracing ,
213
202
host = parsed .host ,
214
203
port = parsed .port ,
0 commit comments