-
Notifications
You must be signed in to change notification settings - Fork 79
Expand file tree
/
Copy pathmain.py
More file actions
268 lines (231 loc) · 10.2 KB
/
main.py
File metadata and controls
268 lines (231 loc) · 10.2 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
"""
Defines Rest API endpoints.
Note: order matters for overloaded paths
(https://fastapi.tiangolo.com/tutorial/path-params/#order-matters).
"""
import uuid
from starlette.middleware.base import BaseHTTPMiddleware
from fastapi import Request
class CorrelationIdMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next):
#1.it generate the unique id i tried to use only single id for the request and response cycle, so that we can easily trace the logs and the response for a particular request.
correlation_id = str(uuid.uuid4())
request.state.correlation_id = correlation_id
#2.pass the request to the next person in line
response = await call_next(request)
#3.inject the correlation id into response headers for client side tracking (Egress) stamped the id on the way out so the user sees it
response.headers["X-Correlation-ID"] = getattr(request.state, "correlation_id", "not-started")
return response
import argparse
import logging
from pathlib import Path
from importlib.metadata import version as pkg_version, PackageNotFoundError
import uvicorn
from fastapi import Depends, FastAPI, HTTPException
from starlette.exceptions import HTTPException as StarletteHTTPException
from fastapi.exceptions import HTTPException as FastAPIHTTPException
from error_handling.error_handling import http_exception_handler
from fastapi.responses import HTMLResponse
from sqlmodel import select, SQLModel
from starlette.requests import Request
from authentication import get_user_or_raise, KeycloakUser, assert_required_settings_configured
from config import KEYCLOAK_CONFIG, DB_CONFIG, DEV_CONFIG
from database.deletion.triggers import (
create_delete_triggers,
create_identifier_synchronization_triggers,
)
import database.authorization # noqa # Trigger registration of User, Permission -> likely obsolete when couple with aiod_entry is done
from database.model.concept.concept import AIoDConcept
from database.model.platform.platform import Platform
from database.model.platform.platform_names import PlatformName
from database.session import EngineSingleton, DbSession
from database.setup import create_database, database_exists
from routers.resource_routers import versioned_routers
from setup_logger import setup_logger
from taxonomies.synchronize_taxonomy import synchronize_taxonomy_from_file
from triggers import disable_review_process, enable_review_process
from error_handling.error_handling import http_exception_handler
from routers import (
resource_routers,
parent_routers,
enum_routers,
search_routers,
review_router,
user_router,
bookmark_router,
asset_router,
)
from prometheus_fastapi_instrumentator import Instrumentator
from middleware.access_log import AccessLogMiddleware
from routers.access_stats_router import create as create_access_stats_router
from versioning import (
versions,
add_version_to_openapi,
add_deprecation_and_sunset_middleware,
Version,
)
import logging
import sys
#just to be sure we can look for particular correlation ids in the logs without having to parse the entire log line
logging.basicConfig(
level=logging.INFO,
format="%(levelname)s: %(message)s",
stream=sys.stdout,
force=True #this overrides any existing hidden configs
)
def add_routes(app: FastAPI, version: Version, url_prefix=""):
"""Add routes to the FastAPI application"""
@app.get("/", include_in_schema=False, response_class=HTMLResponse)
def home(request: Request) -> str:
"""Provides a redirect page to the docs."""
proxy_prefix = request.headers.get("x-forwarded-prefix", "")
prefix = proxy_prefix + version.prefix
return f"""
<!DOCTYPE html>
<html>
<head>
<meta http-equiv="refresh" content="0; url='{prefix}/docs'" />
</head>
<body>
<p>The REST API documentation is <a href="{prefix}/docs">here</a>.</p>
</body>
</html>
"""
@app.get("/authorization_test")
def test_authorization(user: KeycloakUser = Depends(get_user_or_raise)) -> KeycloakUser: # noqa: B008
"""
Returns the user, if authenticated correctly.
"""
return user
@app.get("/counts")
def counts() -> dict:
return {
router.resource_name_plural: count
for router in resource_routers.versioned_routers.get(version, [])
if issubclass(router.resource_class, AIoDConcept)
and (count := router.get_resource_count_func()(detailed=True))
}
for router in versioned_routers.get(version, []):
app.include_router(router.create(url_prefix, version))
for router in (
parent_routers.router_list
+ enum_routers.router_list
+ search_routers.router_list
+ [review_router, user_router, bookmark_router, asset_router]
+ resource_routers.router_list
):
app.include_router(router.create(url_prefix, version))
app.include_router(create_access_stats_router(url_prefix))
def create_app() -> FastAPI:
"""Create the FastAPI application, complete with routes."""
setup_logger()
assert_required_settings_configured()
build_database_setting = DB_CONFIG.get("build_database", "never")
if build_database_setting == "never":
if not database_exists():
logging.warning(
"AI-on-Demand database does not exist on the MySQL server, "
"but `build_db` is set to 'never'. If you are not creating the "
"database through other means, such as MySQL group replication, "
"this likely means that you will get errors or undefined behavior."
)
else:
drop_database = build_database_setting == "drop-then-build"
build_database(drop_database=drop_database)
if taxonomy_path := DEV_CONFIG.get("taxonomy"):
if not (taxonomy_file := Path(taxonomy_path)).is_file():
raise ValueError(f"dev.taxonomy must be a path to a file, but is {taxonomy_path!r}.")
synchronize_taxonomy_from_file(taxonomy_file)
try:
dist_version = pkg_version("aiod_metadata_catalogue")
except PackageNotFoundError:
dist_version = "dev"
app = build_app(url_prefix=DEV_CONFIG.get("url_prefix", ""), version=dist_version)
return app
def build_app(*, url_prefix: str = "", version: str = "dev"):
kwargs = {
"docs_url": None, # We override the default pages with custom html
"redoc_url": None,
"description": "This is the REST API documentation of the AIoD Metadata Catalogue. "
"See also our general "
'<a href="https://aiondemand.github.io/AIOD-rest-api/">metadata catalogue documentation</a>, ' # noqa: E501
"and our "
'<a href="https://github.com/aiondemand/AIOD-rest-api/releases">changelog</a>.',
"swagger_ui_oauth2_redirect_url": "/docs/oauth2-redirect",
"swagger_ui_init_oauth": {
"clientId": KEYCLOAK_CONFIG.get("client_id_swagger"),
"realm": KEYCLOAK_CONFIG.get("realm"),
"appName": "AIoD Metadata Catalogue",
"usePkceWithAuthorizationCodeGrant": True,
"scopes": KEYCLOAK_CONFIG.get("scopes"),
},
}
main_app = FastAPI(
title="AI-on-Demand Metadata Catalogue REST API",
version="latest",
**kwargs,
)
main_app.add_middleware(CorrelationIdMiddleware)
versioned_apps = [
(
FastAPI(
title=f"AIoD Metadata Catalogue {version}",
version=f"{version}",
**kwargs,
),
version,
)
for version, info in versions.items()
if not info.retired
]
for app, version in [(main_app, Version.LATEST)] + versioned_apps:
add_routes(app, version=version)
app.exception_handlers[FastAPIHTTPException] = http_exception_handler
#this is needed to catch exceptions raised by Starlette, such as 404s for non existent endpoints which are not caught by FastAPI HTTPException handler
app.exception_handlers[StarletteHTTPException] = http_exception_handler
app.add_exception_handler(404, http_exception_handler)
add_deprecation_and_sunset_middleware(app)
add_version_to_openapi(app)
Instrumentator().instrument(main_app).expose(
main_app, endpoint="/metrics", include_in_schema=False
)
#Since all traffic goes through the main app this middleware only
#needs to be registered with the main app and not the mounted apps
main_app.add_middleware(AccessLogMiddleware)
for app, _ in versioned_apps:
main_app.mount(f"/{app.version}", app)
return main_app
def build_database(drop_database: bool = False):
create_database(delete_first=drop_database)
SQLModel.metadata.create_all(EngineSingleton().engine, checkfirst=True)
with DbSession() as session:
triggers = create_delete_triggers(AIoDConcept)
sync_triggers = create_identifier_synchronization_triggers()
for trigger in triggers + sync_triggers:
session.execute(trigger)
if DEV_CONFIG.get("disable_reviews", False):
disable_review_process(session)
else:
enable_review_process(session)
existing_platforms = session.scalars(select(Platform)).all()
missing_platforms = set(PlatformName) - {p.name for p in existing_platforms}
if any(missing_platforms):
session.add_all([Platform(name=name) for name in missing_platforms])
session.commit()
def main():
"""Run the application. Placed in a separate function, to avoid having global variables"""
# TODO: unify configuration and environment file? GH#82
# This parsing allows users to see the message on `--help` or incorrect (old) invocations.
msg = (
"Configuration options can be set in the configuration file. "
"Please refer to the documentation pages."
)
argparse.ArgumentParser(description=msg).parse_args()
uvicorn.run(
"main:create_app",
host="0.0.0.0", # noqa: S104 # required to make the interface available outside of docker
reload=DEV_CONFIG.get("reload", False),
factory=True,
)
if __name__ == "__main__":
main()