@@ -214,6 +214,7 @@ def __init__(
214
214
poll_delay = 0
215
215
self .poll_delay = poll_delay
216
216
self .sbatch_args = sbatch_args or ""
217
+ self .error = {}
217
218
218
219
def run_el (self , runnable , rerun = False ):
219
220
"""Worker submission API."""
@@ -224,28 +225,34 @@ def run_el(self, runnable, rerun=False):
224
225
225
226
async def _submit_job (self , task , batchscript ):
226
227
"""Coroutine that submits task runscript and polls job until completion or error."""
228
+ script_dir = (
229
+ task .cache_dir / f"{ self .__class__ .__name__ } _scripts" / task .checksum
230
+ )
227
231
sargs = self .sbatch_args .split ()
228
232
jobname = re .search (r"(?<=-J )\S+|(?<=--job-name=)\S+" , self .sbatch_args )
229
233
if not jobname :
230
234
jobname = "." .join ((task .name , task .checksum ))
231
235
sargs .append (f"--job-name={ jobname } " )
232
236
output = re .search (r"(?<=-o )\S+|(?<=--output=)\S+" , self .sbatch_args )
233
237
if not output :
234
- self . output = str (batchscript . parent / "slurm-%j.out" )
235
- sargs .append (f"--output={ self . output } " )
238
+ output_file = str (script_dir / "slurm-%j.out" )
239
+ sargs .append (f"--output={ output_file } " )
236
240
error = re .search (r"(?<=-e )\S+|(?<=--error=)\S+" , self .sbatch_args )
237
241
if not error :
238
- self .error = str (batchscript .parent / "slurm-%j.err" )
239
- sargs .append (f"--error={ self .error } " )
242
+ error_file = str (script_dir / "slurm-%j.err" )
243
+ sargs .append (f"--error={ error_file } " )
244
+ else :
245
+ error_file = None
240
246
sargs .append (str (batchscript ))
241
247
# TO CONSIDER: add random sleep to avoid overloading calls
242
248
_ , stdout , _ = await read_and_display_async ("sbatch" , * sargs , hide_display = True )
243
249
jobid = re .search (r"\d+" , stdout )
244
250
if not jobid :
245
251
raise RuntimeError ("Could not extract job ID" )
246
252
jobid = jobid .group ()
247
- self .output = self .output .replace ("%j" , jobid )
248
- self .error = self .error .replace ("%j" , jobid )
253
+ if error_file :
254
+ error_file = error_file .replace ("%j" , jobid )
255
+ self .error [jobid ] = error_file .replace ("%j" , jobid )
249
256
# intermittent polling
250
257
while True :
251
258
# 3 possibilities
@@ -273,12 +280,13 @@ async def _verify_exit_code(self, jobid):
273
280
if not stdout :
274
281
raise RuntimeError ("Job information not found" )
275
282
m = self ._sacct_re .search (stdout )
283
+ error_file = self .error [jobid ]
276
284
if int (m .group ("exit_code" )) != 0 or m .group ("status" ) != "COMPLETED" :
277
285
if m .group ("status" ) in ["RUNNING" , "PENDING" ]:
278
286
return False
279
287
# TODO: potential for requeuing
280
288
# parsing the error message
281
- error_line = Path (self . error ).read_text ().split ("\n " )[- 2 ]
289
+ error_line = Path (error_file ).read_text ().split ("\n " )[- 2 ]
282
290
if "Exception" in error_line :
283
291
error_message = error_line .replace ("Exception: " , "" )
284
292
elif "Error" in error_line :
0 commit comments