Skip to content

Commit cde5a5e

Browse files
authored
Override HF config.json via CLI (#3722)
* hf config overrides * TM support * add default val * fix for yaml safe dump * add testcases * change hf_overrides positions * optimize * fix arg helper, add warnings * remove UT * fix UT
1 parent 9625d82 commit cde5a5e

File tree

9 files changed

+207
-44
lines changed

9 files changed

+207
-44
lines changed

lmdeploy/cli/cli.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,16 @@
11
# Copyright (c) OpenMMLab. All rights reserved.
22

3-
import argparse
43
import os
54

65
from ..version import __version__
7-
from .utils import ArgumentHelper, DefaultsAndTypesHelpFormatter, convert_args, get_chat_template, get_lora_adapters
6+
from .utils import (ArgumentHelper, DefaultsAndTypesHelpFormatter, FlexibleArgumentParser, convert_args,
7+
get_chat_template, get_lora_adapters)
88

99

1010
class CLI(object):
1111
_desc = 'The CLI provides a unified API for converting, ' \
1212
'compressing and deploying large language models.'
13-
parser = argparse.ArgumentParser(prog='lmdeploy', description=_desc, add_help=True)
13+
parser = FlexibleArgumentParser(prog='lmdeploy', description=_desc, add_help=True)
1414
parser.add_argument('-v', '--version', action='version', version=__version__)
1515
subparsers = parser.add_subparsers(title='Commands', description='lmdeploy has following commands:', dest='command')
1616

lmdeploy/cli/serve.py

Lines changed: 27 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,7 @@ def add_parser_api_server():
165165
max_prefill_token_num_act = ArgumentHelper.max_prefill_token_num(pt_group)
166166
quant_policy = ArgumentHelper.quant_policy(pt_group)
167167
model_format = ArgumentHelper.model_format(pt_group)
168+
hf_overrides = ArgumentHelper.hf_overrides(pt_group)
168169
ArgumentHelper.dp(pt_group)
169170
ArgumentHelper.ep(pt_group)
170171
ArgumentHelper.enable_microbatch(pt_group)
@@ -189,6 +190,7 @@ def add_parser_api_server():
189190
tb_group._group_actions.append(max_prefill_token_num_act)
190191
tb_group._group_actions.append(quant_policy)
191192
tb_group._group_actions.append(model_format)
193+
tb_group._group_actions.append(hf_overrides)
192194
ArgumentHelper.rope_scaling_factor(tb_group)
193195
ArgumentHelper.num_tokens_per_iter(tb_group)
194196
ArgumentHelper.max_prefill_iters(tb_group)
@@ -318,26 +320,29 @@ def api_server(args):
318320
if backend == 'pytorch':
319321
from lmdeploy.messages import PytorchEngineConfig
320322
adapters = get_lora_adapters(args.adapters)
321-
backend_config = PytorchEngineConfig(dtype=args.dtype,
322-
tp=args.tp,
323-
dp=args.dp,
324-
ep=args.ep,
325-
max_batch_size=max_batch_size,
326-
cache_max_entry_count=args.cache_max_entry_count,
327-
block_size=args.cache_block_seq_len,
328-
session_len=args.session_len,
329-
adapters=adapters,
330-
enable_prefix_caching=args.enable_prefix_caching,
331-
device_type=args.device,
332-
quant_policy=args.quant_policy,
333-
eager_mode=args.eager_mode,
334-
max_prefill_token_num=args.max_prefill_token_num,
335-
enable_microbatch=args.enable_microbatch,
336-
enable_eplb=args.enable_eplb,
337-
enable_metrics=args.enable_metrics,
338-
role=EngineRole[args.role],
339-
migration_backend=MigrationBackend[args.migration_backend],
340-
model_format=args.model_format)
323+
backend_config = PytorchEngineConfig(
324+
dtype=args.dtype,
325+
tp=args.tp,
326+
dp=args.dp,
327+
ep=args.ep,
328+
max_batch_size=max_batch_size,
329+
cache_max_entry_count=args.cache_max_entry_count,
330+
block_size=args.cache_block_seq_len,
331+
session_len=args.session_len,
332+
adapters=adapters,
333+
enable_prefix_caching=args.enable_prefix_caching,
334+
device_type=args.device,
335+
quant_policy=args.quant_policy,
336+
eager_mode=args.eager_mode,
337+
max_prefill_token_num=args.max_prefill_token_num,
338+
enable_microbatch=args.enable_microbatch,
339+
enable_eplb=args.enable_eplb,
340+
enable_metrics=args.enable_metrics,
341+
role=EngineRole[args.role],
342+
migration_backend=MigrationBackend[args.migration_backend],
343+
model_format=args.model_format,
344+
hf_overrides=args.hf_overrides,
345+
)
341346
else:
342347
from lmdeploy.messages import TurbomindEngineConfig
343348
backend_config = TurbomindEngineConfig(dtype=args.dtype,
@@ -351,7 +356,8 @@ def api_server(args):
351356
cache_block_seq_len=args.cache_block_seq_len,
352357
enable_prefix_caching=args.enable_prefix_caching,
353358
max_prefill_token_num=args.max_prefill_token_num,
354-
communicator=args.communicator)
359+
communicator=args.communicator,
360+
hf_overrides=args.hf_overrides)
355361
chat_template_config = get_chat_template(args.chat_template)
356362

