diff --git a/scheduler/management/commands/rqworker.py b/scheduler/management/commands/rqworker.py index a4dbd9e..a7e271c 100644 --- a/scheduler/management/commands/rqworker.py +++ b/scheduler/management/commands/rqworker.py @@ -69,6 +69,7 @@ def add_arguments(self, parser): help="Fork job execution to another process", ) parser.add_argument("--job-class", action="store", dest="job_class", help="Jobs class to use") + parser.add_argument('--worker-class', action='store', dest='worker_class', help='RQ Worker class to use') parser.add_argument( "queues", nargs="*", @@ -102,6 +103,7 @@ def handle(self, **options): *queues, name=options["name"], job_class=options.get("job_class"), + worker_class=options.get("worker_class"), default_worker_ttl=options["worker_ttl"], fork_job_execution=options["fork_job_execution"], ) diff --git a/scheduler/queues.py b/scheduler/queues.py index 8edd8fe..0086612 100644 --- a/scheduler/queues.py +++ b/scheduler/queues.py @@ -135,7 +135,7 @@ def get_queues(*queue_names, **kwargs) -> List[DjangoQueue]: """ from .settings import QUEUES - kwargs["job_class"] = JobExecution + kwargs["job_class"] = kwargs.get("job_class") or JobExecution queue_params = QUEUES[queue_names[0]] queues = [get_queue(queue_names[0], **kwargs)] # perform consistency checks while building return list diff --git a/scheduler/rq_classes.py b/scheduler/rq_classes.py index 99bc1ec..e4138de 100644 --- a/scheduler/rq_classes.py +++ b/scheduler/rq_classes.py @@ -90,6 +90,7 @@ def __init__(self, *args, **kwargs): # Update kwargs with the potentially modified job_class kwargs["job_class"] = job_class kwargs["queue_class"] = DjangoQueue + kwargs.pop("worker_class", None) super(DjangoWorker, self).__init__(*args, **kwargs) def __eq__(self, other): @@ -184,7 +185,8 @@ class DjangoQueue(Queue): """ def __init__(self, *args, **kwargs): - kwargs["job_class"] = JobExecution + kwargs["job_class"] = kwargs.get("job_class") or JobExecution + self.job_class = kwargs["job_class"] super(DjangoQueue, self).__init__(*args, **kwargs) def get_registry(self, name: str) -> Union[None, BaseRegistry, "DjangoQueue"]: @@ -204,7 +206,7 @@ def started_job_registry(self): return StartedJobRegistry( self.name, self.connection, - job_class=JobExecution, + job_class=self.job_class, ) @property @@ -212,7 +214,7 @@ def deferred_job_registry(self): return DeferredJobRegistry( self.name, self.connection, - job_class=JobExecution, + job_class=self.job_class, ) @property @@ -220,7 +222,7 @@ def failed_job_registry(self): return FailedJobRegistry( self.name, self.connection, - job_class=JobExecution, + job_class=self.job_class, ) @property @@ -228,7 +230,7 @@ def scheduled_job_registry(self): return ScheduledJobRegistry( self.name, self.connection, - job_class=JobExecution, + job_class=self.job_class, ) @property @@ -236,7 +238,7 @@ def canceled_job_registry(self): return CanceledJobRegistry( self.name, self.connection, - job_class=JobExecution, + job_class=self.job_class, ) def get_all_job_ids(self) -> List[str]: diff --git a/scheduler/tests/test_mgmt_cmds.py b/scheduler/tests/test_mgmt_cmds.py index 6257935..b259bae 100644 --- a/scheduler/tests/test_mgmt_cmds.py +++ b/scheduler/tests/test_mgmt_cmds.py @@ -17,7 +17,6 @@ class RqworkerTestCase(TestCase): - def test_rqworker__no_queues_params(self): queue = get_queue("default") @@ -49,7 +48,11 @@ def test_rqworker__job_class_param__green(self): # Create a worker to execute these jobs call_command( - "rqworker", "--job-class", "scheduler.rq_classes.JobExecution", fork_job_execution=False, burst=True + "rqworker", + "--job-class", + "scheduler.rq_classes.JobExecution", + fork_job_execution=False, + burst=True, ) # check if all jobs are really failed @@ -69,7 +72,58 @@ def test_rqworker__bad_job_class__fail(self): # Create a worker to execute these jobs with self.assertRaises(ImportError): - call_command("rqworker", "--job-class", "rq.badclass", fork_job_execution=False, burst=True) + call_command( + "rqworker", + "--job-class", + "rq.badclass", + fork_job_execution=False, + burst=True, + ) + + def test_rqworker__worker_class_param__fail(self): + queue = get_queue("default") + + # enqueue some jobs that will fail + jobs = [] + job_ids = [] + for _ in range(0, 3): + job = queue.enqueue(failing_job) + jobs.append(job) + job_ids.append(job.id) + + # Create a worker to execute these jobs with a bad worker class + with self.assertRaises(ImportError): + call_command( + "rqworker", + "--worker-class", + "scheduler.bad_worker_class", + fork_job_execution=False, + burst=True, + ) + + def test_rqworker__worker_class_param__green(self): + queue = get_queue("default") + + # enqueue some jobs that will fail + jobs = [] + job_ids = [] + for _ in range(0, 3): + job = queue.enqueue(failing_job) + jobs.append(job) + job_ids.append(job.id) + + # Create a worker to execute these jobs with a good worker class + call_command( + "rqworker", + "--worker-class", + "scheduler.rq_classes.DjangoWorker", + fork_job_execution=False, + burst=True, + ) + + # check if all jobs are really failed + for job in jobs: + self.assertTrue(job.is_failed) def test_rqworker__run_jobs(self): queue = get_queue("default") @@ -105,7 +159,13 @@ def test_rqworker__worker_with_two_queues(self): job_ids.append(job.id) # Create a worker to execute these jobs - call_command("rqworker", "default", "django_tasks_scheduler_test", fork_job_execution=False, burst=True) + call_command( + "rqworker", + "default", + "django_tasks_scheduler_test", + fork_job_execution=False, + burst=True, + ) # check if all jobs are really failed for job in jobs: diff --git a/scheduler/tests/test_worker.py b/scheduler/tests/test_worker.py index 4b40bfb..d2cde1f 100644 --- a/scheduler/tests/test_worker.py +++ b/scheduler/tests/test_worker.py @@ -2,7 +2,7 @@ import uuid from rq.job import Job -from scheduler.rq_classes import JobExecution +from scheduler.rq_classes import JobExecution, DjangoWorker from scheduler.tests.testtools import SchedulerBaseCase from scheduler.tools import create_worker from . import test_settings # noqa @@ -50,3 +50,21 @@ def test_get_worker_with_custom_job_class(self): def test_get_worker_without_custom_job_class(self): worker = create_worker("default") self.assertTrue(issubclass(worker.job_class, JobExecution)) + + def test_get_worker_with_custom_worker_class(self): + worker = create_worker( + "default", worker_class="scheduler.rq_classes.DjangoWorker" + ) + self.assertIsInstance(worker, DjangoWorker) + + def test_get_worker_with_bad_custom_worker_class(self): + with self.assertRaises(ImportError): + create_worker("default", worker_class="scheduler.non_existent_class") + + def test_create_worker_with_rq_worker_class(self): + with self.assertRaises(ValueError): + create_worker("default", worker_class="rq.Worker") + + def test_get_worker_without_custom_worker_class(self): + worker = create_worker("default") + self.assertIsInstance(worker, DjangoWorker) diff --git a/scheduler/tools.py b/scheduler/tools.py index 476fff0..10151d4 100644 --- a/scheduler/tools.py +++ b/scheduler/tools.py @@ -75,12 +75,24 @@ def create_worker(*queue_names, **kwargs): # Handle job_class if provided if "job_class" not in kwargs or kwargs["job_class"] is None: kwargs["job_class"] = "scheduler.rq_classes.JobExecution" + try: kwargs["job_class"] = import_string(kwargs["job_class"]) except ImportError: raise ImportError(f"Could not import job class {kwargs['job_class']}") - worker = DjangoWorker(queues, connection=queues[0].connection, **kwargs) + # Handle worker_class if provided + if "worker_class" in kwargs and kwargs["worker_class"]: + try: + worker_class = import_string(kwargs["worker_class"]) + if not issubclass(worker_class, DjangoWorker): + raise ValueError("worker_class must be a subclass of DjangoWorker") + except ImportError: + raise ImportError(f"Could not import worker class {kwargs['worker_class']}") + else: + worker_class = DjangoWorker + + worker = worker_class(queues, connection=queues[0].connection, **kwargs) return worker