Skip to content

Commit 57f0aed

Browse files
committed
fixed failing test
1 parent aff87c6 commit 57f0aed

File tree

4 files changed

+45
-63
lines changed

4 files changed

+45
-63
lines changed

ellar/common/interfaces/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from .interceptor_consumer import IInterceptorsConsumer
1616
from .middleware import IEllarMiddleware
1717
from .module import IModuleSetup
18+
from .operation import IWebSocketConnectionAttributes
1819
from .response_model import IResponseModel
1920
from .templating import IModuleTemplateLoader, ITemplateRenderingService
2021
from .versioning import IAPIVersioning, IAPIVersioningResolver
@@ -43,4 +44,5 @@
4344
"IIdentitySchemes",
4445
"IApplicationReady",
4546
"ITemplateRenderingService",
47+
"IWebSocketConnectionAttributes",
4648
]
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
import typing as t
2+
3+
4+
class IWebSocketConnectionAttributes(t.Protocol):
5+
"""
6+
Interface for WebSocket connection attributes.
7+
"""
8+
9+
def connect(self, websocket_handler: t.Callable) -> t.Callable:
10+
"""
11+
Register the connect handler to a websocket handler.
12+
13+
:param websocket_handler: The websocket handler to register the connect handler to.
14+
:return: The connect handler.
15+
"""
16+
17+
def disconnect(self, websocket_handler: t.Callable) -> t.Callable:
18+
"""
19+
Register the disconnect handler to a websocket handler.
20+
21+
:param websocket_handler: The websocket handler to register the disconnect handler to.
22+
:return: The disconnect handler.
23+
"""

ellar/common/operations/base.py

Lines changed: 17 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import functools
22
import typing as t
3-
from functools import partial
43
from types import FunctionType
54

65
from ellar.common.constants import (
@@ -15,55 +14,18 @@
1514
ROUTE_OPERATION_PARAMETERS,
1615
TRACE,
1716
)
17+
from ellar.common.interfaces.operation import IWebSocketConnectionAttributes
1818
from ellar.reflect import ensure_target
1919

2020
from .schema import RouteParameters, WsRouteParameters
2121

2222

