Skip to content

Commit 414a6e1

Browse files
authored
Override env vars if exist in custom envs coming from commands (#524)
* env vars become a dictionary and values overrided * removed excesive arg * imported missing modules * added missing arg * removed excesive imports * fixed imports * fixed dict merge
1 parent 7504a2c commit 414a6e1

File tree

5 files changed

+134
-114
lines changed

5 files changed

+134
-114
lines changed

src/xpk/commands/workload.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -27,16 +27,12 @@
2727
setup_k8s_env,
2828
)
2929
from ..core.commands import run_command_with_updates, run_commands
30-
from ..core.config import (
31-
VERTEX_TENSORBOARD_FEATURE_FLAG,
32-
XPK_CURRENT_VERSION,
33-
parse_env_config,
34-
)
30+
from ..core.config import (VERTEX_TENSORBOARD_FEATURE_FLAG, XPK_CURRENT_VERSION)
3531
from ..core.docker_container import (
3632
get_main_container_docker_image,
3733
get_user_workload_container,
3834
)
39-
from ..core.docker_resources import get_volumes
35+
from ..core.docker_resources import get_volumes, parse_env_config
4036
from ..core.gcloud_context import add_zone_and_project
4137
from ..core.kueue import LOCAL_QUEUE_NAME
4238
from ..core.monitoring import get_gke_outlier_dashboard
@@ -353,7 +349,7 @@ def workload_create(args) -> None:
353349
if not tensorboard_config:
354350
xpk_exit(1)
355351

356-
parse_env_config(args, tensorboard_config, system)
352+
parse_env_config(args, tensorboard_config)
357353

