Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions docs/changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@

See breaking changes in 4.0.0 beta versions.

### 🐛 Bug Fixes

- Fix issue with non-primitive parameters for @job #249

## v4.0.0b3 🌈

Refactor the code to make it more organized and easier to maintain. This includes:
Expand Down
6 changes: 3 additions & 3 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

9 changes: 6 additions & 3 deletions scheduler/redis_models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def _deserialize(value: str, _type: Type) -> Any:
return int(value)
elif _type is float or _type == Optional[float]:
return float(value)
elif _type in {List[Any], List[str], Dict[str, str]}:
elif _type in {List[str], Dict[str, str]}:
return json.loads(value)
elif _type == Optional[Any]:
return json.loads(value)
Expand All @@ -78,6 +78,9 @@ def _deserialize(value: str, _type: Type) -> Any:
class BaseModel:
name: str
_element_key_template: ClassVar[str] = ":element:{}"
# fields that are not serializable using method above and should be dealt with in the subclass
# e.g. args/kwargs for a job
_non_serializable_fields: ClassVar[Set[str]] = set()

@classmethod
def key_for(cls, name: str) -> str:
Expand All @@ -92,14 +95,14 @@ def serialize(self, with_nones: bool = False) -> Dict[str, str]:
self, dict_factory=lambda fields: {key: value for (key, value) in fields if not key.startswith("_")}
)
if not with_nones:
data = {k: v for k, v in data.items() if v is not None}
data = {k: v for k, v in data.items() if v is not None and k not in self._non_serializable_fields}
for k in data:
data[k] = _serialize(data[k])
return data

@classmethod
def deserialize(cls, data: Dict[str, Any]) -> Self:
types = {f.name: f.type for f in dataclasses.fields(cls)}
types = {f.name: f.type for f in dataclasses.fields(cls) if f.name not in cls._non_serializable_fields}
for k in data:
if k not in types:
logger.warning(f"Unknown field {k} in {cls.__name__}")
Expand Down
27 changes: 21 additions & 6 deletions scheduler/redis_models/job.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import base64
import dataclasses
import inspect
import numbers
import pickle
from datetime import datetime
from enum import Enum
from typing import ClassVar, Dict, Optional, List, Callable, Any, Union, Tuple
Expand Down Expand Up @@ -35,13 +37,15 @@ class JobModel(HashModel):
_list_key: ClassVar[str] = ":jobs:ALL:"
_children_key_template: ClassVar[str] = ":{}:jobs:"
_element_key_template: ClassVar[str] = ":jobs:{}"
_non_serializable_fields = {"args", "kwargs"}

args: List[Any]
kwargs: Dict[str, str]

queue_name: str
description: str
func_name: str

args: List[Any]
kwargs: Dict[str, str]
timeout: int = SCHEDULER_CONFIG.DEFAULT_JOB_TIMEOUT
success_ttl: int = SCHEDULER_CONFIG.DEFAULT_SUCCESS_TTL
job_info_ttl: int = SCHEDULER_CONFIG.DEFAULT_JOB_TTL
Expand All @@ -63,10 +67,6 @@ class JobModel(HashModel):
task_type: Optional[str] = None
scheduled_task_id: Optional[int] = None

def serialize(self, with_nones: bool = False) -> Dict[str, str]:
res = super(JobModel, self).serialize()
return res

def __hash__(self):
return hash(self.name)

Expand Down Expand Up @@ -166,6 +166,21 @@ def stopped_callback(self) -> Optional[Callable[..., Any]]:
def get_call_string(self):
return _get_call_string(self.func_name, self.args, self.kwargs)

def serialize(self, with_nones: bool = False) -> Dict[str, str]:
"""Serialize the job model to a dictionary."""
res = super(JobModel, self).serialize(with_nones=with_nones)
res["args"] = base64.encodebytes(pickle.dumps(self.args)).decode("utf-8")
res["kwargs"] = base64.encodebytes(pickle.dumps(self.kwargs)).decode("utf-8")
return res

@classmethod
def deserialize(cls, data: Dict[str, Any]) -> Self:
"""Deserialize the job model from a dictionary."""
res = super(JobModel, cls).deserialize(data)
res.args = pickle.loads(base64.decodebytes(data.get("args").encode("utf-8")))
res.kwargs = pickle.loads(base64.decodebytes(data.get("kwargs").encode("utf-8")))
return res

@classmethod
def create(
cls,
Expand Down
34 changes: 33 additions & 1 deletion scheduler/tests/test_job_decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@
from scheduler.helpers.queues import get_queue
from . import test_settings # noqa
from ..decorators import JOB_METHODS_LIST, job
from ..redis_models import JobStatus
from ..redis_models.job import JobModel
from ..worker import create_worker


@job()
Expand All @@ -32,12 +34,27 @@ def test_job_result_ttl():
return 1 + 1


class MyClass:
def run(self):
print("Hello")

def __eq__(self, other):
if not isinstance(other, MyClass):
return False
return True


@job()
def long_running_func(x):
x.run()


class JobDecoratorTest(TestCase):
def setUp(self) -> None:
get_queue("default").connection.flushall()

def test_all_job_methods_registered(self):
self.assertEqual(4, len(JOB_METHODS_LIST))
self.assertEqual(5, len(JOB_METHODS_LIST))

def test_job_decorator_no_params(self):
test_job.delay()
Expand Down Expand Up @@ -92,3 +109,18 @@ def test_job_decorator_bad_queue(self):
def test_job_bad_queue():
time.sleep(1)
return 1 + 1

def test_job_decorator_delay_with_param(self):
queue_name = "default"
long_running_func.delay(MyClass())

worker = create_worker(queue_name, burst=True)
worker.work()

jobs_list = worker.queues[0].get_all_jobs()
self.assertEqual(1, len(jobs_list))
job = jobs_list[0]
self.assertEqual(job.func, long_running_func)
self.assertEqual(job.kwargs, {})
self.assertEqual(job.status, JobStatus.FINISHED)
self.assertEqual(job.args, (MyClass(),))
8 changes: 4 additions & 4 deletions scheduler/tests/test_task_types/test_task_model.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
import zoneinfo
from datetime import datetime, timedelta

import time_machine
from django.contrib.messages import get_messages
from django.core.exceptions import ValidationError
from django.test import override_settings
from django.urls import reverse
from django.utils import timezone
from freezegun import freeze_time

from scheduler import settings
from scheduler.models import TaskType, Task, TaskArg, TaskKwarg, run_task
from scheduler.helpers.queues import get_queue
from scheduler.helpers.queues import perform_job
from scheduler.models import TaskType, Task, TaskArg, TaskKwarg, run_task
from scheduler.redis_models import JobStatus, JobModel
from scheduler.tests import jobs, test_settings # noqa
from scheduler.tests.testtools import (
Expand Down Expand Up @@ -480,14 +480,14 @@ class TestSchedulableTask(TestBaseTask):
# Currently ScheduledJob and RepeatableJob
task_type = TaskType.ONCE

@freeze_time("2016-12-25")
@time_machine.travel(datetime(2016, 12, 25))
@override_settings(USE_TZ=False)
def test_schedule_time_no_tz(self):
task = task_factory(self.task_type)
task.scheduled_time = datetime(2016, 12, 25, 8, 0, 0, tzinfo=None)
self.assertEqual("2016-12-25T08:00:00", task._schedule_time().isoformat())

@freeze_time("2016-12-25")
@time_machine.travel(datetime(2016, 12, 25))
@override_settings(USE_TZ=True)
def test_schedule_time_with_tz(self):
task = task_factory(self.task_type)
Expand Down
Loading