Skip to content

Commit c3f919e

Browse files
author
Andrei Neagu
committed
refactor
1 parent 2f57230 commit c3f919e

File tree

6 files changed

+73
-18
lines changed

6 files changed

+73
-18
lines changed

packages/service-library/src/servicelib/aiohttp/long_running_tasks/_server.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313
from ...aiohttp import status
1414
from ...long_running_tasks import lrt_api
15-
from ...long_running_tasks._error_serialization import (
15+
from ...long_running_tasks._redis_serialization import (
1616
BaseObjectSerializer,
1717
register_custom_serialization,
1818
)

packages/service-library/src/servicelib/long_running_tasks/_error_serialization.py renamed to packages/service-library/src/servicelib/long_running_tasks/_redis_serialization.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import logging
33
import pickle
44
from abc import ABC, abstractmethod
5-
from typing import Final, Generic, TypeVar
5+
from typing import Any, Final, Generic, TypeVar
66

77
_logger = logging.getLogger(__name__)
88

@@ -42,9 +42,9 @@ def register_custom_serialization(
4242
_MODULE_FIELD: Final[str] = "__pickle__module__field__"
4343

4444

45-
def error_to_string(e: Exception) -> str:
46-
"""Serialize exception to base64-encoded string."""
47-
to_serialize: Exception | dict = e
45+
def object_to_string(e: Any) -> str:
46+
"""Serialize object to base64-encoded string."""
47+
to_serialize: Any | dict = e
4848
object_class = type(e)
4949

5050
for registered_class, object_serializer in _SERIALIZERS.items():
@@ -59,8 +59,8 @@ def error_to_string(e: Exception) -> str:
5959
return base64.b85encode(pickle.dumps(to_serialize)).decode("utf-8")
6060

6161

62-
def error_from_string(error_str: str) -> Exception:
63-
"""Deserialize exception from base64-encoded string."""
62+
def string_to_object(error_str: str) -> Any:
63+
"""Deserialize object from base64-encoded string."""
6464
data = pickle.loads(base64.b85decode(error_str)) # noqa: S301
6565

6666
if isinstance(data, dict) and _TYPE_FIELD in data and _MODULE_FIELD in data:

packages/service-library/src/servicelib/long_running_tasks/_store/redis.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import logging
12
from typing import Any, Final
23

34
import redis.asyncio as aioredis
@@ -9,6 +10,8 @@
910
from ..models import TaskContext, TaskData, TaskId
1011
from .base import BaseStore
1112

13+
_logger = logging.getLogger(__name__)
14+
1215
STORE_TYPE_TASK_DATA: Final[str] = "TD"
1316
STORE_TYPE_CANCELLED_TASKS: Final[str] = "CT"
1417

@@ -49,6 +52,9 @@ async def get_task_data(self, task_id: TaskId) -> TaskData | None:
4952
return TypeAdapter(TaskData).validate_json(result) if result else None
5053

5154
async def set_task_data(self, task_id: TaskId, value: TaskData) -> None:
55+
_logger.debug(
56+
"Setting task data for task_id=%s with data value=%s", task_id, value
57+
)
5258
await self.redis.hset(
5359
self._get_redis_hash_key(STORE_TYPE_TASK_DATA),
5460
task_id,

packages/service-library/src/servicelib/long_running_tasks/task.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from settings_library.redis import RedisDatabase, RedisSettings
1818

1919
from ..redis import RedisClientSDK, exclusive
20-
from ._error_serialization import error_from_string, error_to_string
20+
from ._redis_serialization import object_to_string, string_to_object
2121
from ._store.base import BaseStore
2222
from ._store.redis import RedisStore
2323
from .errors import (
@@ -263,16 +263,18 @@ async def _status_update_worker(self) -> None:
263263

264264
# get task result
265265
try:
266-
task_data.result_field = ResultField(result=task.result())
266+
task_data.result_field = ResultField(
267+
result=object_to_string(task.result())
268+
)
267269
except asyncio.InvalidStateError:
268270
# task was not completed try again next time and see if it is done
269271
continue
270272
except asyncio.CancelledError:
271273
task_data.result_field = ResultField(
272-
error=error_to_string(TaskCancelledError(task_id=task_id))
274+
error=object_to_string(TaskCancelledError(task_id=task_id))
273275
)
274276
except Exception as e: # pylint:disable=broad-except
275-
task_data.result_field = ResultField(error=error_to_string(e))
277+
task_data.result_field = ResultField(error=object_to_string(e))
276278

277279
await self._tasks_data.set_task_data(task_id, task_data)
278280

@@ -360,9 +362,9 @@ async def get_task_result(
360362
raise TaskNotCompletedError(task_id=task_id)
361363

362364
if tracked_task.result_field.error is not None:
363-
raise error_from_string(tracked_task.result_field.error)
365+
raise string_to_object(tracked_task.result_field.error)
364366

365-
return tracked_task.result_field.result
367+
return string_to_object(tracked_task.result_field.result)
366368

367369
async def cancel_task(
368370
self, task_id: TaskId, with_task_context: TaskContext

packages/service-library/tests/long_running_tasks/test_long_running_tasks__error_serialization.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,10 @@
33
import pytest
44
from aiohttp.web import HTTPException, HTTPInternalServerError
55
from servicelib.aiohttp.long_running_tasks._server import AiohttpHTTPExceptionSerializer
6-
from servicelib.long_running_tasks._error_serialization import (
7-
error_from_string,
8-
error_to_string,
6+
from servicelib.long_running_tasks._redis_serialization import (
7+
object_to_string,
98
register_custom_serialization,
9+
string_to_object,
1010
)
1111

1212
register_custom_serialization(HTTPException, AiohttpHTTPExceptionSerializer)
@@ -38,9 +38,9 @@ def __init__(self, arg1, arg2, kwarg1=None, kwarg2=None):
3838
],
3939
)
4040
def test_serialization(obj: Any):
41-
str_data = error_to_string(obj)
41+
str_data = object_to_string(obj)
4242

43-
reconstructed_obj = error_from_string(str_data)
43+
reconstructed_obj = string_to_object(str_data)
4444

4545
assert type(reconstructed_obj) is type(obj)
4646
if hasattr(obj, "__dict__"):
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
from typing import Any
2+
3+
import pytest
4+
from aiohttp.web import HTTPException, HTTPInternalServerError
5+
from servicelib.aiohttp.long_running_tasks._server import AiohttpHTTPExceptionSerializer
6+
from servicelib.long_running_tasks._redis_serialization import (
7+
object_to_string,
8+
register_custom_serialization,
9+
string_to_object,
10+
)
11+
12+
register_custom_serialization(HTTPException, AiohttpHTTPExceptionSerializer)
13+
14+
15+
class PositionalArguments:
16+
def __init__(self, arg1, arg2, *args):
17+
self.arg1 = arg1
18+
self.arg2 = arg2
19+
self.args = args
20+
21+
22+
class MixedArguments:
23+
def __init__(self, arg1, arg2, kwarg1=None, kwarg2=None):
24+
self.arg1 = arg1
25+
self.arg2 = arg2
26+
self.kwarg1 = kwarg1
27+
self.kwarg2 = kwarg2
28+
29+
30+
@pytest.mark.parametrize(
31+
"obj",
32+
[
33+
HTTPInternalServerError(reason="Uh-oh!", text="Failure!"),
34+
PositionalArguments("arg1", "arg2", "arg3", "arg4"),
35+
MixedArguments("arg1", "arg2", kwarg1="kwarg1", kwarg2="kwarg2"),
36+
"a_string",
37+
1,
38+
],
39+
)
40+
def test_serialization(obj: Any):
41+
str_data = object_to_string(obj)
42+
43+
reconstructed_obj = string_to_object(str_data)
44+
45+
assert type(reconstructed_obj) is type(obj)
46+
if hasattr(obj, "__dict__"):
47+
assert reconstructed_obj.__dict__ == obj.__dict__

0 commit comments

Comments
 (0)