|
3 | 3 | import os |
4 | 4 | from typing import Any, Dict |
5 | 5 |
|
| 6 | +import yaml |
6 | 7 | from transformers import AutoConfig |
7 | 8 |
|
8 | 9 | from slime.backends.sglang_utils.arguments import add_sglang_arguments |
@@ -1009,6 +1010,12 @@ def add_sglang_tp_size(): |
1009 | 1010 | # For megatron |
1010 | 1011 | parser = add_custom_megatron_plugins_arguments(parser) |
1011 | 1012 | try: |
| 1013 | + parser.add_argument( |
| 1014 | + "--custom-config-path", |
| 1015 | + type=str, |
| 1016 | + default=None, |
| 1017 | + help="Path to the YAML config for custom function arguments.", |
| 1018 | + ) |
1012 | 1019 | parser.add_argument("--padded-vocab-size", type=int, default=None) |
1013 | 1020 | except: |
1014 | 1021 | pass |
@@ -1231,6 +1238,15 @@ def slime_validate_args(args): |
1231 | 1238 | "num_epoch is not set, but num_rollout is not set, " "please set --num-rollout or --num-epoch" |
1232 | 1239 | ) |
1233 | 1240 |
|
| 1241 | + if args.custom_config_path: |
| 1242 | + with open(args.custom_config_path, "r") as f: |
| 1243 | + data = yaml.safe_load(f) or {} |
| 1244 | + for k, v in data.items(): |
| 1245 | + if not hasattr(args, k): |
| 1246 | + setattr(args, k, v) |
| 1247 | + else: |
| 1248 | + print(f"Warning: Argument {k} is already set to {getattr(args, k)}, will not override with {v}.") |
| 1249 | + |
1234 | 1250 |
|
1235 | 1251 | def hf_validate_args(args, hf_config): |
1236 | 1252 | equal = lambda x, y: x == y |
|
0 commit comments