|
1 | 1 | # Copyright (c) OpenMMLab. All rights reserved. |
2 | 2 |
|
3 | 3 | import argparse |
4 | | -from typing import List |
| 4 | +import json |
| 5 | +import re |
| 6 | +import sys |
| 7 | +from collections import defaultdict |
| 8 | +from typing import Any, List |
5 | 9 |
|
6 | 10 |
|
7 | 11 | class DefaultsAndTypesHelpFormatter(argparse.HelpFormatter): |
@@ -231,6 +235,14 @@ def rope_scaling_factor(parser): |
231 | 235 |
|
232 | 236 | return parser.add_argument('--rope-scaling-factor', type=float, default=0.0, help='Rope scaling factor') |
233 | 237 |
|
| 238 | + @staticmethod |
| 239 | + def hf_overrides(parser): |
| 240 | + """Add argument hf_overrides to parser.""" |
| 241 | + return parser.add_argument('--hf-overrides', |
| 242 | + type=json.loads, |
| 243 | + default=None, |
| 244 | + help='Extra arguments to be forwarded to the HuggingFace config.') |
| 245 | + |
234 | 246 | @staticmethod |
235 | 247 | def use_logn_attn(parser): |
236 | 248 | """Add argument use_logn_attn to parser.""" |
@@ -580,3 +592,93 @@ def migration_backend(parser): |
580 | 592 | default='DLSlime', |
581 | 593 | choices=['DLSlime', 'Mooncake'], |
582 | 594 | help='kvcache migration management backend when PD disaggregation') |
| 595 | + |
| 596 | + |
| 597 | +# adapted from https://github.com/vllm-project/vllm/blob/main/vllm/utils/__init__.py |
| 598 | +class FlexibleArgumentParser(argparse.ArgumentParser): |
| 599 | + """"More flexible argument parser.""" |
| 600 | + |
| 601 | + def parse_args(self, args=None, namespace=None): |
| 602 | + # If args is not provided, use arguments from the command line |
| 603 | + if args is None: |
| 604 | + args = sys.argv[1:] |
| 605 | + |
| 606 | + def repl(match: re.Match) -> str: |
| 607 | + """Replaces underscores with dashes in the matched string.""" |
| 608 | + return match.group(0).replace('_', '-') |
| 609 | + |
| 610 | + # Everything between the first -- and the first . |
| 611 | + pattern = re.compile(r'(?<=--)[^\.]*') |
| 612 | + |
| 613 | + # Convert underscores to dashes and vice versa in argument names |
| 614 | + processed_args = [] |
| 615 | + for arg in args: |
| 616 | + if arg.startswith('--'): |
| 617 | + if '=' in arg: |
| 618 | + key, value = arg.split('=', 1) |
| 619 | + key = pattern.sub(repl, key, count=1) |
| 620 | + processed_args.append(f'{key}={value}') |
| 621 | + else: |
| 622 | + key = pattern.sub(repl, arg, count=1) |
| 623 | + processed_args.append(key) |
| 624 | + elif arg.startswith('-O') and arg != '-O' and len(arg) == 2: |
| 625 | + # allow -O flag to be used without space, e.g. -O3 |
| 626 | + processed_args.append('-O') |
| 627 | + processed_args.append(arg[2:]) |
| 628 | + else: |
| 629 | + processed_args.append(arg) |
| 630 | + |
| 631 | + def _try_convert(value: str): |
| 632 | + """Try to convert string to float or int.""" |
| 633 | + if not isinstance(value, str): |
| 634 | + return value |
| 635 | + # try loads from json |
| 636 | + try: |
| 637 | + return json.loads(value) |
| 638 | + except json.JSONDecodeError: |
| 639 | + pass |
| 640 | + return value |
| 641 | + |
| 642 | + def create_nested_dict(keys: list[str], value: str): |
| 643 | + """Creates a nested dictionary from a list of keys and a value. |
| 644 | +
|
| 645 | + For example, `keys = ["a", "b", "c"]` and `value = 1` will create: `{"a": {"b": {"c": 1}}}` |
| 646 | + """ |
| 647 | + nested_dict: Any = _try_convert(value) |
| 648 | + for key in reversed(keys): |
| 649 | + nested_dict = {key: nested_dict} |
| 650 | + return nested_dict |
| 651 | + |
| 652 | + def recursive_dict_update(original: dict, update: dict): |
| 653 | + """Recursively updates a dictionary with another dictionary.""" |
| 654 | + for k, v in update.items(): |
| 655 | + if isinstance(v, dict) and isinstance(original.get(k), dict): |
| 656 | + recursive_dict_update(original[k], v) |
| 657 | + else: |
| 658 | + original[k] = v |
| 659 | + |
| 660 | + delete = set() |
| 661 | + dict_args: dict[str, dict] = defaultdict(dict) |
| 662 | + for i, processed_arg in enumerate(processed_args): |
| 663 | + if processed_arg.startswith('--') and '.' in processed_arg: |
| 664 | + if '=' in processed_arg: |
| 665 | + processed_arg, value = processed_arg.split('=', 1) |
| 666 | + if '.' not in processed_arg: |
| 667 | + # False positive, . was only in the value |
| 668 | + continue |
| 669 | + else: |
| 670 | + value = processed_args[i + 1] |
| 671 | + delete.add(i + 1) |
| 672 | + key, *keys = processed_arg.split('.') |
| 673 | + # Merge all values with the same key into a single dict |
| 674 | + arg_dict = create_nested_dict(keys, value) |
| 675 | + recursive_dict_update(dict_args[key], arg_dict) |
| 676 | + delete.add(i) |
| 677 | + # Filter out the dict args we set to None |
| 678 | + processed_args = [a for i, a in enumerate(processed_args) if i not in delete] |
| 679 | + # Add the dict args back as if they were originally passed as JSON |
| 680 | + for dict_arg, dict_value in dict_args.items(): |
| 681 | + processed_args.append(dict_arg) |
| 682 | + processed_args.append(json.dumps(dict_value)) |
| 683 | + |
| 684 | + return super().parse_args(processed_args, namespace) |
0 commit comments