11import functools
22import typing as t
3- from functools import partial
43from types import FunctionType
54
65from ellar .common .constants import (
1514 ROUTE_OPERATION_PARAMETERS ,
1615 TRACE ,
1716)
17+ from ellar .common .interfaces .operation import IWebSocketConnectionAttributes
1818from ellar .reflect import ensure_target
1919
2020from .schema import RouteParameters , WsRouteParameters
2121
2222
23- class _WebSocketConnectionAttributes :
24- """
25- This class is used to add connection attributes to a websocket handler.
26- """
27-
28- __slots__ = ("_original_func" , "_func" )
29-
30- def __init__ (self , func : t .Callable ) -> None :
31- self ._original_func = func
32- self ._func : t .Optional [t .Callable ] = None
33-
34- def __call__ (
35- self ,
36- path : str = "/" ,
37- * ,
38- name : t .Optional [str ] = None ,
39- encoding : t .Optional [str ] = "json" ,
40- use_extra_handler : bool = False ,
41- extra_handler_type : t .Optional [t .Type ] = None ,
42- ) -> t .Callable :
43- if self ._func is None : # pragma: no cover
44- raise Exception ("Something went wrong" )
45-
46- res = self ._func (
47- path = path ,
48- name = name ,
49- encoding = encoding ,
50- use_extra_handler = use_extra_handler ,
51- extra_handler_type = extra_handler_type ,
52- )
53- return t .cast (t .Callable , res )
54-
55- def __get__ (self , instance : t .Any , owner : t .Any ) -> t .Callable :
56- self ._func = functools .partial (self ._original_func , instance )
57- return self
58-
59- @classmethod
23+ def _websocket_connection_attributes (
24+ func : t .Callable ,
25+ ) -> IWebSocketConnectionAttributes :
6026 def _advance_function (
61- cls , websocket_handler : t .Callable , handler_name : str
27+ websocket_handler : t .Callable , handler_name : str
6228 ) -> t .Callable :
63- """
64- This method is used to register the connection attributes to a websocket handler.
65- """
66-
6729 def _wrap (connect_handler : t .Callable ) -> t .Callable :
6830 if not (
6931 callable (websocket_handler ) and type (websocket_handler ) is FunctionType
@@ -89,17 +51,10 @@ def _wrap(connect_handler: t.Callable) -> t.Callable:
8951
9052 return _wrap
9153
92- def connect (self , f : t .Callable ) -> t .Callable :
93- """
94- This method is used to register the connect handler to a websocket handler.
95- """
96- return self ._advance_function (f , "on_connect" )
54+ func .connect = functools .partial (_advance_function , handler_name = "on_connect" ) # type: ignore[attr-defined]
55+ func .disconnect = functools .partial (_advance_function , handler_name = "on_disconnect" ) # type: ignore[attr-defined]
9756
98- def disconnect (self , f : t .Callable ) -> t .Callable :
99- """
100- This method is used to register the disconnect handler to a websocket handler.
101- """
102- return self ._advance_function (f , "on_disconnect" )
57+ return t .cast (IWebSocketConnectionAttributes , func )
10358
10459
10560class OperationDefinitions :
@@ -182,7 +137,7 @@ def get(
182137 ] = None ,
183138 ) -> t .Callable :
184139 methods = [GET ]
185- endpoint_parameter_partial = partial (
140+ endpoint_parameter_partial = functools . partial (
186141 RouteParameters ,
187142 name = name ,
188143 methods = methods ,
@@ -202,7 +157,7 @@ def post(
202157 ] = None ,
203158 ) -> t .Callable :
204159 methods = [POST ]
205- endpoint_parameter_partial = partial (
160+ endpoint_parameter_partial = functools . partial (
206161 RouteParameters ,
207162 name = name ,
208163 methods = methods ,
@@ -222,7 +177,7 @@ def put(
222177 ] = None ,
223178 ) -> t .Callable :
224179 methods = [PUT ]
225- endpoint_parameter_partial = partial (
180+ endpoint_parameter_partial = functools . partial (
226181 RouteParameters ,
227182 name = name ,
228183 methods = methods ,
@@ -242,7 +197,7 @@ def patch(
242197 ] = None ,
243198 ) -> t .Callable :
244199 methods = [PATCH ]
245- endpoint_parameter_partial = partial (
200+ endpoint_parameter_partial = functools . partial (
246201 RouteParameters ,
247202 name = name ,
248203 methods = methods ,
@@ -262,7 +217,7 @@ def delete(
262217 ] = None ,
263218 ) -> t .Callable :
264219 methods = [DELETE ]
265- endpoint_parameter_partial = partial (
220+ endpoint_parameter_partial = functools . partial (
266221 RouteParameters ,
267222 name = name ,
268223 methods = methods ,
@@ -282,7 +237,7 @@ def head(
282237 ] = None ,
283238 ) -> t .Callable :
284239 methods = [HEAD ]
285- endpoint_parameter_partial = partial (
240+ endpoint_parameter_partial = functools . partial (
286241 RouteParameters ,
287242 name = name ,
288243 methods = methods ,
@@ -302,7 +257,7 @@ def options(
302257 ] = None ,
303258 ) -> t .Callable :
304259 methods = [OPTIONS ]
305- endpoint_parameter_partial = partial (
260+ endpoint_parameter_partial = functools . partial (
306261 RouteParameters ,
307262 name = name ,
308263 methods = methods ,
@@ -322,7 +277,7 @@ def trace(
322277 ] = None ,
323278 ) -> t .Callable :
324279 methods = [TRACE ]
325- endpoint_parameter_partial = partial (
280+ endpoint_parameter_partial = functools . partial (
326281 RouteParameters ,
327282 name = name ,
328283 methods = methods ,
@@ -356,7 +311,7 @@ def _decorator(endpoint_handler: t.Callable) -> t.Callable:
356311
357312 return _decorator
358313
359- @_WebSocketConnectionAttributes
314+ @_websocket_connection_attributes
360315 def ws_route (
361316 self ,
362317 path : str = "/" ,
0 commit comments