357363
from lmdeploy.messages import VisionConfig

lmdeploy/cli/utils.py

Lines changed: 103 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,11 @@
11
# Copyright (c) OpenMMLab. All rights reserved.
22

33
import argparse
4-
from typing import List
4+
import json
5+
import re
6+
import sys
7+
from collections import defaultdict
8+
from typing import Any, List
59

610

711
class DefaultsAndTypesHelpFormatter(argparse.HelpFormatter):
@@ -231,6 +235,14 @@ def rope_scaling_factor(parser):
231235

232236
return parser.add_argument('--rope-scaling-factor', type=float, default=0.0, help='Rope scaling factor')
233237

238+
@staticmethod
239+
def hf_overrides(parser):
240+
"""Add argument hf_overrides to parser."""
241+
return parser.add_argument('--hf-overrides',
242+
type=json.loads,
243+
default=None,
244+
help='Extra arguments to be forwarded to the HuggingFace config.')
245+
234246
@staticmethod
235247
def use_logn_attn(parser):
236248
"""Add argument use_logn_attn to parser."""
@@ -580,3 +592,93 @@ def migration_backend(parser):
580592
default='DLSlime',
581593
choices=['DLSlime', 'Mooncake'],
582594
help='kvcache migration management backend when PD disaggregation')
595+
596+
597+
# adapted from https://github.com/vllm-project/vllm/blob/main/vllm/utils/__init__.py
598+
class FlexibleArgumentParser(argparse.ArgumentParser):
599+
""""More flexible argument parser."""
600+
601+
def parse_args(self, args=None, namespace=None):
602+
# If args is not provided, use arguments from the command line
603+
if args is None:
604+
args = sys.argv[1:]
605+
606+
def repl(match: re.Match) -> str:
607+
"""Replaces underscores with dashes in the matched string."""
608+
return match.group(0).replace('_', '-')
609+
610+
# Everything between the first -- and the first .
611+
pattern = re.compile(r'(?<=--)[^\.]*')
612+
613+
# Convert underscores to dashes and vice versa in argument names
614+
processed_args = []
615+
for arg in args:
616+
if arg.startswith('--'):
617+
if '=' in arg:
618+
key, value = arg.split('=', 1)
619+
key = pattern.sub(repl, key, count=1)
620+
processed_args.append(f'{key}={value}')
621+
else:
622+
key = pattern.sub(repl, arg, count=1)
623+
processed_args.append(key)
624+
elif arg.startswith('-O') and arg != '-O' and len(arg) == 2:
625+
# allow -O flag to be used without space, e.g. -O3
626+
processed_args.append('-O')
627+
processed_args.append(arg[2:])
628+
else:
629+
processed_args.append(arg)
630+
631+
def _try_convert(value: str):
632+
"""Try to convert string to float or int."""
633+
if not isinstance(value, str):
634+
return value
635+
# try loads from json
636+
try:
637+
return json.loads(value)
638+
except json.JSONDecodeError:
639+
pass
640+
return value
641+
642+
def create_nested_dict(keys: list[str], value: str):
643+
"""Creates a nested dictionary from a list of keys and a value.
644+
645+
For example, `keys = ["a", "b", "c"]` and `value = 1` will create: `{"a": {"b": {"c": 1}}}`
646+
"""
647+
nested_dict: Any = _try_convert(value)
648+
for key in reversed(keys):
649+
nested_dict = {key: nested_dict}
650+
return nested_dict
651+
652+
def recursive_dict_update(original: dict, update: dict):
653+
"""Recursively updates a dictionary with another dictionary."""
654+
for k, v in update.items():
655+
if isinstance(v, dict) and isinstance(original.get(k), dict):
656+
recursive_dict_update(original[k], v)
657+
else:
658+
original[k] = v
659+
660+
delete = set()
661+
dict_args: dict[str, dict] = defaultdict(dict)
662+
for i, processed_arg in enumerate(processed_args):
663+
if processed_arg.startswith('--') and '.' in processed_arg:
664+
if '=' in processed_arg:
665+
processed_arg, value = processed_arg.split('=', 1)
666+
if '.' not in processed_arg:
667+
# False positive, . was only in the value
668+
continue
669+
else:
670+
value = processed_args[i + 1]
671+
delete.add(i + 1)
672+
key, *keys = processed_arg.split('.')
673+
# Merge all values with the same key into a single dict
674+
arg_dict = create_nested_dict(keys, value)
675+
recursive_dict_update(dict_args[key], arg_dict)
676+
delete.add(i)
677+
# Filter out the dict args we set to None
678+
processed_args = [a for i, a in enumerate(processed_args) if i not in delete]
679+
# Add the dict args back as if they were originally passed as JSON
680+
for dict_arg, dict_value in dict_args.items():
681+
processed_args.append(dict_arg)
682+
processed_args.append(json.dumps(dict_value))
683+
684+
return super().parse_args(processed_args, namespace)

