@@ -286,6 +286,7 @@ class ResourceRequest:
286286 container_mounts : list [str ] = field (default_factory = list )
287287 env_vars : dict [str , str ] = field (default_factory = dict )
288288 srun_args : Optional [list [str ]] = None
289+ job_details : SlurmJobDetails = field (default_factory = SlurmJobDetails )
289290
290291 account : str
291292 partition : Optional [str ] = None
@@ -359,6 +360,7 @@ def merge(
359360 gpus_per_node = main_executor .gpus_per_node ,
360361 gpus_per_task = main_executor .gpus_per_task ,
361362 srun_args = main_executor .srun_args ,
363+ job_details = copy .deepcopy (main_executor .job_details ),
362364 )
363365 ]
364366
@@ -374,6 +376,7 @@ def merge(
374376 gpus_per_node = executor .gpus_per_node ,
375377 gpus_per_task = executor .gpus_per_task ,
376378 srun_args = executor .srun_args ,
379+ job_details = copy .deepcopy (executor .job_details ),
377380 )
378381 )
379382
@@ -720,10 +723,12 @@ def materialize(self) -> str:
720723 )
721724 # add necessary parameters
722725 original_job_name : str = self .jobs [0 ] # type: ignore
723- if self .slurm_config .job_name_prefix is None :
724- job_name = f"{ self .slurm_config .account } -{ self .slurm_config .account .split ('_' )[- 1 ]} .{ original_job_name } "
725- else :
726- job_name = f"{ self .slurm_config .job_name_prefix } { original_job_name } "
726+ job_name_prefix = (
727+ self .slurm_config .job_name_prefix
728+ if self .slurm_config .job_name_prefix
729+ else f"{ self .slurm_config .account } -{ self .slurm_config .account .split ('_' )[- 1 ]} ."
730+ )
731+ job_name = f"{ job_name_prefix } { original_job_name } "
727732 slurm_job_dir = (
728733 self .slurm_config .tunnel .job_dir
729734 if self .slurm_config .tunnel
@@ -812,9 +817,9 @@ def materialize(self) -> str:
812817 if self .slurm_config .stderr_to_stdout
813818 else ["--error" , noquote (job_details .srun_stderr )]
814819 )
815- memory_measure = None
820+ memory_measure_out = None
816821 if self .slurm_config .memory_measure :
817- memory_measure = srun_stdout
822+ memory_measure_out = srun_stdout
818823
819824 def get_container_flags (
820825 base_mounts : list [str ], src_job_dir : str , container_image : Optional [str ]
@@ -833,113 +838,90 @@ def get_container_flags(
833838 return _container_flags
834839
835840 for group_ind , command_group in enumerate (self .command_groups ):
836- if self .slurm_config .heterogeneous :
841+ if self .slurm_config .run_as_group and len (self .slurm_config .resource_group ) == len (
842+ self .command_groups
843+ ):
837844 resource_req = self .slurm_config .resource_group [group_ind ]
845+ if not resource_req .job_details .job_name :
846+ resource_req .job_details .job_name = f"{ job_name_prefix } { self .jobs [group_ind ]} "
838847
848+ if not resource_req .job_details .folder :
849+ resource_req .job_details .folder = os .path .join (
850+ slurm_job_dir , job_directory_name
851+ )
852+
853+ cmd_stdout = noquote (resource_req .job_details .srun_stdout )
854+ cmd_stderr = (
855+ []
856+ if self .slurm_config .stderr_to_stdout
857+ else [
858+ "--error" ,
859+ noquote (resource_req .job_details .srun_stderr ),
860+ ]
861+ )
839862 current_env_vars = []
840863 for key , value in resource_req .env_vars .items ():
841864 current_env_vars .append (f"export { key .upper ()} ={ value } " )
842865
843866 group_env_vars .append (current_env_vars )
844867
845- het_group = f"--het-group={ group_ind } "
846- het_stdout = srun_stdout .replace (original_job_name , self .jobs [group_ind ])
847- het_stderr = stderr_flags .copy ()
848- if het_stderr :
849- het_stderr [- 1 ] = het_stderr [- 1 ].replace (original_job_name , self .jobs [group_ind ])
850-
851- _group_srun_args = ["--wait=60" , "--kill-on-bad-exit=1" ]
852- _group_srun_args .extend (resource_req .srun_args or [])
853- srun_cmd = " " .join (
854- list (
855- map (
856- lambda arg : arg if isinstance (arg , noquote ) else shlex .quote (arg ),
857- [
858- "srun" ,
859- het_group ,
860- "--output" ,
861- het_stdout ,
862- * het_stderr ,
863- * get_container_flags (
864- base_mounts = resource_req .container_mounts ,
865- src_job_dir = os .path .join (slurm_job_dir , job_directory_name ),
866- container_image = resource_req .container_image ,
867- ),
868- * _group_srun_args ,
869- ],
870- )
871- )
868+ _container_flags = get_container_flags (
869+ base_mounts = resource_req .container_mounts ,
870+ src_job_dir = os .path .join (
871+ slurm_job_dir ,
872+ job_directory_name ,
873+ ),
874+ container_image = resource_req .container_image ,
872875 )
873-
874- command = ";\n " .join (command_group )
875-
876- srun_command = f"{ srun_cmd } { command } & pids[{ group_ind } ]=$!"
877- if group_ind != len (self .slurm_config .resource_group ) - 1 :
878- srun_command += f"\n \n sleep { self .slurm_config .wait_time_for_group_job } \n "
879- srun_commands .append (srun_command )
876+ _srun_args = ["--wait=60" , "--kill-on-bad-exit=1" ]
877+ _srun_args .extend (resource_req .srun_args or [])
880878 else :
881879 cmd_stdout = srun_stdout .replace (original_job_name , self .jobs [group_ind ])
882880 cmd_stderr = stderr_flags .copy ()
883881 if cmd_stderr :
884882 cmd_stderr [- 1 ] = cmd_stderr [- 1 ].replace (original_job_name , self .jobs [group_ind ])
883+ _container_flags = get_container_flags (
884+ base_mounts = self .slurm_config .container_mounts ,
885+ src_job_dir = os .path .join (
886+ slurm_job_dir ,
887+ job_directory_name ,
888+ ),
889+ container_image = self .slurm_config .container_image ,
890+ )
891+ _srun_args = ["--wait=60" , "--kill-on-bad-exit=1" ]
892+ _srun_args .extend (self .slurm_config .srun_args or [])
885893
886- if self .slurm_config .run_as_group and len (self .slurm_config .resource_group ) == len (
887- self .command_groups
888- ):
889- resource_req = self .slurm_config .resource_group [group_ind ]
890- current_env_vars = []
891- for key , value in resource_req .env_vars .items ():
892- current_env_vars .append (f"export { key .upper ()} ={ value } " )
893-
894- group_env_vars .append (current_env_vars )
895-
896- _container_flags = get_container_flags (
897- base_mounts = resource_req .container_mounts ,
898- src_job_dir = os .path .join (
899- slurm_job_dir ,
900- job_directory_name ,
901- ),
902- container_image = resource_req .container_image ,
903- )
904- _srun_args = ["--wait=60" , "--kill-on-bad-exit=1" ]
905- _srun_args .extend (resource_req .srun_args or [])
906- else :
907- _container_flags = get_container_flags (
908- base_mounts = self .slurm_config .container_mounts ,
909- src_job_dir = os .path .join (
910- slurm_job_dir ,
911- job_directory_name ,
912- ),
913- container_image = self .slurm_config .container_image ,
914- )
915- _srun_args = ["--wait=60" , "--kill-on-bad-exit=1" ]
916- _srun_args .extend (self .slurm_config .srun_args or [])
917-
918- srun_cmd = " " .join (
919- list (
920- map (
921- lambda arg : arg if isinstance (arg , noquote ) else shlex .quote (arg ),
922- [
923- "srun" ,
924- "--output" ,
925- cmd_stdout ,
926- * cmd_stderr ,
927- * _container_flags ,
928- * _srun_args ,
929- ],
930- )
894+ if self .slurm_config .run_as_group and self .slurm_config .heterogeneous :
895+ het_group_flag = [f"--het-group={ group_ind } " ]
896+ else :
897+ het_group_flag = []
898+
899+ srun_cmd = " " .join (
900+ list (
901+ map (
902+ lambda arg : arg if isinstance (arg , noquote ) else shlex .quote (arg ),
903+ [
904+ "srun" ,
905+ * het_group_flag ,
906+ "--output" ,
907+ cmd_stdout ,
908+ * cmd_stderr ,
909+ * _container_flags ,
910+ * _srun_args ,
911+ ],
931912 )
932913 )
933- command = " " .join (command_group )
914+ )
915+ command = " " .join (command_group )
934916
935- if self .slurm_config .run_as_group :
936- srun_command = f"{ srun_cmd } { command } & pids[{ group_ind } ]=$!"
937- if group_ind != len (self .command_groups ) - 1 :
938- srun_command += f"\n \n sleep { self .slurm_config .wait_time_for_group_job } \n "
939- else :
940- srun_command = f"{ srun_cmd } { command } "
917+ if self .slurm_config .run_as_group :
918+ srun_command = f"{ srun_cmd } { command } & pids[{ group_ind } ]=$!"
919+ if group_ind != len (self .command_groups ) - 1 :
920+ srun_command += f"\n \n sleep { self .slurm_config .wait_time_for_group_job } \n "
921+ else :
922+ srun_command = f"{ srun_cmd } { command } "
941923
942- srun_commands .append (srun_command )
924+ srun_commands .append (srun_command )
943925
944926 vars_to_fill = {
945927 "sbatch_command" : sbatch_cmd ,
@@ -948,7 +930,7 @@ def get_container_flags(
948930 "env_vars" : env_vars ,
949931 "head_node_ip_var" : SlurmExecutor .HEAD_NODE_IP_VAR ,
950932 "setup_lines" : self .slurm_config .setup_lines ,
951- "memory_measure" : memory_measure ,
933+ "memory_measure" : memory_measure_out ,
952934 "srun_commands" : srun_commands ,
953935 "group_env_vars" : group_env_vars ,
954936 "heterogeneous" : self .slurm_config .heterogeneous ,
0 commit comments