22import argparse
33import os
44import shlex
5+ from collections .abc import Callable
56from dataclasses import dataclass , field
6- from typing import Any , Callable , Optional , Type
7+ from typing import Any
78
89import yaml
910from deepmerge import always_merger
1718
1819@dataclass
1920class ArgSpec :
20-
2121 arg : str = ""
22- arg_type : Type [Any ] = field (default = str )
22+ arg_type : type [Any ] = field (default = str )
2323 description : str = ""
24- map_fn : Optional [ Callable ] = field (default = None )
24+ map_fn : Callable | None = field (default = None )
2525
2626
2727class Converter (abc .ABC ):
@@ -43,7 +43,7 @@ def get_lite_template(self, template_path: str):
4343 """
4444 Load the areal template from the specified file.
4545 """
46- with open (template_path , "r" , encoding = "utf-8" ) as f :
46+ with open (template_path , encoding = "utf-8" ) as f :
4747 return yaml .safe_load (f )
4848
4949 def flatten_dict (self , d , parent_key = "" , sep = "." ):
@@ -74,7 +74,7 @@ def convert_to_nested_args(self, args, ARG_MAP: dict) -> dict:
7474 argspec = ARG_MAP [k ]
7575 if not argspec .arg :
7676 print (
77- colored (f "## Warning: For " , "yellow" )
77+ colored ("## Warning: For " , "yellow" )
7878 + colored (f"{ k :>40} " , "yellow" , attrs = ["bold" ])
7979 + colored (f", # { argspec .description } !" , "yellow" )
8080 )
@@ -87,23 +87,23 @@ def convert_to_nested_args(self, args, ARG_MAP: dict) -> dict:
8787 if v is not None :
8888 try :
8989 # type conversion
90- if arg_type == bool :
90+ if arg_type is bool :
9191 v = (
9292 bool (v )
9393 if isinstance (v , bool )
9494 else v .lower () in ("1" , "true" , "yes" , "on" )
9595 )
96- elif arg_type == int :
96+ elif arg_type is int :
9797 v = int (v )
98- elif arg_type == float :
98+ elif arg_type is float :
9999 v = float (v )
100- elif arg_type == str :
100+ elif arg_type is str :
101101 v = str (v )
102102 else :
103103 raise ValueError (f"Unsupported type: { arg_type } " )
104104 except Exception as e :
105105 print (
106- colored (f "## Error: For " , "red" )
106+ colored ("## Error: For " , "red" )
107107 + colored (f"{ k :>40} { v } " , "red" , attrs = ["bold" ])
108108 + colored (f", # { e } !" , "red" )
109109 )
@@ -117,7 +117,7 @@ def convert_to_nested_args(self, args, ARG_MAP: dict) -> dict:
117117 else :
118118 unmapped [k ] = v
119119 print (
120- colored (f "## Warning: For " , "yellow" )
120+ colored ("## Warning: For " , "yellow" )
121121 + colored (f"{ k :>50} " , "yellow" , attrs = ["bold" ])
122122 + colored (f", # { CVRT_WARNING } !" , "yellow" )
123123 )
@@ -135,7 +135,6 @@ def set_nested(self, d: dict, keys, value):
135135
136136
137137class OpenRLHFConverter (Converter ):
138-
139138 ARG_MAP = {
140139 # Ray and vLLM
141140 "ref_num_nodes" : ArgSpec ("" , int , CVRT_WARNING ),
@@ -361,7 +360,7 @@ def _parse_args_from_script(
361360 in_command_block = False
362361
363362 try :
364- with open (script_path , "r" , encoding = "utf-8" ) as f :
363+ with open (script_path , encoding = "utf-8" ) as f :
365364 for line in f :
366365 stripped_line = line .strip ()
367366
@@ -425,7 +424,6 @@ def _parse_args_from_script(
425424
426425
427426def post_process_args (args : dict ):
428-
429427 if "allocation_mode" in args :
430428 # convert allocation_mode to sglang.dX.tY.pZ
431429 dp = args ["cluster" ]["n_nodes" ] * args ["cluster" ]["n_gpus_per_node" ]
@@ -442,7 +440,7 @@ def post_process_args(args: dict):
442440 allocation_mode += "t1"
443441 else :
444442 allocation_mode += f"t{ args ['allocation_mode' ]['sglang' ]['t' ]} "
445- allocation_mode += f "p1"
443+ allocation_mode += "p1"
446444 allocation_mode += "+"
447445 if "engine" not in args ["allocation_mode" ]:
448446 allocation_mode += f"d{ dp } t1p1"
@@ -452,7 +450,7 @@ def post_process_args(args: dict):
452450 allocation_mode += "t1"
453451 else :
454452 allocation_mode += f".t{ args ['allocation_mode' ]['engine' ]['t' ]} "
455- allocation_mode += f "p1"
453+ allocation_mode += "p1"
456454
457455 args ["allocation_mode" ] = allocation_mode
458456 args ["cluster" ]["n_nodes" ] = args ["cluster" ]["n_nodes" ] * 2
@@ -830,7 +828,7 @@ def __init__(self, src_config_path: str, template_path: str):
830828 self .template_path = template_path
831829
832830 def parse (self ) -> dict :
833- with open (self .src_config_path , "r" , encoding = "utf-8" ) as f :
831+ with open (self .src_config_path , encoding = "utf-8" ) as f :
834832 cfg = yaml .safe_load (f )
835833 return cfg
836834
@@ -895,7 +893,7 @@ def main():
895893 ** converter_args [args .convert_src ]
896894 )
897895 lite_args = converter .convert ()
898- yaml_str = yaml .dump (lite_args , sort_keys = False , allow_unicode = True )
896+ # yaml_str = yaml.dump(lite_args, sort_keys=False, allow_unicode=True)
899897 with open (args .output_path , "w" , encoding = "utf-8" ) as f :
900898 yaml .dump (lite_args , f , sort_keys = False , allow_unicode = True )
901899 print (f"Converted areal config saved to { args .output_path } " )
0 commit comments