lmdeploy/messages.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import enum
33
import time
44
from dataclasses import dataclass, field
5-
from typing import Callable, Dict, List, Literal, Optional
5+
from typing import Any, Callable, Dict, List, Literal, Optional
66

77
import torch
88
from pydantic.dataclasses import dataclass as pydantic_dataclass
@@ -223,6 +223,8 @@ class TurbomindEngineConfig:
223223
devices(List[int]): the used devices
224224
empty_init (bool): Whether to load the model weights, you should set
225225
it to True if you want to update weights after create the pipeline
226+
hf_overrides (Dict[str, Any]): Huggingface overrides for the model.
227+
It can be used to override the default config of the model,
226228
"""
227229

228230
dtype: str = 'auto'
@@ -252,6 +254,7 @@ class TurbomindEngineConfig:
252254
devices: Optional[List[int]] = None
253255
empty_init: bool = False
254256
communicator: str = 'nccl'
257+
hf_overrides: Optional[Dict[str, Any]] = None
255258

256259
def __post_init__(self):
257260
"""Check input validation."""
@@ -322,6 +325,8 @@ class PytorchEngineConfig:
322325
Default to `MigrationBackend.DLSlime`.
323326
enable_mp_engine (bool): run engine in multi-process mode.
324327
model_format (str): weight quantization policy, options: ['fp8'].
328+
hf_overrides (Dict[str, Any]): Huggingface overrides for the model.
329+
It can be used to override the default config of the model,
325330
"""
326331
dtype: str = 'auto'
327332
tp: int = 1
@@ -352,6 +357,7 @@ class PytorchEngineConfig:
352357
enable_mp_engine: bool = False
353358
model_format: str = None
354359
enable_metrics: bool = False
360+
hf_overrides: Optional[Dict[str, Any]] = None
355361

356362
role: EngineRole = EngineRole.Hybrid
357363
migration_backend: MigrationBackend = MigrationBackend.DLSlime

lmdeploy/pytorch/config.py

Lines changed: 24 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,8 @@ def from_pretrained(cls,
158158
pretrained_model_name_or_path: str,
159159
trust_remote_code: bool = True,
160160
dtype: str = 'auto',
161-
dist_config: DistConfig = None):
161+
dist_config: DistConfig = None,
162+
hf_overrides: Dict[str, Any] = None):
162163
"""Instantiate one of the configuration classes of the library from a
163164
pretrained model configuration.
164165
@@ -168,13 +169,28 @@ def from_pretrained(cls,
168169
models defined on the Hub in their own modeling files.
169170
dtype (str): user specified data type for model weights and
170171
activations. Refer to `PyTorchEngineConfig` for details
172+
hf_overrides (Dict[str, Any]): overrides for the HF config.
171173
"""
172174
from transformers import AutoConfig
175+
176+
from lmdeploy.utils import get_logger
177+
173178
hf_config = AutoConfig.from_pretrained(pretrained_model_name_or_path, trust_remote_code=trust_remote_code)
174179
if getattr(hf_config, 'model_type', None) in ['phi3']:
175180
# phi3 + trust_remote_code leads to error when tp.
176181
hf_config = AutoConfig.from_pretrained(pretrained_model_name_or_path)
177-
return cls.from_hf_config(hf_config, pretrained_model_name_or_path, dtype=dtype, dist_config=dist_config)
182+
183+
model_config = cls.from_hf_config(hf_config,
184+
pretrained_model_name_or_path,
185+
dtype=dtype,
186+
dist_config=dist_config)
187+
188+
if hf_overrides is not None:
189+
logger = get_logger('lmdeploy')
190+
logger.warning(f'Overriding HF config with {hf_overrides}')
191+
model_config.hf_config.update(hf_overrides)
192+
193+
return model_config
178194

