Skip to content

Commit aab73fd

Browse files
committed
added support for gateway controller inheritance
1 parent f50cffc commit aab73fd

File tree

6 files changed

+57
-107
lines changed

6 files changed

+57
-107
lines changed

ellar/socket_io/decorators/gateway.py

Lines changed: 8 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
GATEWAY_WATERMARK,
1111
)
1212
from ellar.socket_io.model import GatewayBase, GatewayType
13-
from ellar.utils import get_name, get_type_of_base
13+
from ellar.utils import get_name
1414

1515

1616
def WebSocketGateway(
@@ -51,27 +51,15 @@ def _decorator(cls: t.Type) -> t.Type:
5151
str(get_name(_gateway_type)).lower().replace("gateway", "")
5252
)
5353

54-
for base in get_type_of_base(GatewayBase, _gateway_type):
55-
if reflect.has_metadata(GATEWAY_WATERMARK, base) and hasattr(
56-
_gateway_type, "__GATEWAY_WATERMARK__"
57-
):
58-
raise ImproperConfiguration(
59-
f"`@WebSocketGateway` decorated classes does not support inheritance. \n"
60-
f"{_gateway_type}"
61-
)
62-
63-
if not reflect.has_metadata(GATEWAY_WATERMARK, _gateway_type) and not hasattr(
64-
_gateway_type, "__GATEWAY_WATERMARK__"
65-
):
66-
reflect.define_metadata(GATEWAY_WATERMARK, True, _gateway_type)
67-
reflect.define_metadata(
68-
GATEWAY_OPTIONS, _kwargs["socket_init_kwargs"], _gateway_type
69-
)
54+
reflect.define_metadata(GATEWAY_WATERMARK, True, _gateway_type)
55+
reflect.define_metadata(
56+
GATEWAY_OPTIONS, _kwargs["socket_init_kwargs"], _gateway_type
57+
)
7058

71-
injectable(RequestORTransientScope)(_gateway_type)
59+
injectable(RequestORTransientScope)(_gateway_type)
7260

73-
for key in GATEWAY_METADATA.keys:
74-
reflect.define_metadata(key, _kwargs[key], _gateway_type)
61+
for key in GATEWAY_METADATA.keys:
62+
reflect.define_metadata(key, _kwargs[key], _gateway_type)
7563

7664
if new_cls:
7765
_gateway_type.__GATEWAY_WATERMARK__ = True # type:ignore[attr-defined]

ellar/socket_io/factory.py

Lines changed: 14 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,7 @@
11
import inspect
22
import typing as t
3-
from abc import ABC
43

