diff --git a/docs/changelog.md b/docs/changelog.md index 3cbd056..dfa3944 100644 --- a/docs/changelog.md +++ b/docs/changelog.md @@ -4,6 +4,10 @@ See breaking changes in 4.0.0 beta versions. +### 🐛 Bug Fixes + +- Fix issue with non-primitive parameters for @job #249 + ## v4.0.0b3 🌈 Refactor the code to make it more organized and easier to maintain. This includes: diff --git a/poetry.lock b/poetry.lock index 8a5e8dd..741a937 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1465,15 +1465,15 @@ jeepney = ">=0.6" [[package]] name = "sentry-sdk" -version = "2.25.1" +version = "2.26.0" description = "Python client for Sentry (https://sentry.io)" optional = true python-versions = ">=3.6" groups = ["main"] markers = "extra == \"sentry\"" files = [ - {file = "sentry_sdk-2.25.1-py2.py3-none-any.whl", hash = "sha256:60b016d0772789454dc55a284a6a44212044d4a16d9f8448725effee97aaf7f6"}, - {file = "sentry_sdk-2.25.1.tar.gz", hash = "sha256:f9041b7054a7cf12d41eadabe6458ce7c6d6eea7a97cfe1b760b6692e9562cf0"}, + {file = "sentry_sdk-2.26.0-py2.py3-none-any.whl", hash = "sha256:82496fc359296dac57ec923300b18cc1f14a1279c1e7108d46d35dbb4cf8f5f8"}, + {file = "sentry_sdk-2.26.0.tar.gz", hash = "sha256:88643459716dd0c6e412e5141fcc94ce3b5725e4b6b312210b91332b3b46a0e2"}, ] [package.dependencies] diff --git a/scheduler/redis_models/base.py b/scheduler/redis_models/base.py index 0b10418..2af8b16 100644 --- a/scheduler/redis_models/base.py +++ b/scheduler/redis_models/base.py @@ -63,7 +63,7 @@ def _deserialize(value: str, _type: Type) -> Any: return int(value) elif _type is float or _type == Optional[float]: return float(value) - elif _type in {List[Any], List[str], Dict[str, str]}: + elif _type in {List[str], Dict[str, str]}: return json.loads(value) elif _type == Optional[Any]: return json.loads(value) @@ -78,6 +78,9 @@ def _deserialize(value: str, _type: Type) -> Any: class BaseModel: name: str _element_key_template: ClassVar[str] = ":element:{}" + # fields that are not serializable using method above and should be dealt with in the subclass + # e.g. args/kwargs for a job + _non_serializable_fields: ClassVar[Set[str]] = set() @classmethod def key_for(cls, name: str) -> str: @@ -92,14 +95,14 @@ def serialize(self, with_nones: bool = False) -> Dict[str, str]: self, dict_factory=lambda fields: {key: value for (key, value) in fields if not key.startswith("_")} ) if not with_nones: - data = {k: v for k, v in data.items() if v is not None} + data = {k: v for k, v in data.items() if v is not None and k not in self._non_serializable_fields} for k in data: data[k] = _serialize(data[k]) return data @classmethod def deserialize(cls, data: Dict[str, Any]) -> Self: - types = {f.name: f.type for f in dataclasses.fields(cls)} + types = {f.name: f.type for f in dataclasses.fields(cls) if f.name not in cls._non_serializable_fields} for k in data: if k not in types: logger.warning(f"Unknown field {k} in {cls.__name__}") diff --git a/scheduler/redis_models/job.py b/scheduler/redis_models/job.py index a512c8e..90c5302 100644 --- a/scheduler/redis_models/job.py +++ b/scheduler/redis_models/job.py @@ -1,6 +1,8 @@ +import base64 import dataclasses import inspect import numbers +import pickle from datetime import datetime from enum import Enum from typing import ClassVar, Dict, Optional, List, Callable, Any, Union, Tuple @@ -35,13 +37,15 @@ class JobModel(HashModel): _list_key: ClassVar[str] = ":jobs:ALL:" _children_key_template: ClassVar[str] = ":{}:jobs:" _element_key_template: ClassVar[str] = ":jobs:{}" + _non_serializable_fields = {"args", "kwargs"} + + args: List[Any] + kwargs: Dict[str, str] queue_name: str description: str func_name: str - args: List[Any] - kwargs: Dict[str, str] timeout: int = SCHEDULER_CONFIG.DEFAULT_JOB_TIMEOUT success_ttl: int = SCHEDULER_CONFIG.DEFAULT_SUCCESS_TTL job_info_ttl: int = SCHEDULER_CONFIG.DEFAULT_JOB_TTL @@ -63,10 +67,6 @@ class JobModel(HashModel): task_type: Optional[str] = None scheduled_task_id: Optional[int] = None - def serialize(self, with_nones: bool = False) -> Dict[str, str]: - res = super(JobModel, self).serialize() - return res - def __hash__(self): return hash(self.name) @@ -166,6 +166,21 @@ def stopped_callback(self) -> Optional[Callable[..., Any]]: def get_call_string(self): return _get_call_string(self.func_name, self.args, self.kwargs) + def serialize(self, with_nones: bool = False) -> Dict[str, str]: + """Serialize the job model to a dictionary.""" + res = super(JobModel, self).serialize(with_nones=with_nones) + res["args"] = base64.encodebytes(pickle.dumps(self.args)).decode("utf-8") + res["kwargs"] = base64.encodebytes(pickle.dumps(self.kwargs)).decode("utf-8") + return res + + @classmethod + def deserialize(cls, data: Dict[str, Any]) -> Self: + """Deserialize the job model from a dictionary.""" + res = super(JobModel, cls).deserialize(data) + res.args = pickle.loads(base64.decodebytes(data.get("args").encode("utf-8"))) + res.kwargs = pickle.loads(base64.decodebytes(data.get("kwargs").encode("utf-8"))) + return res + @classmethod def create( cls, diff --git a/scheduler/tests/test_job_decorator.py b/scheduler/tests/test_job_decorator.py index ec10d16..736a80e 100644 --- a/scheduler/tests/test_job_decorator.py +++ b/scheduler/tests/test_job_decorator.py @@ -6,7 +6,9 @@ from scheduler.helpers.queues import get_queue from . import test_settings # noqa from ..decorators import JOB_METHODS_LIST, job +from ..redis_models import JobStatus from ..redis_models.job import JobModel +from ..worker import create_worker @job() @@ -32,12 +34,27 @@ def test_job_result_ttl(): return 1 + 1 +class MyClass: + def run(self): + print("Hello") + + def __eq__(self, other): + if not isinstance(other, MyClass): + return False + return True + + +@job() +def long_running_func(x): + x.run() + + class JobDecoratorTest(TestCase): def setUp(self) -> None: get_queue("default").connection.flushall() def test_all_job_methods_registered(self): - self.assertEqual(4, len(JOB_METHODS_LIST)) + self.assertEqual(5, len(JOB_METHODS_LIST)) def test_job_decorator_no_params(self): test_job.delay() @@ -92,3 +109,18 @@ def test_job_decorator_bad_queue(self): def test_job_bad_queue(): time.sleep(1) return 1 + 1 + + def test_job_decorator_delay_with_param(self): + queue_name = "default" + long_running_func.delay(MyClass()) + + worker = create_worker(queue_name, burst=True) + worker.work() + + jobs_list = worker.queues[0].get_all_jobs() + self.assertEqual(1, len(jobs_list)) + job = jobs_list[0] + self.assertEqual(job.func, long_running_func) + self.assertEqual(job.kwargs, {}) + self.assertEqual(job.status, JobStatus.FINISHED) + self.assertEqual(job.args, (MyClass(),)) diff --git a/scheduler/tests/test_task_types/test_task_model.py b/scheduler/tests/test_task_types/test_task_model.py index f23dc2f..16434f7 100644 --- a/scheduler/tests/test_task_types/test_task_model.py +++ b/scheduler/tests/test_task_types/test_task_model.py @@ -1,17 +1,17 @@ import zoneinfo from datetime import datetime, timedelta +import time_machine from django.contrib.messages import get_messages from django.core.exceptions import ValidationError from django.test import override_settings from django.urls import reverse from django.utils import timezone -from freezegun import freeze_time from scheduler import settings -from scheduler.models import TaskType, Task, TaskArg, TaskKwarg, run_task from scheduler.helpers.queues import get_queue from scheduler.helpers.queues import perform_job +from scheduler.models import TaskType, Task, TaskArg, TaskKwarg, run_task from scheduler.redis_models import JobStatus, JobModel from scheduler.tests import jobs, test_settings # noqa from scheduler.tests.testtools import ( @@ -480,14 +480,14 @@ class TestSchedulableTask(TestBaseTask): # Currently ScheduledJob and RepeatableJob task_type = TaskType.ONCE - @freeze_time("2016-12-25") + @time_machine.travel(datetime(2016, 12, 25)) @override_settings(USE_TZ=False) def test_schedule_time_no_tz(self): task = task_factory(self.task_type) task.scheduled_time = datetime(2016, 12, 25, 8, 0, 0, tzinfo=None) self.assertEqual("2016-12-25T08:00:00", task._schedule_time().isoformat()) - @freeze_time("2016-12-25") + @time_machine.travel(datetime(2016, 12, 25)) @override_settings(USE_TZ=True) def test_schedule_time_with_tz(self): task = task_factory(self.task_type)