33import subprocess
44import time
55import traceback
6+ import weakref
67from abc import abstractmethod
78from datetime import timedelta
89from pathlib import Path
@@ -639,45 +640,51 @@ def __init__(self, name: str, config: BatchSchedulerExecutorConfig,
639640 self .name = name
640641 self .daemon = True
641642 self .config = config
642- self .executor = executor
643+ self .done = False
644+ self .executor = weakref .ref (executor , self ._stop )
643645 # native_id -> job
644646 self ._jobs : Dict [str , List [Job ]] = {}
645647 # counts consecutive errors while invoking qstat or equivalent
646648 self ._poll_error_count = 0
647649 self ._jobs_lock = RLock ()
648650
649651 def run (self ) -> None :
650- logger .debug ('Executor %s: queue poll thread started' , self .executor )
652+ logger .debug ('Executor %s: queue poll thread started' , self .executor () )
651653 time .sleep (self .config .initial_queue_polling_delay )
652- while True :
654+ while not self . done :
653655 self ._poll ()
654656 time .sleep (self .config .queue_polling_interval )
655657
658+ def _stop (self , exec : object ) -> None :
659+ self .done = True
660+
656661 def _poll (self ) -> None :
662+ exec = self .executor ()
663+ if exec is None :
664+ return
657665 with self ._jobs_lock :
658666 if len (self ._jobs ) == 0 :
659667 return
660668 jobs_copy = dict (self ._jobs )
661669 logger .info ('Polling for %s jobs' , len (jobs_copy ))
662670 try :
663- out = self .executor ._run_command (self .executor .get_status_command (jobs_copy .keys ()))
671+ if exec :
672+ out = exec ._run_command (exec .get_status_command (jobs_copy .keys ()))
664673 except subprocess .CalledProcessError as ex :
665674 out = ex .output
666675 exit_code = ex .returncode
667676 except Exception as ex :
668- self ._handle_poll_error (True ,
669- ex ,
677+ self ._handle_poll_error (exec , True , ex ,
670678 f'Failed to poll for job status: { traceback .format_exc ()} ' )
671679 return
672680 else :
673681 exit_code = 0
674682 self ._poll_error_count = 0
675683 logger .debug ('Output from status command: %s' , out )
676684 try :
677- status_map = self . executor .parse_status_output (exit_code , out )
685+ status_map = exec .parse_status_output (exit_code , out )
678686 except Exception as ex :
679- self ._handle_poll_error (False ,
680- ex ,
687+ self ._handle_poll_error (exec , False , ex ,
681688 f'Failed to poll for job status: { traceback .format_exc ()} ' )
682689 return
683690 try :
@@ -689,21 +696,22 @@ def _poll(self) -> None:
689696 message = 'Failed to update job status: %s' %
690697 traceback .format_exc ())
691698 for job in job_list :
692- self . executor ._set_job_status (job , status )
699+ exec ._set_job_status (job , status )
693700 if status .state .final :
694701 with self ._jobs_lock :
695702 del self ._jobs [native_id ]
696703 except Exception as ex :
697704 msg = traceback .format_exc ()
698- self ._handle_poll_error (True , ex , 'Error updating job statuses {}' .format (msg ))
705+ self ._handle_poll_error (exec , True , ex , 'Error updating job statuses {}' .format (msg ))
699706
700707 def _get_job_status (self , native_id : str , status_map : Dict [str , JobStatus ]) -> JobStatus :
701708 if native_id in status_map :
702709 return status_map [native_id ]
703710 else :
704711 return JobStatus (JobState .COMPLETED )
705712
706- def _handle_poll_error (self , immediate : bool , ex : Exception , msg : str ) -> None :
713+ def _handle_poll_error (self , exec : BatchSchedulerExecutor , immediate : bool , ex : Exception ,
714+ msg : str ) -> None :
707715 logger .warning ('Polling error: %s' , msg )
708716 self ._poll_error_count += 1
709717 if immediate or (self ._poll_error_count > self .config .queue_polling_error_threshold ):
@@ -720,7 +728,7 @@ def _handle_poll_error(self, immediate: bool, ex: Exception, msg: str) -> None:
720728 self ._jobs .clear ()
721729 for job_list in jobs_copy .values ():
722730 for job in job_list :
723- self . executor ._set_job_status (job , JobStatus (JobState .FAILED , message = msg ))
731+ exec ._set_job_status (job , JobStatus (JobState .FAILED , message = msg ))
724732
725733 def register_job (self , job : Job ) -> None :
726734 assert job .native_id
0 commit comments