Skip to content

Commit 7bff43b

Browse files
committed
wip
1 parent 30dbd7b commit 7bff43b

File tree

6 files changed

+45
-43
lines changed

6 files changed

+45
-43
lines changed

scheduler/worker/commands/kill_worker.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,6 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
1919

2020
def process_command(self, connection: ConnectionType) -> None:
2121
from scheduler.worker import Worker
22-
if self.worker_name is None:
23-
raise ValueError("Worker name must be provided")
2422
logger.info(f"Received kill-worker command for {self.worker_name}")
2523
worker_model = WorkerModel.get(self.worker_name, connection)
2624
if worker_model is None or worker_model.pid is None:

scheduler/worker/commands/stop_job.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
import os
22
import signal
3-
from typing import Dict, Any, Optional
3+
from typing import Dict, Any
44

5-
from scheduler.types import ConnectionType
65
from scheduler.redis_models import WorkerModel, JobModel
76
from scheduler.settings import logger
7+
from scheduler.types import ConnectionType
88
from scheduler.worker.commands.worker_commands import WorkerCommand, WorkerCommandError
99

1010

@@ -13,13 +13,13 @@ class StopJobCommand(WorkerCommand):
1313

1414
command_name = "stop-job"
1515

16-
def __init__(self, *args, job_name: str, worker_name: Optional[str], **kwargs) -> None:
16+
def __init__(self, *args: Any, job_name: str, worker_name: str, **kwargs: Any) -> None:
1717
super().__init__(*args, worker_name=worker_name, **kwargs)
1818
self.job_name = job_name
1919
if self.job_name is None:
2020
raise WorkerCommandError("job_name for kill-job command is required")
2121

22-
def command_payload(self, **kwargs) -> Dict[str, Any]:
22+
def command_payload(self, **kwargs: Any) -> Dict[str, Any]:
2323
return super().command_payload(job_name=self.job_name)
2424

2525
def process_command(self, connection: ConnectionType) -> None:

scheduler/worker/commands/suspend_worker.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1-
from scheduler.types import ConnectionType
21
from scheduler.redis_models import WorkerModel
32
from scheduler.settings import logger
3+
from scheduler.types import ConnectionType
44
from scheduler.worker.commands.worker_commands import WorkerCommand
55

66

@@ -32,6 +32,7 @@ def process_command(self, connection: ConnectionType) -> None:
3232
worker_model = WorkerModel.get(self.worker_name, connection)
3333
if worker_model is None:
3434
logger.warning(f"Worker {self.worker_name} not found")
35+
return
3536
if not worker_model.is_suspended:
3637
logger.warning(f"Worker {self.worker_name} not suspended and therefore can't be resumed")
3738
return

scheduler/worker/commands/worker_commands.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
import json
22
from abc import ABC
33
from datetime import datetime, timezone
4-
from typing import Type, Dict, Any, Optional
4+
from typing import Type, Dict, Any
55

66
from scheduler.settings import logger
77
from scheduler.types import ConnectionType, Self
88

99
_PUBSUB_CHANNEL_TEMPLATE: str = ":workers:pubsub:{}"
10+
_WORKER_COMMANDS_REGISTRY: Dict[str, Type["WorkerCommand"]] = dict()
1011

1112

1213
class WorkerCommandError(Exception):
@@ -16,13 +17,12 @@ class WorkerCommandError(Exception):
1617
class WorkerCommand(ABC):
1718
"""Abstract class for commands to be sent to a worker and processed by worker"""
1819

19-
_registry: Dict[str, Type[Self]] = dict()
2020
command_name: str = ""
2121

22-
def __init__(self, *args, worker_name: Optional[str], **kwargs) -> None:
23-
self.worker_name: Optional[str] = worker_name
22+
def __init__(self, *args: Any, worker_name: str, **kwargs: Any) -> None:
23+
self.worker_name: str = worker_name
2424

