Skip to content

Commit a914987

Browse files
committed
fix:registries do not keep connections
1 parent a5aa239 commit a914987

24 files changed

+135
-148
lines changed

scheduler/helpers/queues/getters.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -46,13 +46,6 @@ def _get_connection(config: QueueConfiguration, use_strict_broker: bool = False)
4646
)
4747

4848

49-
def refresh_queue_connection(queue: Queue) -> None:
50-
"""Refreshes the connection of a given Queue"""
51-
queue_settings = get_queue_configuration(queue.name)
52-
connection = _get_connection(queue_settings)
53-
queue.refresh_connection(connection)
54-
55-
5649
def get_queue(name: str = "default") -> Queue:
5750
"""Returns an DjangoQueue using parameters defined in `SCHEDULER_QUEUES`"""
5851
queue_settings = get_queue_configuration(name)

scheduler/helpers/queues/queue_logic.py

Lines changed: 16 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -82,21 +82,12 @@ def __init__(self, connection: ConnectionType, name: str, is_async: bool = True)
8282
self.connection: ConnectionType = connection
8383
self.name = name
8484
self._is_async = is_async
85-
self.queued_job_registry = QueuedJobRegistry(connection=self.connection, name=self.name)
86-
self.active_job_registry = ActiveJobRegistry(connection=self.connection, name=self.name)
87-
self.failed_job_registry = FailedJobRegistry(connection=self.connection, name=self.name)
88-
self.finished_job_registry = FinishedJobRegistry(connection=self.connection, name=self.name)
89-
self.scheduled_job_registry = ScheduledJobRegistry(connection=self.connection, name=self.name)
90-
self.canceled_job_registry = CanceledJobRegistry(connection=self.connection, name=self.name)
91-
92-
def refresh_connection(self, connection: ConnectionType) -> None:
93-
self.connection = connection
94-
self.queued_job_registry.connection = connection
95-
self.active_job_registry.connection = connection
96-
self.failed_job_registry.connection = connection
97-
self.finished_job_registry.connection = connection
98-
self.scheduled_job_registry.connection = connection
99-
self.canceled_job_registry.connection = connection
85+
self.queued_job_registry = QueuedJobRegistry(name=self.name)
86+
self.active_job_registry = ActiveJobRegistry(name=self.name)
87+
self.failed_job_registry = FailedJobRegistry(name=self.name)
88+
self.finished_job_registry = FinishedJobRegistry(name=self.name)
89+
self.scheduled_job_registry = ScheduledJobRegistry(name=self.name)
90+
self.canceled_job_registry = CanceledJobRegistry(name=self.name)
10091