179195
@classmethod
180196
def from_hf_config(cls,
@@ -223,14 +239,14 @@ class MiscConfig:
223239
custom_module_map: str = None
224240
empty_init: bool = False
225241
model_format: str = None
242+
hf_overrides: Dict[str, Any] = None
226243

227244
@classmethod
228245
def from_engine_config(cls, engine_config: PytorchEngineConfig):
229246
"""From engine config."""
230-
misc_config = cls(
231-
custom_module_map=engine_config.custom_module_map,
232-
empty_init=engine_config.empty_init,
233-
prefill_interval=engine_config.prefill_interval,
234-
model_format=engine_config.model_format,
235-
)
247+
misc_config = cls(custom_module_map=engine_config.custom_module_map,
248+
empty_init=engine_config.empty_init,
249+
prefill_interval=engine_config.prefill_interval,
250+
model_format=engine_config.model_format,
251+
hf_overrides=engine_config.hf_overrides)
236252
return misc_config

lmdeploy/pytorch/engine/executor/__init__.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,11 @@ def build_executor(model_path: str,
6868
dp = dist_config.dp
6969
world_size = dist_config.world_size
7070

71-
model_config = ModelConfig.from_pretrained(model_path, trust_remote_code=True, dtype=dtype, dist_config=dist_config)
71+
model_config = ModelConfig.from_pretrained(model_path,
72+
trust_remote_code=True,
73+
dtype=dtype,
74+
hf_overrides=misc_config.hf_overrides,
75+
dist_config=dist_config)
7276

7377
if distributed_executor_backend is None:
7478
distributed_executor_backend = get_distributed_executor_backend(world_size, dp, device_type, logger)

lmdeploy/pytorch/engine/executor/ray_executor.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -257,7 +257,11 @@ def __init__(
257257

258258
from lmdeploy.tokenizer import Tokenizer
259259
tokenizer = Tokenizer(model_path).model.model
260-
model_config = ModelConfig.from_pretrained(model_path, dtype=dtype, dist_config=dist_config)
260+
model_config = ModelConfig.from_pretrained(model_path,
261+
dtype=dtype,
262+
hf_overrides=misc_config.hf_overrides,
263+
dist_config=dist_config)
264+
261265
super().__init__(
262266
model_path=model_path,
263267
cache_config=cache_config,

lmdeploy/turbomind/deploy/config.py

Lines changed: 27 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,9 @@
88
from pydantic.dataclasses import dataclass
99

1010
from lmdeploy.messages import TurbomindEngineConfig
11+
from lmdeploy.utils import get_logger
12+
13+
logger = get_logger('lmdeploy')
1114

1215

1316
def config_from_dict(cls, env):
@@ -150,15 +153,32 @@ def update_from_engine_config(self, config: TurbomindEngineConfig):
150153
if hasattr(self.attention_config, key):
151154
setattr(self.attention_config, key, value)
152155

156+
# update from hf_overrides
157+
if hasattr(config, 'hf_overrides') and config.hf_overrides:
158+
hf_overrides = config.hf_overrides
159+
160+
if hf_overrides.get('rope_scaling'):
161+
override_params = hf_overrides.get('rope_scaling')
162+
163+
rope_param = self.attention_config.rope_param or RopeParam(type='', base=0, dim=0)
164+
rope_param.type = override_params.get('rope_type', '')
165+
rope_param.factor = override_params.get('factor', 1.0)
166+
rope_param.max_position_embeddings = override_params.get('original_max_position_embeddings', None)
167+
168+
self.attention_config.rope_param = rope_param
169+
logger.warning(f'Overriding HF config with {hf_overrides}')
170+
153171
# use dynamic ntk
154172
if config.rope_scaling_factor:
155-
if self.attention_config.rope_param is None:
156-
# some ut will create empty RopeParam, will check base/dim in src code
157-
self.attention_config.rope_param = RopeParam(type='', base=0, dim=0)
158-
self.attention_config.rope_param.__dict__.update(
159-
type='dynamic',
160-
factor=config.rope_scaling_factor,
161-
max_position_embeddings=self.attention_config.max_position_embeddings)
173+
# some ut will create empty RopeParam, will check base/dim in src code
174+
rope_param = self.attention_config.rope_param or RopeParam(type='', base=0, dim=0)
175+
rope_param.type = 'dynamic'
176+
rope_param.factor = config.rope_scaling_factor
177+
rope_param.max_position_embeddings = self.attention_config.max_position_embeddings
178+
179+
self.attention_config.rope_param = rope_param
180+
logger.warning(
181+
'`--rope-scaling-factor` will be removed in a future release. Please instead use `--hf-overrides`.')
162182

163183
@classmethod
164184
def from_dict(cls, config: dict = {}):

0 commit comments

Comments
 (0)