Skip to content

Commit 2f57230

Browse files
author
Andrei Neagu
committed
added custom serializers
1 parent e81328e commit 2f57230

File tree

4 files changed

+160
-15
lines changed

4 files changed

+160
-15
lines changed

packages/service-library/src/servicelib/aiohttp/long_running_tasks/_server.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,19 +5,23 @@
55
from typing import Any
66

77
from aiohttp import web
8+
from aiohttp.web import HTTPException
89
from common_library.json_serialization import json_dumps
910
from pydantic import AnyHttpUrl, TypeAdapter
10-
from servicelib.long_running_tasks.task import Namespace
1111
from settings_library.redis import RedisSettings
1212

1313
from ...aiohttp import status
1414
from ...long_running_tasks import lrt_api
15+
from ...long_running_tasks._error_serialization import (
16+
BaseObjectSerializer,
17+
register_custom_serialization,
18+
)
1519
from ...long_running_tasks.constants import (
1620
DEFAULT_STALE_TASK_CHECK_INTERVAL,
1721
DEFAULT_STALE_TASK_DETECT_TIMEOUT,
1822
)
1923
from ...long_running_tasks.models import TaskContext, TaskGet
20-
from ...long_running_tasks.task import RegisteredTaskName
24+
from ...long_running_tasks.task import Namespace, RegisteredTaskName
2125
from ..typing_extension import Handler
2226
from . import _routes
2327
from ._constants import (
@@ -118,6 +122,22 @@ def _wrap_and_add_routes(
118122
)
119123

120124

125+
class AiohttpHTTPExceptionSerializer(BaseObjectSerializer[HTTPException]):
126+
@classmethod
127+
def get_init_kwargs_from_object(cls, obj: HTTPException) -> dict:
128+
return {
129+
"status_code": obj.status_code,
130+
"reason": obj.reason,
131+
"text": obj.text,
132+
"headers": dict(obj.headers) if obj.headers else None,
133+
}
134+
135+
@classmethod
136+
def prepare_object_init_kwargs(cls, data: dict) -> dict:
137+
data.pop("status_code")
138+
return data
139+
140+
121141
def setup(
122142
app: web.Application,
123143
*,
@@ -140,6 +160,8 @@ def setup(
140160
"""
141161

142162
async def on_cleanup_ctx(app: web.Application) -> AsyncGenerator[None, None]:
163+
register_custom_serialization(HTTPException, AiohttpHTTPExceptionSerializer)
164+
143165
# add error handlers
144166
app.middlewares.append(base_long_running_error_handler)
145167

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
import base64
2+
import logging
3+
import pickle
4+
from abc import ABC, abstractmethod
5+
from typing import Final, Generic, TypeVar
6+
7+
_logger = logging.getLogger(__name__)
8+
9+
10+
T = TypeVar("T")
11+
12+
13+
class BaseObjectSerializer(ABC, Generic[T]):
14+
15+
@classmethod
16+
@abstractmethod
17+
def get_init_kwargs_from_object(cls, obj: T) -> dict:
18+
"""dictionary reppreseting the kwargs passed to the __init__ method"""
19+
20+
@classmethod
21+
@abstractmethod
22+
def prepare_object_init_kwargs(cls, data: dict) -> dict:
23+
"""cleanup data to be used as kwargs for the __init__ method if required"""
24+
25+
26+
_SERIALIZERS: Final[dict[type, type[BaseObjectSerializer]]] = {}
27+
28+
29+
def register_custom_serialization(
30+
object_type: type, object_serializer: type[BaseObjectSerializer]
31+
) -> None:
32+
"""Register a custom serializer for a specific object type.
33+
34+
Arguments:
35+
object_type -- the type or parent class of the object to be serialized
36+
object_serializer -- custom implementation of BaseObjectSerializer for the object type
37+
"""
38+
_SERIALIZERS[object_type] = object_serializer
39+
40+
41+
_TYPE_FIELD: Final[str] = "__pickle__type__field__"
42+
_MODULE_FIELD: Final[str] = "__pickle__module__field__"
43+
44+
45+
def error_to_string(e: Exception) -> str:
46+
"""Serialize exception to base64-encoded string."""
47+
to_serialize: Exception | dict = e
48+
object_class = type(e)
49+
50+
for registered_class, object_serializer in _SERIALIZERS.items():
51+
if issubclass(object_class, registered_class):
52+
to_serialize = {
53+
_TYPE_FIELD: type(e).__name__,
54+
_MODULE_FIELD: type(e).__module__,
55+
**object_serializer.get_init_kwargs_from_object(e),
56+
}
57+
break
58+
59+
return base64.b85encode(pickle.dumps(to_serialize)).decode("utf-8")
60+
61+
62+
def error_from_string(error_str: str) -> Exception:
63+
"""Deserialize exception from base64-encoded string."""
64+
data = pickle.loads(base64.b85decode(error_str)) # noqa: S301
65+
66+
if isinstance(data, dict) and _TYPE_FIELD in data and _MODULE_FIELD in data:
67+
try:
68+
# Import the module and get the exception class
69+
module = __import__(data[_MODULE_FIELD], fromlist=[data[_TYPE_FIELD]])
70+
exception_class = getattr(module, data[_TYPE_FIELD])
71+
72+
for registered_class, object_serializer in _SERIALIZERS.items():
73+
if issubclass(exception_class, registered_class):
74+
# remove unrequired
75+
data.pop(_TYPE_FIELD)
76+
data.pop(_MODULE_FIELD)
77+
78+
return exception_class( # type: ignore[no-any-return]
79+
**object_serializer.prepare_object_init_kwargs(data)
80+
)
81+
except (ImportError, AttributeError, TypeError) as e:
82+
msg = f"Could not reconstruct object from data: {data}"
83+
raise ValueError(msg) from e
84+
85+
return data # type: ignore[no-any-return]

packages/service-library/src/servicelib/long_running_tasks/task.py

Lines changed: 4 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,8 @@
11
import asyncio
2-
import base64
32
import datetime
43
import functools
54
import inspect
65
import logging
7-
import pickle
86
import traceback
97
import urllib.parse
108
from contextlib import suppress
@@ -19,6 +17,7 @@
1917
from settings_library.redis import RedisDatabase, RedisSettings
2018

2119
from ..redis import RedisClientSDK, exclusive
20+
from ._error_serialization import error_from_string, error_to_string
2221
from ._store.base import BaseStore
2322
from ._store.redis import RedisStore
2423
from .errors import (
@@ -101,14 +100,6 @@ async def _get_tasks_to_remove(
101100
return tasks_to_remove
102101

103102

104-
def _error_to_string(e: Exception) -> str:
105-
return base64.b85encode(pickle.dumps(e)).decode("utf-8")
106-
107-
108-
def _error_from_string(error_str: str) -> Exception:
109-
return pickle.loads(base64.b85decode(error_str)) # type: ignore[no-any-return] # noqa: S301
110-
111-
112103
class TasksManager: # pylint:disable=too-many-instance-attributes
113104
"""
114105
Monitors execution and results retrieval of a collection of asyncio.Tasks
@@ -278,10 +269,10 @@ async def _status_update_worker(self) -> None:
278269
continue
279270
except asyncio.CancelledError:
280271
task_data.result_field = ResultField(
281-
error=_error_to_string(TaskCancelledError(task_id=task_id))
272+
error=error_to_string(TaskCancelledError(task_id=task_id))
282273
)
283274
except Exception as e: # pylint:disable=broad-except
284-
task_data.result_field = ResultField(error=_error_to_string(e))
275+
task_data.result_field = ResultField(error=error_to_string(e))
285276

286277
await self._tasks_data.set_task_data(task_id, task_data)
287278

@@ -369,7 +360,7 @@ async def get_task_result(
369360
raise TaskNotCompletedError(task_id=task_id)
370361

371362
if tracked_task.result_field.error is not None:
372-
raise _error_from_string(tracked_task.result_field.error)
363+
raise error_from_string(tracked_task.result_field.error)
373364

374365
return tracked_task.result_field.result
375366

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
from typing import Any
2+
3+
import pytest
4+
from aiohttp.web import HTTPException, HTTPInternalServerError
5+
from servicelib.aiohttp.long_running_tasks._server import AiohttpHTTPExceptionSerializer
6+
from servicelib.long_running_tasks._error_serialization import (
7+
error_from_string,
8+
error_to_string,
9+
register_custom_serialization,
10+
)
11+
12+
register_custom_serialization(HTTPException, AiohttpHTTPExceptionSerializer)
13+
14+
15+
class PositionalArguments:
16+
def __init__(self, arg1, arg2, *args):
17+
self.arg1 = arg1
18+
self.arg2 = arg2
19+
self.args = args
20+
21+
22+
class MixedArguments:
23+
def __init__(self, arg1, arg2, kwarg1=None, kwarg2=None):
24+
self.arg1 = arg1
25+
self.arg2 = arg2
26+
self.kwarg1 = kwarg1
27+
self.kwarg2 = kwarg2
28+
29+
30+
@pytest.mark.parametrize(
31+
"obj",
32+
[
33+
HTTPInternalServerError(reason="Uh-oh!", text="Failure!"),
34+
PositionalArguments("arg1", "arg2", "arg3", "arg4"),
35+
MixedArguments("arg1", "arg2", kwarg1="kwarg1", kwarg2="kwarg2"),
36+
"a_string",
37+
1,
38+
],
39+
)
40+
def test_serialization(obj: Any):
41+
str_data = error_to_string(obj)
42+
43+
reconstructed_obj = error_from_string(str_data)
44+
45+
assert type(reconstructed_obj) is type(obj)
46+
if hasattr(obj, "__dict__"):
47+
assert reconstructed_obj.__dict__ == obj.__dict__

0 commit comments

Comments
 (0)