diff --git a/scheduler/management/commands/rqworker.py b/scheduler/management/commands/rqworker.py index cd812e2..89a1aa6 100644 --- a/scheduler/management/commands/rqworker.py +++ b/scheduler/management/commands/rqworker.py @@ -5,6 +5,7 @@ import click from django.core.management.base import BaseCommand from django.db import connections +from django.template.defaultfilters import default from redis.exceptions import ConnectionError from rq.logutils import setup_loghandlers @@ -47,6 +48,8 @@ def add_arguments(self, parser): help='Maximum number of jobs to execute before terminating worker') parser.add_argument('--fork-job-execution', action='store', default=True, dest='fork_job_execution', type=bool, help='Fork job execution to another process') + parser.add_argument('--job-class', action='store', dest='job_class', + help='Jobs class to use') parser.add_argument( 'queues', nargs='*', type=str, help='The queues to work on, separated by space, all queues should be using the same redis') @@ -71,6 +74,7 @@ def handle(self, **options): w = create_worker( *queues, name=options['name'], + job_class=options.get('job_class'), default_worker_ttl=options['worker_ttl'], fork_job_execution=options['fork_job_execution'], ) diff --git a/scheduler/rq_classes.py b/scheduler/rq_classes.py index d8b9238..f875dc5 100644 --- a/scheduler/rq_classes.py +++ b/scheduler/rq_classes.py @@ -70,9 +70,14 @@ def stop_execution(self, connection: Redis): class DjangoWorker(Worker): def __init__(self, *args, **kwargs): - self.fork_job_execution = kwargs.pop('fork_job_execution', True) - kwargs['job_class'] = JobExecution - kwargs['queue_class'] = DjangoQueue + self.fork_job_execution = kwargs.pop("fork_job_execution", True) + job_class = kwargs.get("job_class") or JobExecution + if not isinstance(job_class, type) or not issubclass(job_class, JobExecution): + raise ValueError("job_class must be a subclass of JobExecution") + + # Update kwargs with the potentially modified job_class + kwargs["job_class"] = job_class + kwargs["queue_class"] = DjangoQueue super(DjangoWorker, self).__init__(*args, **kwargs) def __eq__(self, other): diff --git a/scheduler/tests/test_mgmt_cmds.py b/scheduler/tests/test_mgmt_cmds.py index 5c86000..4b99193 100644 --- a/scheduler/tests/test_mgmt_cmds.py +++ b/scheduler/tests/test_mgmt_cmds.py @@ -36,6 +36,39 @@ def test_rqworker__no_queues_params(self): for job in jobs: self.assertTrue(job.is_failed) + def test_rqworker__job_class_param__green(self): + queue = get_queue('default') + + # enqueue some jobs that will fail + jobs = [] + job_ids = [] + for _ in range(0, 3): + job = queue.enqueue(failing_job) + jobs.append(job) + job_ids.append(job.id) + + # Create a worker to execute these jobs + call_command('rqworker', '--job-class', 'scheduler.rq_classes.JobExecution', fork_job_execution=False, burst=True) + + # check if all jobs are really failed + for job in jobs: + self.assertTrue(job.is_failed) + + def test_rqworker__bad_job_class__fail(self): + queue = get_queue('default') + + # enqueue some jobs that will fail + jobs = [] + job_ids = [] + for _ in range(0, 3): + job = queue.enqueue(failing_job) + jobs.append(job) + job_ids.append(job.id) + + # Create a worker to execute these jobs + with self.assertRaises(ImportError): + call_command('rqworker', '--job-class', 'rq.badclass', fork_job_execution=False, burst=True) + def test_rqworker__run_jobs(self): queue = get_queue('default') diff --git a/scheduler/tests/test_worker.py b/scheduler/tests/test_worker.py index 68e8366..9337929 100644 --- a/scheduler/tests/test_worker.py +++ b/scheduler/tests/test_worker.py @@ -1,6 +1,8 @@ import os import uuid +from rq.job import Job +from scheduler.rq_classes import JobExecution from scheduler.tests.testtools import SchedulerBaseCase from scheduler.tools import create_worker from . import test_settings # noqa @@ -38,3 +40,13 @@ def test_create_worker__scheduler_interval(self): worker.work(burst=True) self.assertEqual(worker.scheduler.interval, 1) settings.SCHEDULER_CONFIG['SCHEDULER_INTERVAL'] = prev + + def test_get_worker_with_custom_job_class(self): + # Test with string representation of job_class + worker = create_worker('default', job_class='scheduler.rq_classes.JobExecution') + self.assertTrue(issubclass(worker.job_class, Job)) + self.assertTrue(issubclass(worker.job_class, JobExecution)) + + def test_get_worker_without_custom_job_class(self): + worker = create_worker('default') + self.assertTrue(issubclass(worker.job_class, JobExecution)) diff --git a/scheduler/tools.py b/scheduler/tools.py index e6991fb..f7ce1a8 100644 --- a/scheduler/tools.py +++ b/scheduler/tools.py @@ -4,6 +4,7 @@ import croniter from django.apps import apps from django.utils import timezone +from django.utils.module_loading import import_string from scheduler.queues import get_queues, logger, get_queue from scheduler.rq_classes import DjangoWorker, MODEL_NAMES @@ -71,6 +72,15 @@ def create_worker(*queue_names, **kwargs): kwargs['name'] = _calc_worker_name(existing_worker_names) kwargs['name'] = kwargs['name'].replace('/', '.') + + # Handle job_class if provided + if 'job_class' not in kwargs or kwargs["job_class"] is None: + kwargs['job_class'] = 'scheduler.rq_classes.JobExecution' + try: + kwargs['job_class'] = import_string(kwargs['job_class']) + except ImportError: + raise ImportError(f"Could not import job class {kwargs['job_class']}") + worker = DjangoWorker(queues, connection=queues[0].connection, **kwargs) return worker