23-
class _WebSocketConnectionAttributes:
24-
"""
25-
This class is used to add connection attributes to a websocket handler.
26-
"""
27-
28-
__slots__ = ("_original_func", "_func")
29-
30-
def __init__(self, func: t.Callable) -> None:
31-
self._original_func = func
32-
self._func: t.Optional[t.Callable] = None
33-
34-
def __call__(
35-
self,
36-
path: str = "/",
37-
*,
38-
name: t.Optional[str] = None,
39-
encoding: t.Optional[str] = "json",
40-
use_extra_handler: bool = False,
41-
extra_handler_type: t.Optional[t.Type] = None,
42-
) -> t.Callable:
43-
if self._func is None: # pragma: no cover
44-
raise Exception("Something went wrong")
45-
46-
res = self._func(
47-
path=path,
48-
name=name,
49-
encoding=encoding,
50-
use_extra_handler=use_extra_handler,
51-
extra_handler_type=extra_handler_type,
52-
)
53-
return t.cast(t.Callable, res)
54-
55-
def __get__(self, instance: t.Any, owner: t.Any) -> t.Callable:
56-
self._func = functools.partial(self._original_func, instance)
57-
return self
58-
59-
@classmethod
23+
def _websocket_connection_attributes(
24+
func: t.Callable,
25+
) -> IWebSocketConnectionAttributes:
6026
def _advance_function(
61-
cls, websocket_handler: t.Callable, handler_name: str
27+
websocket_handler: t.Callable, handler_name: str
6228
) -> t.Callable:
63-
"""
64-
This method is used to register the connection attributes to a websocket handler.
65-
"""
66-
6729
def _wrap(connect_handler: t.Callable) -> t.Callable:
6830
if not (
6931
callable(websocket_handler) and type(websocket_handler) is FunctionType
@@ -89,17 +51,10 @@ def _wrap(connect_handler: t.Callable) -> t.Callable:
8951

9052
return _wrap
9153

92-
def connect(self, f: t.Callable) -> t.Callable:
93-
"""
94-
This method is used to register the connect handler to a websocket handler.
95-
"""
96-
return self._advance_function(f, "on_connect")
54+
func.connect = functools.partial(_advance_function, handler_name="on_connect") # type: ignore[attr-defined]
55+
func.disconnect = functools.partial(_advance_function, handler_name="on_disconnect") # type: ignore[attr-defined]
9756

98-
def disconnect(self, f: t.Callable) -> t.Callable:
99-
"""
100-
This method is used to register the disconnect handler to a websocket handler.
101-
"""
102-
return self._advance_function(f, "on_disconnect")
57+
return t.cast(IWebSocketConnectionAttributes, func)
10358

10459

10560
class OperationDefinitions:
@@ -182,7 +137,7 @@ def get(
182137
] = None,
183138
) -> t.Callable:
184139
methods = [GET]
185-
endpoint_parameter_partial = partial(
140+
endpoint_parameter_partial = functools.partial(
186141
RouteParameters,
187142
name=name,
188143
methods=methods,
@@ -202,7 +157,7 @@ def post(
202157
] = None,
203158
) -> t.Callable:
204159
methods = [POST]
205-
endpoint_parameter_partial = partial(
160+
endpoint_parameter_partial = functools.partial(
206161
RouteParameters,
207162
name=name,
208163
methods=methods,
@@ -222,7 +177,7 @@ def put(
222177
] = None,
223178
) -> t.Callable:
224179
methods = [PUT]
225-
endpoint_parameter_partial = partial(
180+
endpoint_parameter_partial = functools.partial(
226181
RouteParameters,
227182
name=name,
228183
methods=methods,
@@ -242,7 +197,7 @@ def patch(
242197
] = None,
243198
) -> t.Callable:
244199
methods = [PATCH]
245-
endpoint_parameter_partial = partial(
200+
endpoint_parameter_partial = functools.partial(
246201
RouteParameters,
247202
name=name,
248203
methods=methods,
@@ -262,7 +217,7 @@ def delete(
262217
] = None,
263218
) -> t.Callable:
264219
methods = [DELETE]
265-
endpoint_parameter_partial = partial(
220+
endpoint_parameter_partial = functools.partial(
266221
RouteParameters,
267222
name=name,
268223
methods=methods,
@@ -282,7 +237,7 @@ def head(
282237
] = None,
283238
) -> t.Callable:
284239
methods = [HEAD]
285-
endpoint_parameter_partial = partial(
240+
endpoint_parameter_partial = functools.partial(
286241
RouteParameters,
287242
name=name,
288243
methods=methods,
@@ -302,7 +257,7 @@ def options(
302257
] = None,
303258
) -> t.Callable:
304259
methods = [OPTIONS]
305-
endpoint_parameter_partial = partial(
260+
endpoint_parameter_partial = functools.partial(
306261
RouteParameters,
307262
name=name,
308263
methods=methods,
@@ -322,7 +277,7 @@ def trace(
322277
] = None,
323278
) -> t.Callable:
324279
methods = [TRACE]
325-
endpoint_parameter_partial = partial(
280+
endpoint_parameter_partial = functools.partial(
326281
RouteParameters,
327282
name=name,
328283
methods=methods,
@@ -356,7 +311,7 @@ def _decorator(endpoint_handler: t.Callable) -> t.Callable:
356311

357312
return _decorator
358313

359-
@_WebSocketConnectionAttributes
314+
@_websocket_connection_attributes
360315
def ws_route(
361316
self,
362317
path: str = "/",

ellar/core/router_builders/controller.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,9 @@ def process_controller_routes(controller: t.Type[ControllerBase]) -> t.List[Base
3232
for _, item in get_functions_with_tag(
3333
controller, tag=constants.OPERATION_ENDPOINT_KEY
3434
):
35-
parameters = item.__dict__[constants.ROUTE_OPERATION_PARAMETERS]
35+
parameters = item.__dict__.get(constants.ROUTE_OPERATION_PARAMETERS)
36+
if parameters is None:
37+
print("Something is not right")
3638
operation: t.Union[ControllerRouteOperation, ControllerWebsocketRouteOperation]
3739

3840
if not isinstance(parameters, list):

0 commit comments

Comments
 (0)