Skip to content

Commit 7088769

Browse files
authored
Merge pull request #502 from ExaWorks/clean_queue_poll_threads
Fix thread "leak"
2 parents 7e7dcbe + 65a5c83 commit 7088769

File tree

1 file changed

+21
-13
lines changed

1 file changed

+21
-13
lines changed

src/psij/executors/batch/batch_scheduler_executor.py

Lines changed: 21 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import subprocess
44
import time
55
import traceback
6+
import weakref
67
from abc import abstractmethod
78
from datetime import timedelta
89
from pathlib import Path
@@ -639,45 +640,51 @@ def __init__(self, name: str, config: BatchSchedulerExecutorConfig,
639640
self.name = name
640641
self.daemon = True
641642
self.config = config
642-
self.executor = executor
643+
self.done = False
644+
self.executor = weakref.ref(executor, self._stop)
643645
# native_id -> job
644646
self._jobs: Dict[str, List[Job]] = {}
645647
# counts consecutive errors while invoking qstat or equivalent
646648
self._poll_error_count = 0
647649
self._jobs_lock = RLock()
648650

649651
def run(self) -> None:
650-
logger.debug('Executor %s: queue poll thread started', self.executor)
652+
logger.debug('Executor %s: queue poll thread started', self.executor())
651653
time.sleep(self.config.initial_queue_polling_delay)
652-
while True:
654+
while not self.done:
653655
self._poll()
654656
time.sleep(self.config.queue_polling_interval)
655657

658+
def _stop(self, exec: object) -> None:
659+
self.done = True
660+
656661
def _poll(self) -> None:
662+
exec = self.executor()
663+
if exec is None:
664+
return
657665
with self._jobs_lock:
658666
if len(self._jobs) == 0:
659667
return
660668
jobs_copy = dict(self._jobs)
661669
logger.info('Polling for %s jobs', len(jobs_copy))
662670
try:
663-
out = self.executor._run_command(self.executor.get_status_command(jobs_copy.keys()))
671+
if exec:
672+
out = exec._run_command(exec.get_status_command(jobs_copy.keys()))
664673
except subprocess.CalledProcessError as ex:
665674
out = ex.output
666675
exit_code = ex.returncode
667676
except Exception as ex:
668-
self._handle_poll_error(True,
669-
ex,
677+
self._handle_poll_error(exec, True, ex,
670678
f'Failed to poll for job status: {traceback.format_exc()}')
671679
return
672680
else:
673681
exit_code = 0
674682
self._poll_error_count = 0
675683
logger.debug('Output from status command: %s', out)
676684
try:
677-
status_map = self.executor.parse_status_output(exit_code, out)
685+
status_map = exec.parse_status_output(exit_code, out)
678686
except Exception as ex:
679-
self._handle_poll_error(False,
680-
ex,
687+
self._handle_poll_error(exec, False, ex,
681688
f'Failed to poll for job status: {traceback.format_exc()}')
682689
return
683690
try:
@@ -689,21 +696,22 @@ def _poll(self) -> None:
689696
message='Failed to update job status: %s' %
690697
traceback.format_exc())
691698
for job in job_list:
692-
self.executor._set_job_status(job, status)
699+
exec._set_job_status(job, status)
693700
if status.state.final:
694701
with self._jobs_lock:
695702
del self._jobs[native_id]
696703
except Exception as ex:
697704
msg = traceback.format_exc()
698-
self._handle_poll_error(True, ex, 'Error updating job statuses {}'.format(msg))
705+
self._handle_poll_error(exec, True, ex, 'Error updating job statuses {}'.format(msg))
699706

700707
def _get_job_status(self, native_id: str, status_map: Dict[str, JobStatus]) -> JobStatus:
701708
if native_id in status_map:
702709
return status_map[native_id]
703710
else:
704711
return JobStatus(JobState.COMPLETED)
705712

706-
def _handle_poll_error(self, immediate: bool, ex: Exception, msg: str) -> None:
713+
def _handle_poll_error(self, exec: BatchSchedulerExecutor, immediate: bool, ex: Exception,
714+
msg: str) -> None:
707715
logger.warning('Polling error: %s', msg)
708716
self._poll_error_count += 1
709717
if immediate or (self._poll_error_count > self.config.queue_polling_error_threshold):
@@ -720,7 +728,7 @@ def _handle_poll_error(self, immediate: bool, ex: Exception, msg: str) -> None:
720728
self._jobs.clear()
721729
for job_list in jobs_copy.values():
722730
for job in job_list:
723-
self.executor._set_job_status(job, JobStatus(JobState.FAILED, message=msg))
731+
exec._set_job_status(job, JobStatus(JobState.FAILED, message=msg))
724732

725733
def register_job(self, job: Job) -> None:
726734
assert job.native_id

0 commit comments

Comments
 (0)