1414limitations under the License.
1515"""
1616
17+ import os
18+ import re
1719from .capacity import H100_DEVICE_TYPE , H100_MEGA_DEVICE_TYPE , H200_DEVICE_TYPE
1820from .cluster import setup_k8s_env
1921from .storage import GCS_FUSE_TYPE , GCP_FILESTORE_TYPE , Storage , get_storages_to_mount
@@ -64,6 +66,25 @@ def get_env_container(args, system: SystemCharacteristics) -> str:
6466 str:
6567 YAML with the env config for the main container, as a YAML string.
6668 """
69+ if system .accelerator_type == AcceleratorType ['GPU' ]:
70+ return get_gpu_env (args , system )
71+
72+ if system .accelerator_type == AcceleratorType ['CPU' ]:
73+ return get_cpu_env (args , system )
74+
75+ return format_env_dict (args .env , system ) # pytype: disable=bad-return-type
76+
77+
78+ def get_gpu_env (args , system ) -> str :
79+ """Generate environment variables for GPU nodepools
80+ Args:
81+ num_slices: Number of slices to be used in the workload.
82+ env_vars: Environment variables, processed from user args.
83+ system: system characteristics
84+
85+ Returns:
86+ str: yaml containing env variables
87+ """
6788 gpu_env_yaml = """
6889 - name: REPLICATED_JOB_NAME
6990 valueFrom:
@@ -73,8 +94,6 @@ def get_env_container(args, system: SystemCharacteristics) -> str:
7394 valueFrom:
7495 fieldRef:
7596 fieldPath: metadata.annotations['jobset.sigs.k8s.io/jobset-name']
76- - name: JAX_COORDINATOR_ADDRESS
77- value: "$(JOBSET_NAME)-$(REPLICATED_JOB_NAME)-0-0.$(JOBSET_NAME)"
7897 - name: NNODES
7998 value: "{args.num_nodes}"
8099 - name: NODE_RANK
@@ -84,32 +103,37 @@ def get_env_container(args, system: SystemCharacteristics) -> str:
84103 - name: USE_GPUDIRECT
85104 value: {gpu_direct_name}
86105 - name: GPUS_PER_NODE
87- value: "{system.chips_per_vm}"
88- - name: JAX_COORDINATOR_PORT
89- value: "6002"
106+ value: "{chips_per_vm}"
90107 - name: COMMAND
91108 value: "{args.command}"
92- {args.env}"""
93-
94- if system .accelerator_type == AcceleratorType ['GPU' ]:
95- gpu_direct_name = 'fastrak'
96- if args .device_type == H100_DEVICE_TYPE :
97- gpu_direct_name = 'tcpx'
98- elif args .device_type == H100_MEGA_DEVICE_TYPE :
99- gpu_direct_name = 'tcpxo'
100- elif args .device_type == H200_DEVICE_TYPE :
101- gpu_direct_name = 'rdma'
102- return gpu_env_yaml .format (
103- args = args , system = system , gpu_direct_name = gpu_direct_name
104- )
105-
106- if system .accelerator_type == AcceleratorType ['CPU' ]:
107- return get_cpu_env (args .num_slices , args .env , system )
108-
109- return args .env # pytype: disable=bad-return-type
109+ {custom_envs}"""
110+
111+ gpu_direct_name = 'fastrak'
112+ if args .device_type == H100_DEVICE_TYPE :
113+ gpu_direct_name = 'tcpx'
114+ elif args .device_type == H100_MEGA_DEVICE_TYPE :
115+ gpu_direct_name = 'tcpxo'
116+ elif args .device_type == H200_DEVICE_TYPE :
117+ gpu_direct_name = 'rdma'
118+
119+ gpu_env_dic = {
120+ 'JAX_COORDINATOR_PORT' : '6002' ,
121+ 'JAX_COORDINATOR_ADDRESS' : (
122+ '$(JOBSET_NAME)-$(REPLICATED_JOB_NAME)-0-0.$(JOBSET_NAME)'
123+ ),
124+ }
125+
126+ args .env = gpu_env_dic | args .env
127+
128+ return gpu_env_yaml .format (
129+ args = args ,
130+ chips_per_vm = system .chips_per_vm ,
131+ gpu_direct_name = gpu_direct_name ,
132+ custom_envs = format_env_dict (args .env , system ),
133+ )
110134
111135
112- def get_cpu_env (num_slices , env_vars , system ) -> str :
136+ def get_cpu_env (args , system ) -> str :
113137 """Generate environment variables for CPU nodepools
114138 Args:
115139 num_slices: Number of slices to be used in the workload.
@@ -132,19 +156,87 @@ def get_cpu_env(num_slices, env_vars, system) -> str:
132156 valueFrom:
133157 fieldRef:
134158 fieldPath: metadata.annotations['batch.kubernetes.io/job-completion-index']
135- - name: PROCESSES_IN_JOB
136- value: "{processes_in_job}"
137- - name: JAX_PROCESS_COUNT
138- value: "{process_count}"
139- {env_vars}
140- - name: JAX_COORDINATOR_ADDRESS
141- value: "$(JOBSET_NAME)-$(REPLICATED_JOB_NAME)-0-0.$(JOBSET_NAME)"
159+ {custom_envs}
142160 """
143- return yaml .format (
144- processes_in_job = system .vms_per_slice ,
145- process_count = calculate_process_count (num_slices , system .vms_per_slice ),
146- env_vars = env_vars ,
147- )
161+
162+ cpu_env_dic = {
163+ 'PROCESSES_IN_JOB' : str (system .vms_per_slice ),
164+ 'JAX_PROCESS_COUNT' : str (
165+ calculate_process_count (args .num_slices , system .vms_per_slice )
166+ ),
167+ 'JAX_COORDINATOR_ADDRESS' : (
168+ '$(JOBSET_NAME)-$(REPLICATED_JOB_NAME)-0-0.$(JOBSET_NAME)'
169+ ),
170+ }
171+
172+ args .env = cpu_env_dic | args .env
173+
174+ return yaml .format (custom_envs = format_env_dict (args .env , system ))
175+
176+
177+ def format_env_dict (env , system : SystemCharacteristics ) -> str :
178+ if system .accelerator_type == AcceleratorType ['GPU' ]:
179+ # For GPUs, it has two more spaces ahead of name and value respectively
180+ env_format = '''
181+ - name: {key}
182+ value: "{value}"'''
183+ else :
184+ env_format = '''
185+ - name: {key}
186+ value: "{value}"'''
187+ return '' .join (env_format .format (key = k , value = v ) for k , v in env .items ())
188+
189+
190+ def parse_env_config (args , tensorboard_config ):
191+ """Parses the environment configurations to the a dictionary.
192+
193+ Args:
194+ args: user provided arguments for running the command.
195+ tensorboard_config: configuration of Vertex Tensorboard.
196+ system: system characteristics.
197+ """
198+ env = {}
199+
200+ env_pat = re .compile (r'(^[a-zA-Z_][a-zA-Z0-9_]*?)(?:=(.*))?$' , re .M )
201+ if args .env_file :
202+ print ('Setting container environment from' , args .env_file )
203+ with open (file = args .env_file , mode = 'r' , encoding = 'utf-8' ) as f :
204+ for match in env_pat .finditer (f .read ()):
205+ variable = match .group (1 )
206+ if match .group (2 ) is not None :
207+ env [variable ] = match .group (2 )
208+ else :
209+ assert variable in os .environ , (
210+ f'Variable { variable } is not set in the current '
211+ 'environment, a value must be specified.'
212+ )
213+ env [variable ] = os .environ [variable ]
214+ if args .env :
215+ for var in args .env :
216+ match = env_pat .match (var )
217+ assert match and match .group (2 ) is not None , (
218+ 'Invalid environment variable, format must be '
219+ f'`--env VARIABLE=value`: { var } '
220+ )
221+ variable = match .group (1 )
222+ env [variable ] = match .group (2 )
223+
224+ if not args .use_pathways :
225+ if args .debug_dump_gcs :
226+ if 'XLA_FLAGS' in env :
227+ raise ValueError (
228+ 'Conflict: XLA_FLAGS defined in both --debug_dump_gcs '
229+ 'and environment file. Please choose one way to define '
230+ 'XLA_FLAGS.'
231+ )
232+ env ['XLA_FLAGS' ] = '--xla_dump_to=/tmp/xla_dump/'
233+
234+ if tensorboard_config :
235+ env ['UPLOAD_DATA_TO_TENSORBOARD' ] = True
236+ for key , value in tensorboard_config .items ():
237+ env [key .upper ()] = value
238+
239+ args .env = env
148240
149241
150242def get_volumes (args , system : SystemCharacteristics ) -> str :
0 commit comments