@@ -17,8 +17,8 @@ def create_torch_dist_role(
1717 entrypoint : str ,
1818 resource : Resource = NULL_RESOURCE ,
1919 base_image : Optional [str ] = None ,
20- script_args : Optional [List [str ]] = None ,
21- script_envs : Optional [Dict [str , str ]] = None ,
20+ args : Optional [List [str ]] = None ,
21+ env : Optional [Dict [str , str ]] = None ,
2222 num_replicas : int = 1 ,
2323 max_retries : int = 0 ,
2424 port_map : Dict [str , int ] = field (default_factory = dict ),
@@ -54,7 +54,7 @@ def create_torch_dist_role(
5454 ... image="<NONE>",
5555 ... resource=NULL_RESOURCE,
5656 ... entrypoint="my_train_script.py",
57- ... script_args =["--script_arg", "foo", "--another_arg", "bar"],
57+ ... args =["--script_arg", "foo", "--another_arg", "bar"],
5858 ... num_replicas=4, max_retries=1,
5959 ... nproc_per_node=8, nnodes="2:4", max_restarts=3)
6060 ... # effectively runs:
@@ -72,8 +72,8 @@ def create_torch_dist_role(
7272 entrypoint: User binary or python script that will be launched.
7373 resource: Resource that is requested by scheduler
7474 base_image: Optional base image, if schedulers support image overlay
75- script_args : User provided arguments
76- script_envs : Env. variables that will be set on worker process that runs entrypoint
75+ args : User provided arguments
76+ env : Env. variables that will be set on worker process that runs entrypoint
7777 num_replicas: Number of role replicas to run
7878 max_retries: Max number of retries
7979 port_map: Port mapping for the role
@@ -84,11 +84,11 @@ def create_torch_dist_role(
8484 Role object that launches user entrypoint via the torchelastic as proxy
8585
8686 """
87- script_args = script_args or []
88- script_envs = script_envs or {}
87+ args = args or []
88+ env = env or {}
8989
9090 entrypoint_override = "python"
91- args : List [str ] = ["-m" , "torch.distributed.launch" ]
91+ torch_run_args : List [str ] = ["-m" , "torch.distributed.launch" ]
9292
9393 launch_kwargs .setdefault ("rdzv_backend" , "etcd" )
9494 launch_kwargs .setdefault ("rdzv_id" , macros .app_id )
@@ -98,14 +98,14 @@ def create_torch_dist_role(
9898 if isinstance (val , bool ):
9999 # treat boolean kwarg as a flag
100100 if val :
101- args += [f"--{ arg } " ]
101+ torch_run_args += [f"--{ arg } " ]
102102 else :
103- args += [f"--{ arg } " , str (val )]
103+ torch_run_args += [f"--{ arg } " , str (val )]
104104 if not os .path .isabs (entrypoint ) and not entrypoint .startswith (macros .img_root ):
105105 # make entrypoint relative to {img_root} ONLY if it is not an absolute path
106106 entrypoint = os .path .join (macros .img_root , entrypoint )
107107
108- args + = [entrypoint , * script_args ]
108+ args = [* torch_run_args , entrypoint , * args ]
109109 return (
110110 Role (
111111 name ,
@@ -114,7 +114,7 @@ def create_torch_dist_role(
114114 resource = resource ,
115115 port_map = port_map ,
116116 )
117- .runs (entrypoint_override , * args , ** script_envs )
117+ .runs (entrypoint_override , * args , ** env )
118118 .replicas (num_replicas )
119119 .with_retry_policy (retry_policy , max_retries )
120120 )
0 commit comments