25-
def command_payload(self, **kwargs) -> Dict[str, Any]:
25+
def command_payload(self, **kwargs: Any) -> Dict[str, Any]:
2626
commands_channel = WorkerCommandsChannelListener._commands_channel(self.worker_name)
2727
payload = {
2828
"command": self.command_name,
@@ -41,17 +41,19 @@ def process_command(self, connection: ConnectionType) -> None:
4141
raise NotImplementedError
4242

4343
@classmethod
44-
def __init_subclass__(cls, *args, **kwargs):
44+
def __init_subclass__(cls, *args: Any, **kwargs: Any) -> None:
4545
if cls is WorkerCommand:
4646
return
4747
if not cls.command_name:
48-
raise NotImplementedError(f"{cls.__name__} must have a name attribute")
49-
WorkerCommand._registry[cls.command_name] = cls
48+
raise NotImplementedError(f"{cls.__name__} must have a command_name attribute")
49+
_WORKER_COMMANDS_REGISTRY[cls.command_name] = cls
5050

5151
@classmethod
5252
def from_payload(cls, payload: Dict[str, Any]) -> Type[Self]:
5353
command_name = payload.get("command")
54-
command_class = WorkerCommand._registry.get(command_name)
54+
if command_name is None:
55+
raise WorkerCommandError("Payload must contain 'command' key")
56+
command_class = _WORKER_COMMANDS_REGISTRY.get(command_name)
5557
if command_class is None:
5658
raise WorkerCommandError(f"Invalid command: {command_name}")
5759
return command_class(**payload)

scheduler/worker/scheduler.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,15 +41,15 @@ def __init__(
4141
interval = interval or SCHEDULER_CONFIG.SCHEDULER_INTERVAL
4242
self._queues = queues
4343
self._scheduled_job_registries: List[ScheduledJobRegistry] = []
44-
self.lock_acquisition_time = None
44+
self.lock_acquisition_time: Optional[datetime] = None
4545
self._pool_class = connection.connection_pool.connection_class
4646
self._pool_kwargs = connection.connection_pool.connection_kwargs.copy()
4747
self._locks: Dict[str, SchedulerLock] = dict()
4848
self.connection = connection
4949
self.interval = interval
5050
self._stop_requested = False
5151
self.status = SchedulerStatus.STOPPED
52-
self._thread = None
52+
self._thread: Optional[Thread] = None
5353
self._pid: Optional[int] = None
5454
self.worker_name = worker_name
5555

@@ -163,7 +163,7 @@ def enqueue_scheduled_jobs(self) -> None:
163163
self.status = SchedulerStatus.STARTED
164164

165165

166-
def run_scheduler(scheduler: WorkerScheduler):
166+
def run_scheduler(scheduler: WorkerScheduler) -> None:
167167
try:
168168
scheduler.work()
169169
except: # noqa

scheduler/worker/worker.py

Lines changed: 25 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from enum import Enum
1515
from random import shuffle
1616
from types import FrameType
17-
from typing import List, Optional, Tuple, Any, Iterable, Collection
17+
from typing import List, Optional, Tuple, Any, Iterable, Collection, Union
1818

1919
import scheduler
2020
from scheduler.helpers.queues import get_queue
@@ -63,7 +63,7 @@ class QueueConnectionDiscrepancyError(Exception):
6363
)
6464

6565

66-
def signal_name(signum) -> str:
66+
def signal_name(signum: int) -> str:
6767
try:
6868
return signal.Signals(signum).name
6969
except KeyError:
@@ -99,19 +99,19 @@ def from_model(cls, model: WorkerModel) -> Self:
9999
return res
100100

101101
def __init__(
102-
self,
103-
queues,
104-
name: str,
105-
connection: ConnectionType,
106-
maintenance_interval: int = SCHEDULER_CONFIG.DEFAULT_MAINTENANCE_TASK_INTERVAL,
107-
job_monitoring_interval=SCHEDULER_CONFIG.DEFAULT_JOB_MONITORING_INTERVAL,
108-
dequeue_strategy: DequeueStrategy = DequeueStrategy.DEFAULT,
109-
disable_default_exception_handler: bool = False,
110-
fork_job_execution: bool = True,
111-
with_scheduler: bool = True,
112-
burst: bool = False,
113-
model: Optional[WorkerModel] = None,
114-
): # noqa
102+
self,
103+
queues: Iterable[Union[str, Queue]],
104+
name: str,
105+
connection: ConnectionType,
106+
maintenance_interval: int = SCHEDULER_CONFIG.DEFAULT_MAINTENANCE_TASK_INTERVAL,
107+
job_monitoring_interval: int = SCHEDULER_CONFIG.DEFAULT_JOB_MONITORING_INTERVAL,
108+
dequeue_strategy: DequeueStrategy = DequeueStrategy.DEFAULT,
109+
disable_default_exception_handler: bool = False,
110+
fork_job_execution: bool = True,
111+
with_scheduler: bool = True,
112+
burst: bool = False,
113+
model: Optional[WorkerModel] = None,
114+
) -> None:
115115
self.fork_job_execution = fork_job_execution
116116
self.job_monitoring_interval: int = job_monitoring_interval
117117
self.maintenance_interval = maintenance_interval
@@ -232,7 +232,7 @@ def work(self, max_jobs: Optional[int] = None, max_idle_time: Optional[int] = No
232232

233233
timeout = None if self.burst else (SCHEDULER_CONFIG.DEFAULT_WORKER_TTL - 15)
234234
job, queue = self.dequeue_job_and_maintain_ttl(timeout, max_idle_time)
235-
if job is None:
235+
if job is None or queue is None:
236236
if self.burst:
237237
logger.info(f"[Worker {self.name}/{self._pid}]: done, quitting")
238238
break
@@ -267,7 +267,7 @@ def work(self, max_jobs: Optional[int] = None, max_idle_time: Optional[int] = No
267267
self.teardown()
268268
return False
269269

270-
def handle_job_failure(self, job: JobModel, queue: Queue, exc_string="") -> None:
270+
def handle_job_failure(self, job: JobModel, queue: Queue, exc_string: str = "") -> None:
271271
"""
272272
Handles the failure or an executing job by:
273273
1. Setting the job status to failed
@@ -312,7 +312,7 @@ def handle_job_failure(self, job: JobModel, queue: Queue, exc_string="") -> None
312312
# Ensure that custom exception handlers are called even if the Broker is down
313313
pass
314314

315-
def bootstrap(self)-> None:
315+
def bootstrap(self) -> None:
316316
"""Bootstraps the worker.
317317
Runs the basic tasks that should run when the worker actually starts working.
318318
Used so that new workers can focus on the work loop implementation rather
@@ -327,7 +327,8 @@ def bootstrap(self)-> None:
327327
self._model.has_scheduler = True
328328
self._model.save(connection=self.connection)
329329
if self.with_scheduler and self.burst:
330-
self.scheduler.request_stop_and_wait()
330+
if self.scheduler is not None:
331+
self.scheduler.request_stop_and_wait()
331332
self._model.has_scheduler = False
332333
self._model.save(connection=self.connection)
333334
qnames = [queue.name for queue in self.queues]
@@ -375,8 +376,8 @@ def run_maintenance_tasks(self) -> None:
375376
self._model.save(connection=self.connection)
376377

377378
def dequeue_job_and_maintain_ttl(
378-
self, timeout: Optional[int], max_idle_time: Optional[int] = None
379-
) -> Tuple[JobModel, Queue]:
379+
self, timeout: Optional[int], max_idle_time: Optional[int] = None
380+
) -> Tuple[Optional[JobModel], Optional[Queue]]:
380381
"""Dequeues a job while maintaining the TTL.
381382
:param timeout: The timeout for the dequeue operation.
382383
:param max_idle_time: The maximum idle time for the worker.
@@ -550,7 +551,7 @@ def reorder_queues(self, reference_queue: Queue) -> None:
550551
return
551552
if self._dequeue_strategy == DequeueStrategy.ROUND_ROBIN:
552553
pos = self._ordered_queues.index(reference_queue)
553-
self._ordered_queues = self._ordered_queues[pos + 1 :] + self._ordered_queues[: pos + 1]
554+
self._ordered_queues = self._ordered_queues[pos + 1:] + self._ordered_queues[: pos + 1]
554555
return
555556
if self._dequeue_strategy == DequeueStrategy.RANDOM:
556557
shuffle(self._ordered_queues)
@@ -634,7 +635,7 @@ def monitor_job_execution_process(self, job: JobModel, queue: Queue) -> None:
634635
while True:
635636
try:
636637
with SCHEDULER_CONFIG.DEATH_PENALTY_CLASS(
637-
self.job_monitoring_interval, JobExecutionMonitorTimeoutException
638+
self.job_monitoring_interval, JobExecutionMonitorTimeoutException
638639
):
639640
retpid, ret_val = self.wait_for_job_execution_process()
640641
break
@@ -877,7 +878,7 @@ class RoundRobinWorker(Worker):
877878

878879
def reorder_queues(self, reference_queue: Queue) -> None:
879880
pos = self._ordered_queues.index(reference_queue)
880-
self._ordered_queues = self._ordered_queues[pos + 1 :] + self._ordered_queues[: pos + 1]
881+
self._ordered_queues = self._ordered_queues[pos + 1:] + self._ordered_queues[: pos + 1]
881882

882883

883884
class RandomWorker(Worker):

0 commit comments

Comments
 (0)