54
import socketio
6-
from ellar.common.constants import CONTROLLER_CLASS_KEY
75
from ellar.core.router_builders import RouterBuilder
86
from ellar.reflect import reflect
97
from ellar.socket_io.adapter import SocketIOASGIApp
@@ -30,7 +28,9 @@ def _get_message_handler(
3028
cls,
3129
klass: t.Type,
3230
) -> t.Iterable[t.Union[t.Callable]]:
33-
for method in klass.__dict__.values():
31+
for _method_name, method in inspect.getmembers(
32+
klass, predicate=inspect.isfunction
33+
):
3434
if hasattr(method, MESSAGE_MAPPING_METADATA) and getattr(
3535
method, MESSAGE_MAPPING_METADATA
3636
):
@@ -40,28 +40,16 @@ def _get_message_handler(
4040
def _process_controller_routes(
4141
cls, klass: t.Type[GatewayBase]
4242
) -> t.List[t.Callable]:
43-
bases = inspect.getmro(klass)
43+
# bases = inspect.getmro(klass)
4444
results = []
4545

4646
if reflect.get_metadata(GATEWAY_METADATA.PROCESSED, klass):
4747
return reflect.get_metadata(GATEWAY_MESSAGE_HANDLER_KEY, klass) or []
4848

49-
for base_cls in reversed(bases):
50-
if base_cls not in [ABC, GatewayBase, object]:
51-
for method in cls._get_message_handler(base_cls):
52-
if reflect.has_metadata(CONTROLLER_CLASS_KEY, method):
53-
raise Exception(
54-
f"{klass.__name__} Gateway message handler tried to be processed more than once."
55-
f"\n-Message Handler - {method}."
56-
f"\n-Gateway message handler can not be reused once its under a `@Gateway` decorator."
57-
)
49+
for method in cls._get_message_handler(klass):
50+
results.append(method)
5851

59-
results.append(method)
60-
61-
reflect.define_metadata(CONTROLLER_CLASS_KEY, klass, method)
62-
reflect.define_metadata(
63-
GATEWAY_MESSAGE_HANDLER_KEY, [method], klass
64-
)
52+
reflect.define_metadata(GATEWAY_MESSAGE_HANDLER_KEY, [method], klass)
6553

6654
reflect.define_metadata(GATEWAY_METADATA.PROCESSED, True, klass)
6755
return results
@@ -91,14 +79,18 @@ def build(
9179
is_disconnection_handler = reflect.get_metadata(DISCONNECT_EVENT, handler)
9280

9381
if is_connection_handler:
94-
SocketOperationConnection(CONNECTION_EVENT, socket_server, handler)
82+
SocketOperationConnection(
83+
controller_type, CONNECTION_EVENT, socket_server, handler
84+
)
9585
elif is_disconnection_handler:
96-
SocketOperationConnection(DISCONNECT_EVENT, socket_server, handler)
86+
SocketOperationConnection(
87+
controller_type, DISCONNECT_EVENT, socket_server, handler
88+
)
9789
else:
9890
message = reflect.get_metadata_or_raise_exception(
9991
MESSAGE_METADATA, handler
10092
)
101-
SocketMessageOperation(message, socket_server, handler)
93+
SocketMessageOperation(controller_type, message, socket_server, handler)
10294

10395
return Mount(app=SocketIOASGIApp(socket_server), path=path, name=name)
10496

ellar/socket_io/gateway.py

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
serialize_object,
99
)
1010
from ellar.common.constants import (
11-
CONTROLLER_CLASS_KEY,
1211
CONTROLLER_OPERATION_HANDLER_KEY,
1312
EXTRA_ROUTE_ARGS_KEY,
1413
NOT_SET,
@@ -49,17 +48,19 @@ class SocketOperationConnection:
4948
)
5049

5150
def __init__(
52-
self, event: str, server: AsyncServer, message_handler: t.Callable
51+
self,
52+
controller_type: t.Type[GatewayBase],
53+
event: str,
54+
server: AsyncServer,
55+
message_handler: t.Callable,
5356
) -> None:
5457
self._event = event
5558
self._server = server
5659
self.endpoint = message_handler
5760
self._name = get_name(self.endpoint)
5861
self._is_coroutine = inspect.iscoroutinefunction(message_handler)
5962
self.endpoint_parameter_model = NOT_SET
60-
self._controller_type: t.Type[GatewayBase] = reflect.get_metadata( # type: ignore[assignment]
61-
CONTROLLER_CLASS_KEY, self.endpoint
62-
)
63+
self._controller_type = controller_type
6364
self._load_model()
6465
self._register_handler()
6566
reflect.define_metadata(CONTROLLER_OPERATION_HANDLER_KEY, self, self.endpoint)
@@ -192,12 +193,6 @@ def get_controller_type(self) -> t.Type[GatewayBase]:
192193
For operation under ModuleRouter, this will return a unique type created for the router for tracking some properties
193194
:return: a type that wraps the operation
194195
"""
195-
if not self._controller_type:
196-
_controller_type = reflect.get_metadata(CONTROLLER_CLASS_KEY, self.endpoint)
197-
if _controller_type is None or not isinstance(_controller_type, type):
198-
raise Exception("Operation must have a single control type.")
199-
self._controller_type = t.cast(t.Type[GatewayBase], _controller_type)
200-
201196
return self._controller_type
202197

203198
def _get_gateway_instance(self, ctx: IExecutionContext) -> GatewayBase:

ellar/utils/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ def build_init_kwargs(obj: t.Type, init_kwargs: t.Dict) -> t.Dict:
7676

7777
def get_type_of_base(
7878
base_type: t.Type[t.Any], reference_type: t.Type[t.Any]
79-
) -> t.Iterable[t.Type[t.Any]]:
79+
) -> t.Iterable[t.Type[t.Any]]: # pragma: no cover
8080
for base in inspect.getmro(reference_type):
8181
if issubclass(base, base_type):
8282
yield base

tests/test_socket_io/test_decorators/test_gateway.py

Lines changed: 28 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11
import pytest
22
from ellar.auth.guards import GuardHttpBearerAuth
33
from ellar.common import UseGuards
4-
from ellar.common.constants import CONTROLLER_CLASS_KEY, GUARDS_KEY
5-
from ellar.common.exceptions import ImproperConfiguration
4+
from ellar.common.constants import CONTROLLER_OPERATION_HANDLER_KEY, GUARDS_KEY
65
from ellar.reflect import reflect
76
from ellar.socket_io import GatewayRouterFactory, WebSocketGateway, subscribe_message
87
from ellar.socket_io.constants import (
@@ -97,28 +96,27 @@ def a_message(self):
9796
assert message == "a_message"
9897
assert (
9998
reflect.get_metadata_or_raise_exception(
100-
CONTROLLER_CLASS_KEY, SampleAGateway().a_message
101-
)
99+
CONTROLLER_OPERATION_HANDLER_KEY, SampleAGateway().a_message
100+
).get_controller_type()
102101
is SampleAGateway
103102
)
104103

105104

106-
def test_sub_message_building_fails():
107-
with pytest.raises(Exception) as ex:
108-
109-
@WebSocketGateway(path="/ws", namespace="/some-namespace")
110-
class SampleBGateway(GatewayBase):
111-
@subscribe_message
112-
@reflect.metadata(CONTROLLER_CLASS_KEY, "b_message")
113-
def b_message(self):
114-
pass
115-
116-
GatewayRouterFactory.build(SampleBGateway)
117-
118-
assert (
119-
"SampleBGateway Gateway message handler tried to be processed more than once"
120-
in str(ex.value)
121-
)
105+
# def test_sub_message_building_fails():
106+
# with pytest.raises(Exception) as ex:
107+
#
108+
# @WebSocketGateway(path="/ws", namespace="/some-namespace")
109+
# class SampleBGateway(GatewayBase):
110+
# @subscribe_message
111+
# def b_message(self):
112+
# pass
113+
#
114+
# GatewayRouterFactory.build(SampleBGateway)
115+
#
116+
# assert (
117+
# "SampleBGateway Gateway message handler tried to be processed more than once"
118+
# in str(ex.value)
119+
# )
122120

123121

124122
def test_cant_use_gateway_decorator_on_function():
@@ -131,13 +129,16 @@ def sample_c_gateway():
131129
assert "WebSocketGateway is a class decorator" in str(ex.value)
132130

133131

134-
def test_inheritance_fails():
135-
with pytest.raises(ImproperConfiguration) as ex:
136-
137-
@WebSocketGateway(path="/ws", namespace="/some-namespace")
138-
class InheritanceNotSupported(SampleWithoutGateway):
132+
def test_inheritance_works():
133+
@WebSocketGateway(path="/ws", namespace="/some-namespace")
134+
class InheritanceSupported(SampleWithoutGateway):
135+
@subscribe_message
136+
def b_message(self):
139137
pass
140138

141-
assert "`@WebSocketGateway` decorated classes does not support inheritance." in str(
142-
ex.value
139+
GatewayRouterFactory.build(InheritanceSupported)
140+
141+
message_handlers = reflect.get_metadata_or_raise_exception(
142+
GATEWAY_MESSAGE_HANDLER_KEY, InheritanceSupported
143143
)
144+
assert len(message_handlers) == 1

tests/test_socket_io/test_gateway.py

Lines changed: 0 additions & 26 deletions
This file was deleted.

0 commit comments

Comments
 (0)