Skip to content

Commit 3029f0a

Browse files
authored
Merge pull request #248 from djarecka/fix/slurm_errorfiles
fixing error files for the slurm worker
2 parents d009be1 + 6414c21 commit 3029f0a

File tree

1 file changed

+15
-7
lines changed

1 file changed

+15
-7
lines changed

pydra/engine/workers.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -214,6 +214,7 @@ def __init__(
214214
poll_delay = 0
215215
self.poll_delay = poll_delay
216216
self.sbatch_args = sbatch_args or ""
217+
self.error = {}
217218

218219
def run_el(self, runnable, rerun=False):
219220
"""Worker submission API."""
@@ -224,28 +225,34 @@ def run_el(self, runnable, rerun=False):
224225

225226
async def _submit_job(self, task, batchscript):
226227
"""Coroutine that submits task runscript and polls job until completion or error."""
228+
script_dir = (
229+
task.cache_dir / f"{self.__class__.__name__}_scripts" / task.checksum
230+
)
227231
sargs = self.sbatch_args.split()
228232
jobname = re.search(r"(?<=-J )\S+|(?<=--job-name=)\S+", self.sbatch_args)
229233
if not jobname:
230234
jobname = ".".join((task.name, task.checksum))
231235
sargs.append(f"--job-name={jobname}")
232236
output = re.search(r"(?<=-o )\S+|(?<=--output=)\S+", self.sbatch_args)
233237
if not output:
234-
self.output = str(batchscript.parent / "slurm-%j.out")
235-
sargs.append(f"--output={self.output}")
238+
output_file = str(script_dir / "slurm-%j.out")
239+
sargs.append(f"--output={output_file}")
236240
error = re.search(r"(?<=-e )\S+|(?<=--error=)\S+", self.sbatch_args)
237241
if not error:
238-
self.error = str(batchscript.parent / "slurm-%j.err")
239-
sargs.append(f"--error={self.error}")
242+
error_file = str(script_dir / "slurm-%j.err")
243+
sargs.append(f"--error={error_file}")
244+
else:
245+
error_file = None
240246
sargs.append(str(batchscript))
241247
# TO CONSIDER: add random sleep to avoid overloading calls
242248
_, stdout, _ = await read_and_display_async("sbatch", *sargs, hide_display=True)
243249
jobid = re.search(r"\d+", stdout)
244250
if not jobid:
245251
raise RuntimeError("Could not extract job ID")
246252
jobid = jobid.group()
247-
self.output = self.output.replace("%j", jobid)
248-
self.error = self.error.replace("%j", jobid)
253+
if error_file:
254+
error_file = error_file.replace("%j", jobid)
255+
self.error[jobid] = error_file.replace("%j", jobid)
249256
# intermittent polling
250257
while True:
251258
# 3 possibilities
@@ -273,12 +280,13 @@ async def _verify_exit_code(self, jobid):
273280
if not stdout:
274281
raise RuntimeError("Job information not found")
275282
m = self._sacct_re.search(stdout)
283+
error_file = self.error[jobid]
276284
if int(m.group("exit_code")) != 0 or m.group("status") != "COMPLETED":
277285
if m.group("status") in ["RUNNING", "PENDING"]:
278286
return False
279287
# TODO: potential for requeuing
280288
# parsing the error message
281-
error_line = Path(self.error).read_text().split("\n")[-2]
289+
error_line = Path(error_file).read_text().split("\n")[-2]
282290
if "Exception" in error_line:
283291
error_message = error_line.replace("Exception: ", "")
284292
elif "Error" in error_line:

0 commit comments

Comments
 (0)