Skip to content

Commit 664a8f5

Browse files
committed
added support for openapi document computation for nested route mount
1 parent d9ca03f commit 664a8f5

File tree

4 files changed

+64
-23
lines changed

4 files changed

+64
-23
lines changed

ellar/core/routing/mount.py

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import functools
22
import typing as t
3-
import uuid
43

54
from ellar.common import ControllerBase, ModuleRouter
65
from ellar.common.constants import (
@@ -11,6 +10,7 @@
1110
from ellar.common.types import TReceive, TScope, TSend
1211
from ellar.core.router_builders.base import get_controller_builder_factory
1312
from ellar.reflect import reflect
13+
from ellar.utils import get_unique_type
1414
from starlette.middleware import Middleware
1515
from starlette.routing import BaseRoute, Match, Route, Router
1616
from starlette.routing import Mount as StarletteMount
@@ -38,8 +38,7 @@ def __init__(
3838
app.routes = RouteCollection(routes) # type:ignore
3939
super().__init__(path=path, app=app, name=name, middleware=[])
4040
self.include_in_schema = include_in_schema
41-
self._current_found_route_key = f"{uuid.uuid4().hex:4}_EllarMountRoute"
42-
self._control_type = control_type
41+
self._control_type = control_type or get_unique_type("EllarMountDynamicType")
4342

4443
self.user_middleware = [] if middleware is None else list(middleware)
4544
self._middleware_stack: t.Optional[ASGIApp] = None
@@ -50,7 +49,7 @@ def build_middleware_stack(self) -> ASGIApp:
5049
app = cls(app, *args, **kwargs)
5150
return app
5251

53-
def get_control_type(self) -> t.Optional[t.Type[t.Any]]:
52+
def get_control_type(self) -> t.Type[t.Any]:
5453
return self._control_type
5554

5655
def add_route(
@@ -69,9 +68,7 @@ def add_route(
6968

7069
if not isinstance(route, BaseRoute) and self.get_control_type():
7170
reflect.define_metadata(
72-
CONTROLLER_CLASS_KEY,
73-
route,
74-
self.get_control_type(), # type:ignore[arg-type]
71+
CONTROLLER_CLASS_KEY, route, self.get_control_type()
7572
)
7673

7774
self.routes.append(route) # type:ignore[arg-type]
@@ -93,7 +90,7 @@ def matches(self, scope: TScope) -> t.Tuple[Match, TScope]:
9390
match, child_scope = route.matches(scope_copy)
9491
if match == Match.FULL:
9592
_child_scope.update(child_scope)
96-
_child_scope[self._current_found_route_key] = route
93+
_child_scope[str(self.get_control_type())] = route
9794
return Match.FULL, _child_scope
9895
elif (
9996
match == Match.PARTIAL
@@ -105,7 +102,7 @@ def matches(self, scope: TScope) -> t.Tuple[Match, TScope]:
105102
partial_scope.update(child_scope)
106103

107104
if partial:
108-
partial_scope[self._current_found_route_key] = partial
105+
partial_scope[str(self.get_control_type())] = partial
109106
return Match.PARTIAL, partial_scope
110107

111108
return Match.NONE, {}
@@ -114,9 +111,9 @@ async def _app_handler(self, scope: TScope, receive: TReceive, send: TSend) -> N
114111
request_logger.debug(
115112
f"Executing Matched URL Handler, path={scope['path']} - '{self.__class__.__name__}'"
116113
)
117-
route = t.cast(t.Optional[Route], scope.get(self._current_found_route_key))
114+
route = t.cast(t.Optional[Route], scope.get(str(self.get_control_type())))
118115
if route:
119-
del scope[self._current_found_route_key]
116+
del scope[str(self.get_control_type())]
120117
await route.handle(scope, receive, send)
121118
else:
122119
mount_router = t.cast(Router, self.app)

ellar/openapi/builder.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import functools
12
import typing as t
23

34
from ellar.common import IIdentitySchemes
@@ -37,6 +38,16 @@ class DocumentOpenAPIFactory:
3738
def __init__(self, document_dict: t.Dict) -> None:
3839
self._build = document_dict
3940

41+
def _update_openapi_route_models(
42+
self,
43+
route_models: t.List[OpenAPIRoute],
44+
model: OpenAPIMountDocumentation,
45+
openapi_tags: AttributeDict,
46+
) -> None:
47+
route_models.append(model)
48+
if openapi_tags:
49+
self._build.setdefault("tags", []).append(openapi_tags)
50+
4051
def _get_openapi_route_document_models(self, app: "App") -> t.List[OpenAPIRoute]:
4152
openapi_route_models: t.List = []
4253
reflector = app.reflector
@@ -65,6 +76,9 @@ def _get_openapi_route_document_models(self, app: "App") -> t.List[OpenAPIRoute]
6576
mount=route,
6677
global_guards=guards or app_guards,
6778
name=openapi_tags.name,
79+
global_route_models_update=functools.partial(
80+
self._update_openapi_route_models, openapi_route_models
81+
),
6882
)
6983
)
7084
if openapi_tags:

ellar/openapi/route_doc_models.py

Lines changed: 40 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import typing as t
22
from abc import ABC, abstractmethod
33

4-
from ellar.common.compatible import cached_property
4+
from ellar.common.compatible import AttributeDict, cached_property
55
from ellar.common.constants import (
66
GUARDS_KEY,
77
METHODS_WITH_BODY,
@@ -19,7 +19,11 @@
1919
from ellar.common.shortcuts import normalize_path
2020
from ellar.core.routing import EllarMount, RouteOperation
2121
from ellar.core.services.reflector import reflector
22-
from ellar.openapi.constants import IGNORE_CONTROLLER_TYPE, OPENAPI_OPERATION_KEY
22+
from ellar.openapi.constants import (
23+
IGNORE_CONTROLLER_TYPE,
24+
OPENAPI_OPERATION_KEY,
25+
OPENAPI_TAG,
26+
)
2327
from ellar.pydantic import (
2428
JsonSchemaValue,
2529
ModelField,
@@ -55,27 +59,33 @@ class OpenAPIMountDocumentation(OpenAPIRoute):
5559
def __init__(
5660
self,
5761
mount: t.Union[EllarMount, Mount],
62+
global_route_models_update: t.Callable[
63+
["OpenAPIMountDocumentation", t.Dict], t.Any
64+
],
5865
name: t.Optional[str] = None,
5966
global_guards: t.Optional[
6067
t.List[t.Union[t.Type["GuardCanActivate"], "GuardCanActivate"]]
6168
] = None,
69+
path_prefix: t.Optional[str] = None,
6270
) -> None:
6371
self.tag = name
6472
self.mount = mount
6573
self.path_regex, self.path_format, self.param_convertors = compile_path(
66-
self.mount.path
74+
normalize_path(f"/{path_prefix}/{mount.path}")
75+
if path_prefix
76+
else mount.path
6777
)
6878
# if there is some convertor on ModuleMount Object, then we need to convert it to ModelField
6979
self.global_route_parameters: t.List[ModelField] = [
7080
EndpointArgsModel.get_convertor_model_field(name, convertor)
7181
for name, convertor in self.param_convertors.items()
7282
]
7383
self.global_guards = global_guards or []
74-
75-
self._routes: t.List["OpenAPIRouteDocumentation"] = self._build_routes()
84+
self._global_route_models_update = global_route_models_update
85+
self.routes: t.List["OpenAPIRouteDocumentation"] = self._build_routes()
7686

7787
def _build_routes(self) -> t.List["OpenAPIRouteDocumentation"]:
78-
_routes: t.List[OpenAPIRouteDocumentation] = []
88+
routes: t.List[OpenAPIRouteDocumentation] = []
7989

8090
for route in self.mount.routes:
8191
if isinstance(route, RouteOperation) and route.include_in_schema:
@@ -85,20 +95,40 @@ def _build_routes(self) -> t.List["OpenAPIRouteDocumentation"]:
8595
if not openapi.get("tags", False):
8696
openapi.update(tags=[self.tag] if self.tag else ["default"])
8797

88-
_routes.append(
98+
routes.append(
8999
OpenAPIRouteDocumentation(
90100
route=route,
91101
global_route_parameters=self.global_route_parameters,
92102
guards=guards or self.global_guards,
93103
**openapi,
94104
)
95105
)
96-
return _routes
106+
elif isinstance(route, EllarMount):
107+
openapi_tags = AttributeDict(
108+
reflector.get(OPENAPI_TAG, route.get_control_type()) or {}
109+
)
110+
111+
if route.name:
112+
openapi_tags.setdefault("name", route.name)
113+
114+
guards = reflector.get(GUARDS_KEY, route.get_control_type())
115+
116+
self._global_route_models_update(
117+
OpenAPIMountDocumentation(
118+
mount=route,
119+
global_guards=guards or self.global_guards,
120+
name=openapi_tags.name,
121+
global_route_models_update=self._global_route_models_update,
122+
path_prefix=self.path_format,
123+
),
124+
openapi_tags,
125+
)
126+
return routes
97127

98128
@cached_property
99129
def _openapi_models(self) -> t.List[ModelField]:
100130
_models = []
101-
for route in self._routes:
131+
for route in self.routes:
102132
_models.extend(route.get_route_models())
103133
return _models
104134

@@ -121,7 +151,7 @@ def get_openapi_path(
121151
if path_prefix
122152
else self.path_format
123153
)
124-
for openapi_route in self._routes:
154+
for openapi_route in self.routes:
125155
openapi_route.get_openapi_path(
126156
paths=paths,
127157
security_schemes=security_schemes,

ellar/utils/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,8 @@ def generate_controller_operation_unique_id(
4141
return hash(path + _methods + _versioning + extra_string)
4242

4343

44-
def get_unique_type() -> t.Type:
45-
return type(f"DynamicType{uuid.uuid4().hex[:6]}", (), {})
44+
def get_unique_type(prefix: str = "DynamicType") -> t.Type:
45+
return type(f"{prefix}{uuid.uuid4().hex[:6]}", (), {})
4646

4747

4848
def get_name(endpoint: t.Union[t.Callable, t.Type, object]) -> str:

0 commit comments

Comments
 (0)