358354
autoprovisioning_args = ''
359355
autoprovisioning_enabled, return_code = is_autoprovisioning_enabled(

src/xpk/core/config.py

Lines changed: 0 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,11 @@
1515
"""
1616

1717
import os
18-
import re
1918

2019
import ruamel.yaml
2120

2221
from ..utils import file
2322
from ..utils.console import xpk_print
24-
from .system_characteristics import AcceleratorType, SystemCharacteristics
2523

2624
# This is the version for XPK PyPI package
2725
__version__ = 'v0.8.0'
@@ -117,65 +115,3 @@ def get_all(
117115
return None
118116
val: dict[str, str] = config_yaml[CONFIGS_KEY]
119117
return val
120-
121-
122-
def parse_env_config(args, tensorboard_config, system: SystemCharacteristics):
123-
"""Parses the environment configurations to the jobset config.
124-
125-
Args:
126-
args: user provided arguments for running the command.
127-
tensorboard_config: configuration of Vertex Tensorboard.
128-
system: system characteristics.
129-
"""
130-
env = {}
131-
132-
env_pat = re.compile(r'(^[a-zA-Z_][a-zA-Z0-9_]*?)(?:=(.*))?$', re.M)
133-
if args.env_file:
134-
print('Setting container environment from', args.env_file)
135-
with open(file=args.env_file, mode='r', encoding='utf-8') as f:
136-
for match in env_pat.finditer(f.read()):
137-
variable = match.group(1)
138-
if match.group(2) is not None:
139-
env[variable] = match.group(2)
140-
else:
141-
assert variable in os.environ, (
142-
f'Variable {variable} is not set in the current '
143-
'environment, a value must be specified.'
144-
)
145-
env[variable] = os.environ[variable]
146-
if args.env:
147-
for var in args.env:
148-
match = env_pat.match(var)
149-
assert match and match.group(2) is not None, (
150-
'Invalid environment variable, format must be '
151-
f'`--env VARIABLE=value`: {var}'
152-
)
153-
variable = match.group(1)
154-
env[variable] = match.group(2)
155-
156-
if not args.use_pathways:
157-
if args.debug_dump_gcs:
158-
if 'XLA_FLAGS' in env:
159-
raise ValueError(
160-
'Conflict: XLA_FLAGS defined in both --debug_dump_gcs '
161-
'and environment file. Please choose one way to define '
162-
'XLA_FLAGS.'
163-
)
164-
env['XLA_FLAGS'] = '--xla_dump_to=/tmp/xla_dump/'
165-
166-
if tensorboard_config:
167-
env['UPLOAD_DATA_TO_TENSORBOARD'] = True
168-
for key, value in tensorboard_config.items():
169-
env[key.upper()] = value
170-
171-
if system.accelerator_type == AcceleratorType['GPU']:
172-
# For GPUs, it has two more spaces ahead of name and value respectively
173-
env_format = '''
174-
- name: {key}
175-
value: "{value}"'''
176-
else:
177-
env_format = '''
178-
- name: {key}
179-
value: "{value}"'''
180-
181-
args.env = ''.join(env_format.format(key=k, value=v) for k, v in env.items())

src/xpk/core/docker_resources.py

Lines changed: 128 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
limitations under the License.
1515
"""
1616

17+
import os
18+
import re
1719
from .capacity import H100_DEVICE_TYPE, H100_MEGA_DEVICE_TYPE, H200_DEVICE_TYPE
1820
from .cluster import setup_k8s_env
1921
from .storage import GCS_FUSE_TYPE, GCP_FILESTORE_TYPE, Storage, get_storages_to_mount
@@ -64,6 +66,25 @@ def get_env_container(args, system: SystemCharacteristics) -> str:
6466
str:
6567
YAML with the env config for the main container, as a YAML string.
6668
"""
69+
if system.accelerator_type == AcceleratorType['GPU']:
70+
return get_gpu_env(args, system)
71+
72+
if system.accelerator_type == AcceleratorType['CPU']:
73+
return get_cpu_env(args, system)
74+
75+
return format_env_dict(args.env, system) # pytype: disable=bad-return-type
76+
77+
78+
def get_gpu_env(args, system) -> str:
79+
"""Generate environment variables for GPU nodepools
80+
Args:
81+
num_slices: Number of slices to be used in the workload.
82+
env_vars: Environment variables, processed from user args.
83+
system: system characteristics
84+
85+
Returns:
86+
str: yaml containing env variables
87+
"""
6788
gpu_env_yaml = """
6889
- name: REPLICATED_JOB_NAME
6990
valueFrom:
@@ -73,8 +94,6 @@ def get_env_container(args, system: SystemCharacteristics) -> str:
7394
valueFrom:
7495
fieldRef:
7596
fieldPath: metadata.annotations['jobset.sigs.k8s.io/jobset-name']
76-
- name: JAX_COORDINATOR_ADDRESS
77-
value: "$(JOBSET_NAME)-$(REPLICATED_JOB_NAME)-0-0.$(JOBSET_NAME)"
7897
- name: NNODES
7998
value: "{args.num_nodes}"
8099
- name: NODE_RANK
@@ -84,32 +103,37 @@ def get_env_container(args, system: SystemCharacteristics) -> str:
84103
- name: USE_GPUDIRECT
85104
value: {gpu_direct_name}
86105
- name: GPUS_PER_NODE
87-
value: "{system.chips_per_vm}"
88-
- name: JAX_COORDINATOR_PORT
89-
value: "6002"
106+
value: "{chips_per_vm}"
90107
- name: COMMAND
91108
value: "{args.command}"
92-
{args.env}"""
93-
94-
if system.accelerator_type == AcceleratorType['GPU']:
95-
gpu_direct_name = 'fastrak'
96-
if args.device_type == H100_DEVICE_TYPE:
97-
gpu_direct_name = 'tcpx'
98-
elif args.device_type == H100_MEGA_DEVICE_TYPE:
99-
gpu_direct_name = 'tcpxo'
100-
elif args.device_type == H200_DEVICE_TYPE:
101-
gpu_direct_name = 'rdma'
102-
return gpu_env_yaml.format(
103-
args=args, system=system, gpu_direct_name=gpu_direct_name
104-
)
105-
106-
if system.accelerator_type == AcceleratorType['CPU']:
107-
return get_cpu_env(args.num_slices, args.env, system)
108-
109-
return args.env # pytype: disable=bad-return-type
109+
{custom_envs}"""
110+
111+
gpu_direct_name = 'fastrak'
112+
if args.device_type == H100_DEVICE_TYPE:
113+
gpu_direct_name = 'tcpx'
114+
elif args.device_type == H100_MEGA_DEVICE_TYPE:
115+
gpu_direct_name = 'tcpxo'
116+
elif args.device_type == H200_DEVICE_TYPE:
117+
gpu_direct_name = 'rdma'
118+
119+
gpu_env_dic = {
120+
'JAX_COORDINATOR_PORT': '6002',
121+
'JAX_COORDINATOR_ADDRESS': (
122+
'$(JOBSET_NAME)-$(REPLICATED_JOB_NAME)-0-0.$(JOBSET_NAME)'
123+
),
124+
}
125+
126+
args.env = gpu_env_dic | args.env
127+
128+
return gpu_env_yaml.format(
129+
args=args,
130+
chips_per_vm=system.chips_per_vm,
131+
gpu_direct_name=gpu_direct_name,
132+
custom_envs=format_env_dict(args.env, system),
133+
)
110134

111135

112-
def get_cpu_env(num_slices, env_vars, system) -> str:
136+
def get_cpu_env(args, system) -> str:
113137
"""Generate environment variables for CPU nodepools
114138
Args:
115139
num_slices: Number of slices to be used in the workload.
@@ -132,19 +156,87 @@ def get_cpu_env(num_slices, env_vars, system) -> str:
132156
valueFrom:
133157
fieldRef:
134158
fieldPath: metadata.annotations['batch.kubernetes.io/job-completion-index']
135-
- name: PROCESSES_IN_JOB
136-
value: "{processes_in_job}"
137-
- name: JAX_PROCESS_COUNT
138-
value: "{process_count}"
139-
{env_vars}
140-
- name: JAX_COORDINATOR_ADDRESS
141-
value: "$(JOBSET_NAME)-$(REPLICATED_JOB_NAME)-0-0.$(JOBSET_NAME)"
159+
{custom_envs}
142160
"""
143-
return yaml.format(
144-
processes_in_job=system.vms_per_slice,
145-
process_count=calculate_process_count(num_slices, system.vms_per_slice),
146-
env_vars=env_vars,
147-
)
161+
162+
cpu_env_dic = {
163+
'PROCESSES_IN_JOB': str(system.vms_per_slice),
164+
'JAX_PROCESS_COUNT': str(
165+
calculate_process_count(args.num_slices, system.vms_per_slice)
166+
),
167+
'JAX_COORDINATOR_ADDRESS': (
168+
'$(JOBSET_NAME)-$(REPLICATED_JOB_NAME)-0-0.$(JOBSET_NAME)'
169+
),
170+
}
171+
172+
args.env = cpu_env_dic | args.env
173+
174+
return yaml.format(custom_envs=format_env_dict(args.env, system))
175+
176+
177+
def format_env_dict(env, system: SystemCharacteristics) -> str:
178+
if system.accelerator_type == AcceleratorType['GPU']:
179+
# For GPUs, it has two more spaces ahead of name and value respectively
180+
env_format = '''
181+
- name: {key}
182+
value: "{value}"'''
183+
else:
184+
env_format = '''
185+
- name: {key}
186+
value: "{value}"'''
187+
return ''.join(env_format.format(key=k, value=v) for k, v in env.items())
188+
189+
190+
def parse_env_config(args, tensorboard_config):
191+
"""Parses the environment configurations to the a dictionary.
192+
193+
Args:
194+
args: user provided arguments for running the command.
195+
tensorboard_config: configuration of Vertex Tensorboard.
196+
system: system characteristics.
197+
"""
198+
env = {}
199+
200+
env_pat = re.compile(r'(^[a-zA-Z_][a-zA-Z0-9_]*?)(?:=(.*))?$', re.M)
201+
if args.env_file:
202+
print('Setting container environment from', args.env_file)
203+
with open(file=args.env_file, mode='r', encoding='utf-8') as f:
204+
for match in env_pat.finditer(f.read()):
205+
variable = match.group(1)
206+
if match.group(2) is not None:
207+
env[variable] = match.group(2)
208+
else:
209+
assert variable in os.environ, (
210+
f'Variable {variable} is not set in the current '
211+
'environment, a value must be specified.'
212+
)
213+
env[variable] = os.environ[variable]
214+
if args.env:
215+
for var in args.env:
216+
match = env_pat.match(var)
217+
assert match and match.group(2) is not None, (
218+
'Invalid environment variable, format must be '
219+
f'`--env VARIABLE=value`: {var}'
220+
)
221+
variable = match.group(1)
222+
env[variable] = match.group(2)
223+
224+
if not args.use_pathways:
225+
if args.debug_dump_gcs:
226+
if 'XLA_FLAGS' in env:
227+
raise ValueError(
228+
'Conflict: XLA_FLAGS defined in both --debug_dump_gcs '
229+
'and environment file. Please choose one way to define '
230+
'XLA_FLAGS.'
231+
)
232+
env['XLA_FLAGS'] = '--xla_dump_to=/tmp/xla_dump/'
233+
234+
if tensorboard_config:
235+
env['UPLOAD_DATA_TO_TENSORBOARD'] = True
236+
for key, value in tensorboard_config.items():
237+
env[key.upper()] = value
238+
239+
args.env = env
148240

149241

150242
def get_volumes(args, system: SystemCharacteristics) -> str:

src/xpk/core/kjob.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -40,11 +40,8 @@
4040
XpkConfig,
4141
)
4242
from .network import get_cluster_subnetworks
43-
from .resources import (
44-
AcceleratorType,
45-
SystemCharacteristics,
46-
get_cluster_system_characteristics,
47-
)
43+
from .system_characteristics import AcceleratorType, SystemCharacteristics
44+
from .resources import get_cluster_system_characteristics
4845
from .storage import (
4946
GCS_FUSE_ANNOTATIONS,
5047
PARALLELSTORE_ANNOTATIONS,

src/xpk/core/pathways.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,7 @@
1919
from ..core.gcloud_context import zone_to_region
2020
from ..core.nodepool import get_all_nodepools_programmatic
2121
from ..utils.console import xpk_exit, xpk_print
22-
from .config import AcceleratorType
23-
from .system_characteristics import SystemCharacteristics
22+
from .system_characteristics import AcceleratorType, SystemCharacteristics
2423

2524

2625
def add_pw_resource_flavors(args):

0 commit comments

Comments
 (0)