@@ -96,6 +96,20 @@ def update_image_pull_secrets(spec, image_pull_secrets):
9696 ]
9797
9898
99+ def update_volume_mounts (spec , volume_mounts : list ):
100+ containers = spec .get ("containers" )
101+ for volume_mount in volume_mounts :
102+ for container in containers :
103+ volumeMount = client .ApiClient ().sanitize_for_serialization (volume_mount )
104+ container ["volumeMounts" ].append (volumeMount )
105+
106+
107+ def update_volumes (spec , volumes : list ):
108+ for volume in volumes :
109+ new_volume = client .ApiClient ().sanitize_for_serialization (volume )
110+ spec ["volumes" ].append (new_volume )
111+
112+
99113def update_env (spec , env ):
100114 containers = spec .get ("containers" )
101115 for container in containers :
@@ -136,6 +150,8 @@ def update_nodes(
136150 head_cpus ,
137151 head_memory ,
138152 head_gpus ,
153+ volumes ,
154+ volume_mounts ,
139155):
140156 head = cluster_yaml .get ("spec" ).get ("headGroupSpec" )
141157 head ["rayStartParams" ]["num-gpus" ] = str (int (head_gpus ))
@@ -150,6 +166,8 @@ def update_nodes(
150166
151167 for comp in [head , worker ]:
152168 spec = comp .get ("template" ).get ("spec" )
169+ update_volume_mounts (spec , volume_mounts )
170+ update_volumes (spec , volumes )
153171 update_image_pull_secrets (spec , image_pull_secrets )
154172 update_image (spec , image )
155173 update_env (spec , env )
@@ -280,6 +298,8 @@ def generate_appwrapper(
280298 write_to_file : bool ,
281299 local_queue : Optional [str ],
282300 labels ,
301+ volumes : list [client .V1Volume ],
302+ volume_mounts : list [client .V1VolumeMount ],
283303):
284304 cluster_yaml = read_template (template )
285305 appwrapper_name , cluster_name = gen_names (name )
@@ -299,6 +319,8 @@ def generate_appwrapper(
299319 head_cpus ,
300320 head_memory ,
301321 head_gpus ,
322+ volumes ,
323+ volume_mounts ,
302324 )
303325 augment_labels (cluster_yaml , labels )
304326 notebook_annotations (cluster_yaml )
0 commit comments