10192
def __len__(self) -> int:
10293
return self.count
@@ -114,7 +105,7 @@ def clean_registries(self, timestamp: Optional[float] = None) -> None:
114105
Removed jobs are added to the global failed job queue.
115106
"""
116107
before_score = timestamp or current_timestamp()
117-
self.queued_job_registry.compact()
108+
self.queued_job_registry.compact(self.connection)
118109
started_jobs: List[Tuple[str, float]] = self.active_job_registry.get_job_names_before(
119110
self.connection, before_score
120111
)
@@ -142,7 +133,7 @@ def clean_registries(self, timestamp: Optional[float] = None) -> None:
142133
getattr(self, registry).cleanup(connection=self.connection, timestamp=before_score)
143134

144135
def first_queued_job_name(self) -> Optional[str]:
145-
return self.queued_job_registry.get_first()
136+
return self.queued_job_registry.get_first(self.connection)
146137

147138
@property
148139
def count(self) -> int:
@@ -160,12 +151,12 @@ def get_registry(self, name: str) -> JobNamesRegistry:
160151

161152
def get_all_job_names(self) -> List[str]:
162153
all_job_names = list()
163-
all_job_names.extend(self.queued_job_registry.all())
164-
all_job_names.extend(self.finished_job_registry.all())
165-
all_job_names.extend(self.active_job_registry.all())
166-
all_job_names.extend(self.failed_job_registry.all())
167-
all_job_names.extend(self.scheduled_job_registry.all())
168-
all_job_names.extend(self.canceled_job_registry.all())
154+
all_job_names.extend(self.queued_job_registry.all(self.connection))
155+
all_job_names.extend(self.finished_job_registry.all(self.connection))
156+
all_job_names.extend(self.active_job_registry.all(self.connection))
157+
all_job_names.extend(self.failed_job_registry.all(self.connection))
158+
all_job_names.extend(self.scheduled_job_registry.all(self.connection))
159+
all_job_names.extend(self.canceled_job_registry.all(self.connection))
169160
res = list(filter(lambda job_name: JobModel.exists(job_name, self.connection), all_job_names))
170161
return res
171162

@@ -307,7 +298,7 @@ def dequeue_any(
307298
while True:
308299
registries = [q.queued_job_registry for q in queues]
309300
for registry in registries:
310-
registry.compact()
301+
registry.compact(connection)
311302

312303
registry_key, job_name = QueuedJobRegistry.pop(connection, registries, timeout)
313304
if job_name is None:
@@ -416,7 +407,7 @@ def enqueue_job(
416407
if at_front:
417408
score = current_timestamp()
418409
else:
419-
score = self.queued_job_registry.get_last_timestamp() or current_timestamp()
410+
score = self.queued_job_registry.get_last_timestamp(self.connection) or current_timestamp()
420411
self.scheduled_job_registry.delete(connection=pipe, job_name=job_model.name)
421412
self.queued_job_registry.add(connection=pipe, score=score, job_name=job_model.name)
422413
pipe.execute()

scheduler/management/commands/delete_failed_executions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ def add_arguments(self, parser: CommandParser) -> None:
1616

1717
def handle(self, *args, **options):
1818
queue = get_queue(options.get("queue", "default"))
19-
job_names = queue.failed_job_registry.all()
19+
job_names = queue.failed_job_registry.all(queue.connection)
2020
jobs = JobModel.get_many(job_names, connection=queue.connection)
2121
func_name = options.get("func", None)
2222
if func_name is not None:

scheduler/models/task.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -186,12 +186,14 @@ def is_scheduled(self) -> bool:
186186
if self.job_name is None: # no job_id => is not scheduled
187187
return False
188188
# check whether job_id is in scheduled/queued/active jobs
189-
res = (
190-
(self.job_name in self.rqueue.scheduled_job_registry.all())
191-
or (self.job_name in self.rqueue.queued_job_registry.all())
192-
or (self.job_name in self.rqueue.active_job_registry.all())
193-
)
194-
# If the job_id is not scheduled/queued/started,
189+
with self.rqueue.connection.pipeline() as pipeline:
190+
self.rqueue.scheduled_job_registry.exists(pipeline, self.job_name)
191+
self.rqueue.queued_job_registry.exists(pipeline, self.job_name)
192+
self.rqueue.active_job_registry.exists(pipeline, self.job_name)
193+
results = pipeline.execute()
194+
res = any([item is not None for item in results])
195+
196+
# If the job_name is not scheduled/queued/started,
195197
# update the job_id to None. (The job_id belongs to a previous run which is completed)
196198
if not res:
197199
self.job_name = None

scheduler/redis_models/registry/base_registry.py

Lines changed: 20 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -32,54 +32,52 @@ def delete(self, connection: ConnectionType, job_name: str) -> None:
3232
logger.debug(f"[registry {self._key}] Deleting {job_name}")
3333
connection.zrem(self._key, job_name)
3434

35+
def exists(self, connection: ConnectionType, job_name: str) -> bool:
36+
return connection.zrank(self._key, job_name) is not None
37+
3538

3639
class JobNamesRegistry(ZSetModel):
3740
_element_key_template: ClassVar[str] = ":registry:{}"
3841

39-
def __init__(self, connection: ConnectionType, name: str) -> None:
42+
def __init__(self, name: str) -> None:
4043
super().__init__(name=name)
41-
self.connection = connection
42-
43-
def __len__(self) -> int:
44-
return self.count(self.connection)
45-
46-
def __contains__(self, item: str) -> bool:
47-
return self.connection.zrank(self._key, item) is not None
4844

49-
def all(self, start: int = 0, end: int = -1) -> List[str]:
45+
def all(self, connection: ConnectionType, start: int = 0, end: int = -1) -> List[str]:
5046
"""Returns a list of all job names.
5147
48+
:param connection: Broker connection
5249
:param start: Start score/timestamp, default to 0.
5350
:param end: End score/timestamp, default to -1 (i.e., no max score).
5451
:returns: Returns a list of all job names with timestamp from start to end
5552
"""
56-
self.cleanup(self.connection)
57-
res = [as_str(job_name) for job_name in self.connection.zrange(self._key, start, end)]
58-
logger.debug(f"Getting jobs for registry {self._key}: {len(res)} found.")
53+
self.cleanup(connection)
54+
res = [as_str(job_name) for job_name in connection.zrange(self._key, start, end)]
55+
logger.debug(f"Getting jobs for registry {self.key}: {len(res)} found.")
5956
return res
6057

61-
def all_with_timestamps(self, start: int = 0, end: int = -1) -> List[Tuple[str, float]]:
58+
def all_with_timestamps(self, connection: ConnectionType, start: int = 0, end: int = -1) -> List[Tuple[str, float]]:
6259
"""Returns a list of all job names with their timestamps.
6360
61+
:param connection: Broker connection
6462
:param start: Start score/timestamp, default to 0.
6563
:param end: End score/timestamp, default to -1 (i.e., no max score).
6664
:returns: Returns a list of all job names with timestamp from start to end
6765
"""
68-
self.cleanup(self.connection)
69-
res = self.connection.zrange(self._key, start, end, withscores=True)
66+
self.cleanup(connection)
67+
res = connection.zrange(self._key, start, end, withscores=True)
7068
logger.debug(f"Getting jobs for registry {self._key}: {len(res)} found.")
7169
return [(as_str(job_name), timestamp) for job_name, timestamp in res]
7270

73-
def get_first(self) -> Optional[str]:
71+
def get_first(self, connection: ConnectionType) -> Optional[str]:
7472
"""Returns the first job in the registry."""
75-
self.cleanup(self.connection)
76-
first_job = self.connection.zrange(self._key, 0, 0)
73+
self.cleanup(connection)
74+
first_job = connection.zrange(self._key, 0, 0)
7775
return first_job[0].decode() if first_job else None
7876

79-
def get_last_timestamp(self) -> Optional[int]:
77+
def get_last_timestamp(self, connection: ConnectionType) -> Optional[int]:
8078
"""Returns the latest timestamp in the registry."""
81-
self.cleanup(self.connection)
82-
last_timestamp = self.connection.zrange(self._key, -1, -1, withscores=True)
79+
self.cleanup(connection)
80+
last_timestamp = connection.zrange(self._key, -1, -1, withscores=True)
8381
return int(last_timestamp[0][1]) if last_timestamp else None
8482

8583
@property
@@ -88,7 +86,7 @@ def key(self) -> str:
8886

8987
@classmethod
9088
def pop(
91-
cls, connection: ConnectionType, registries: Sequence[Self], timeout: Optional[int]
89+
cls, connection: ConnectionType, registries: Sequence[Self], timeout: Optional[int]
9290
) -> Tuple[Optional[str], Optional[str]]:
9391
"""Helper method to abstract away from some Redis API details
9492

