Skip to content

Commit 81d9b45

Browse files
committed
feat:add --worker-class kwarg to rqworker
1 parent ad064ab commit 81d9b45

File tree

6 files changed

+107
-13
lines changed

6 files changed

+107
-13
lines changed

scheduler/management/commands/rqworker.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ def add_arguments(self, parser):
6969
help="Fork job execution to another process",
7070
)
7171
parser.add_argument("--job-class", action="store", dest="job_class", help="Jobs class to use")
72+
parser.add_argument('--worker-class', action='store', dest='worker_class', help='RQ Worker class to use')
7273
parser.add_argument(
7374
"queues",
7475
nargs="*",
@@ -102,6 +103,7 @@ def handle(self, **options):
102103
*queues,
103104
name=options["name"],
104105
job_class=options.get("job_class"),
106+
worker_class=options.get("worker_class"),
105107
default_worker_ttl=options["worker_ttl"],
106108
fork_job_execution=options["fork_job_execution"],
107109
)

scheduler/queues.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ def get_queues(*queue_names, **kwargs) -> List[DjangoQueue]:
135135
"""
136136
from .settings import QUEUES
137137

138-
kwargs["job_class"] = JobExecution
138+
kwargs["job_class"] = kwargs.get("job_class") or JobExecution
139139
queue_params = QUEUES[queue_names[0]]
140140
queues = [get_queue(queue_names[0], **kwargs)]
141141
# perform consistency checks while building return list

scheduler/rq_classes.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,7 @@ def __init__(self, *args, **kwargs):
9090
# Update kwargs with the potentially modified job_class
9191
kwargs["job_class"] = job_class
9292
kwargs["queue_class"] = DjangoQueue
93+
kwargs.pop("worker_class", None)
9394
super(DjangoWorker, self).__init__(*args, **kwargs)
9495

9596
def __eq__(self, other):
@@ -184,7 +185,8 @@ class DjangoQueue(Queue):
184185
"""
185186

186187
def __init__(self, *args, **kwargs):
187-
kwargs["job_class"] = JobExecution
188+
kwargs["job_class"] = kwargs.get("job_class") or JobExecution
189+
self.job_class = kwargs["job_class"]
188190
super(DjangoQueue, self).__init__(*args, **kwargs)
189191

190192
def get_registry(self, name: str) -> Union[None, BaseRegistry, "DjangoQueue"]:
@@ -204,39 +206,39 @@ def started_job_registry(self):
204206
return StartedJobRegistry(
205207
self.name,
206208
self.connection,
207-
job_class=JobExecution,
209+
job_class=self.job_class,
208210
)
209211

210212
@property
211213
def deferred_job_registry(self):
212214
return DeferredJobRegistry(
213215
self.name,
214216
self.connection,
215-
job_class=JobExecution,
217+
job_class=self.job_class,
216218
)
217219

218220
@property
219221
def failed_job_registry(self):
220222
return FailedJobRegistry(
221223
self.name,
222224
self.connection,
223-
job_class=JobExecution,
225+
job_class=self.job_class,
224226
)
225227

226228
@property
227229
def scheduled_job_registry(self):
228230
return ScheduledJobRegistry(
229231
self.name,
230232
self.connection,
231-
job_class=JobExecution,
233+
job_class=self.job_class,
232234
)
233235

234236
@property
235237
def canceled_job_registry(self):
236238
return CanceledJobRegistry(
237239
self.name,
238240
self.connection,
239-
job_class=JobExecution,
241+
job_class=self.job_class,
240242
)
241243

242244
def get_all_job_ids(self) -> List[str]:

scheduler/tests/test_mgmt_cmds.py

Lines changed: 64 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717

1818

1919
class RqworkerTestCase(TestCase):
20-
2120
def test_rqworker__no_queues_params(self):
2221
queue = get_queue("default")
2322

@@ -49,7 +48,11 @@ def test_rqworker__job_class_param__green(self):
4948

5049
# Create a worker to execute these jobs
5150
call_command(
52-
"rqworker", "--job-class", "scheduler.rq_classes.JobExecution", fork_job_execution=False, burst=True
51+
"rqworker",
52+
"--job-class",
53+
"scheduler.rq_classes.JobExecution",
54+
fork_job_execution=False,
55+
burst=True,
5356
)
5457

5558
# check if all jobs are really failed
@@ -69,7 +72,58 @@ def test_rqworker__bad_job_class__fail(self):
6972