scheduler/redis_models/registry/queue_registries.py

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -17,18 +17,18 @@ def cleanup(self, connection: ConnectionType, timestamp: Optional[float] = None)
1717
and `all()` methods implemented in JobIdsRegistry."""
1818
pass
1919

20-
def compact(self) -> None:
20+
def compact(self, connection: ConnectionType) -> None:
2121
"""Removes all "dead" jobs from the queue by cycling through it, while guaranteeing FIFO semantics."""
22-
jobs_with_ts = self.all_with_timestamps()
22+
jobs_with_ts = self.all_with_timestamps(connection)
2323
for job_name, timestamp in jobs_with_ts:
24-
if not JobModel.exists(job_name, self.connection):
25-
self.delete(connection=self.connection, job_name=job_name)
24+
if not JobModel.exists(job_name, connection):
25+
self.delete(connection=connection, job_name=job_name)
2626

27-
def empty(self) -> None:
28-
queued_jobs_count = self.count(connection=self.connection)
29-
with self.connection.pipeline() as pipe:
27+
def empty(self, connection: ConnectionType) -> None:
28+
queued_jobs_count = self.count(connection=connection)
29+
with connection.pipeline() as pipe:
3030
for offset in range(0, queued_jobs_count, 1000):
31-
job_names = self.all(offset, 1000)
31+
job_names = self.all(connection, offset, 1000)
3232
for job_name in job_names:
3333
self.delete(connection=pipe, job_name=job_name)
3434
JobModel.delete_many(job_names, connection=pipe)
@@ -76,24 +76,26 @@ def schedule(self, connection: ConnectionType, job_name: str, scheduled_datetime
7676
timestamp = scheduled_datetime.timestamp()
7777
return self.add(connection=connection, job_name=job_name, score=timestamp)
7878

79-
def get_jobs_to_schedule(self, timestamp: int, chunk_size: int = 1000) -> List[str]:
79+
def get_jobs_to_schedule(self, connection: ConnectionType, timestamp: int, chunk_size: int = 1000) -> List[str]:
8080
"""Gets a list of job names that should be scheduled.
8181
82+
:param connection: Broker connection
8283
:param timestamp: timestamp/score of jobs in SortedSet.
8384
:param chunk_size: Max results to return.
8485
:returns: A list of job names
8586
"""
86-
jobs_to_schedule = self.connection.zrangebyscore(self._key, 0, max=timestamp, start=0, num=chunk_size)
87+
jobs_to_schedule = connection.zrangebyscore(self._key, 0, max=timestamp, start=0, num=chunk_size)
8788
return [as_str(job_name) for job_name in jobs_to_schedule]
8889

89-
def get_scheduled_time(self, job_name: str) -> Optional[datetime]:
90+
def get_scheduled_time(self, connection: ConnectionType, job_name: str) -> Optional[datetime]:
9091
"""Returns datetime (UTC) at which job is scheduled to be enqueued
9192
93+
:param connection: Broker connection
9294
:param job_name: Job name
9395
:returns: The scheduled time as datetime object, or None if job is not found
9496
"""
9597

96-
score: Optional[float] = self.connection.zscore(self._key, job_name)
98+
score: Optional[float] = connection.zscore(self._key, job_name)
9799
if not score:
98100
return None
99101

scheduler/templatetags/scheduler_tags.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,4 +68,4 @@ def job_runtime(job: JobModel):
6868

6969
@register.filter
7070
def job_scheduled_time(job: JobModel, queue: Queue):
71-
return queue.scheduled_job_registry.get_scheduled_time(job.name)
71+
return queue.scheduled_job_registry.get_scheduled_time(queue.connection, job.name)

scheduler/tests/test_job_decorator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ def test_job_decorator_different_queue(self):
100100

101101
def _assert_job_with_func_and_props(self, queue_name, expected_func, expected_result_ttl, expected_timeout):
102102
queue = get_queue(queue_name)
103-
jobs = JobModel.get_many(queue.queued_job_registry.all(), queue.connection)
103+
jobs = JobModel.get_many(queue.queued_job_registry.all(queue.connection), queue.connection)
104104
self.assertEqual(1, len(jobs))
105105

106106
j = jobs[0]

scheduler/tests/test_mgmt_commands/test_delete_failed_executions.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,9 @@ def test_delete_failed_executions__delete_jobs(self):
1212
queue = get_queue("default")
1313
call_command("delete_failed_executions", queue="default")
1414
queue.create_and_enqueue_job(failing_job)
15-
self.assertEqual(1, len(queue.queued_job_registry))
15+
self.assertEqual(1, queue.queued_job_registry.count(queue.connection))
1616
worker = create_worker("default", burst=True)
1717
worker.work()
18-
self.assertEqual(1, len(queue.failed_job_registry))
18+
self.assertEqual(1, queue.failed_job_registry.count(queue.connection))
1919
call_command("delete_failed_executions", queue="default")
20-
self.assertEqual(0, len(queue.failed_job_registry))
20+
self.assertEqual(0, queue.failed_job_registry.count(queue.connection))

scheduler/tests/test_mgmt_commands/test_run_job.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,11 @@
1010
class RunJobTest(TestCase):
1111
def test_run_job__should_schedule_job(self):
1212
queue = get_queue("default")
13-
queue.queued_job_registry.empty()
13+
queue.queued_job_registry.empty(queue.connection)
1414
func_name = f"{test_job.__module__}.{test_job.__name__}"
1515
# act
1616
call_command("run_job", func_name, queue="default")
1717
# assert
18-
job_list = JobModel.get_many(queue.queued_job_registry.all(), queue.connection)
18+
job_list = JobModel.get_many(queue.queued_job_registry.all(queue.connection), queue.connection)
1919
self.assertEqual(1, len(job_list))
2020
self.assertEqual(func_name + "()", job_list[0].get_call_string())

0 commit comments

Comments
 (0)