7073
# Create a worker to execute these jobs
7174
with self.assertRaises(ImportError):
72-
call_command("rqworker", "--job-class", "rq.badclass", fork_job_execution=False, burst=True)
75+
call_command(
76+
"rqworker",
77+
"--job-class",
78+
"rq.badclass",
79+
fork_job_execution=False,
80+
burst=True,
81+
)
82+
83+
def test_rqworker__worker_class_param__fail(self):
84+
queue = get_queue("default")
85+
86+
# enqueue some jobs that will fail
87+
jobs = []
88+
job_ids = []
89+
for _ in range(0, 3):
90+
job = queue.enqueue(failing_job)
91+
jobs.append(job)
92+
job_ids.append(job.id)
93+
94+
# Create a worker to execute these jobs with a bad worker class
95+
with self.assertRaises(ImportError):
96+
call_command(
97+
"rqworker",
98+
"--worker-class",
99+
"scheduler.bad_worker_class",
100+
fork_job_execution=False,
101+
burst=True,
102+
)
103+
104+
def test_rqworker__worker_class_param__green(self):
105+
queue = get_queue("default")
106+
107+
# enqueue some jobs that will fail
108+
jobs = []
109+
job_ids = []
110+
for _ in range(0, 3):
111+
job = queue.enqueue(failing_job)
112+
jobs.append(job)
113+
job_ids.append(job.id)
114+
115+
# Create a worker to execute these jobs with a good worker class
116+
call_command(
117+
"rqworker",
118+
"--worker-class",
119+
"scheduler.rq_classes.DjangoWorker",
120+
fork_job_execution=False,
121+
burst=True,
122+
)
123+
124+
# check if all jobs are really failed
125+
for job in jobs:
126+
self.assertTrue(job.is_failed)
73127

74128
def test_rqworker__run_jobs(self):
75129
queue = get_queue("default")
@@ -105,7 +159,13 @@ def test_rqworker__worker_with_two_queues(self):
105159
job_ids.append(job.id)
106160

107161
# Create a worker to execute these jobs
108-
call_command("rqworker", "default", "django_tasks_scheduler_test", fork_job_execution=False, burst=True)
162+
call_command(
163+
"rqworker",
164+
"default",
165+
"django_tasks_scheduler_test",
166+
fork_job_execution=False,
167+
burst=True,
168+
)
109169

110170
# check if all jobs are really failed
111171
for job in jobs:

scheduler/tests/test_worker.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@
22
import uuid
33

44
from rq.job import Job
5-
from scheduler.rq_classes import JobExecution
5+
from rq.worker import Worker as RQWorker
6+
from scheduler.rq_classes import JobExecution, DjangoWorker
67
from scheduler.tests.testtools import SchedulerBaseCase
78
from scheduler.tools import create_worker
89
from . import test_settings # noqa
@@ -50,3 +51,20 @@ def test_get_worker_with_custom_job_class(self):
5051
def test_get_worker_without_custom_job_class(self):
5152
worker = create_worker("default")
5253
self.assertTrue(issubclass(worker.job_class, JobExecution))
54+
55+
def test_get_worker_with_custom_worker_class(self):
56+
57+
worker = create_worker("default", worker_class="scheduler.rq_classes.DjangoWorker")
58+
self.assertIsInstance(worker, DjangoWorker)
59+
60+
def test_get_worker_with_bad_custom_worker_class(self):
61+
with self.assertRaises(ImportError):
62+
create_worker("default", worker_class="scheduler.non_existent_class")
63+
64+
def test_create_worker_with_rq_worker_class(self):
65+
with self.assertRaises(ValueError):
66+
create_worker("default", worker_class="rq.Worker")
67+
68+
def test_get_worker_without_custom_worker_class(self):
69+
worker = create_worker("default")
70+
self.assertIsInstance(worker, DjangoWorker)

scheduler/tools.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,12 +75,24 @@ def create_worker(*queue_names, **kwargs):
7575
# Handle job_class if provided
7676
if "job_class" not in kwargs or kwargs["job_class"] is None:
7777
kwargs["job_class"] = "scheduler.rq_classes.JobExecution"
78+
7879
try:
7980
kwargs["job_class"] = import_string(kwargs["job_class"])
8081
except ImportError:
8182
raise ImportError(f"Could not import job class {kwargs['job_class']}")
8283

83-
worker = DjangoWorker(queues, connection=queues[0].connection, **kwargs)
84+
# Handle worker_class if provided
85+
if "worker_class" in kwargs and kwargs["worker_class"]:
86+
try:
87+
worker_class = import_string(kwargs["worker_class"])
88+
if not issubclass(worker_class, DjangoWorker):
89+
raise ValueError("worker_class must be a subclass of DjangoWorker")
90+
except ImportError:
91+
raise ImportError(f"Could not import worker class {kwargs['worker_class']}")
92+
else:
93+
worker_class = DjangoWorker
94+
95+
worker = worker_class(queues, connection=queues[0].connection, **kwargs)
8496
return worker
8597

8698

0 commit comments

Comments
 (0)