1010
1111import pkg_resources
1212import uvicorn
13- from fastapi import Depends , FastAPI
13+ from fastapi import Depends , FastAPI , HTTPException
1414from fastapi .responses import HTMLResponse
1515from sqlmodel import select
1616
2222from database .model .platform .platform_names import PlatformName
2323from database .session import EngineSingleton , DbSession
2424from database .setup import create_database , database_exists
25+ from error_handling import http_exception_handler
2526from routers import resource_routers , parent_routers , enum_routers , uploader_routers
2627from routers import search_routers
2728from setup_logger import setup_logger
@@ -101,17 +102,33 @@ def create_app() -> FastAPI:
101102 """Create the FastAPI application, complete with routes."""
102103 setup_logger ()
103104 args = _parse_args ()
105+ if args .build_db == "never" :
106+ if not database_exists ():
107+ logging .warning (
108+ "AI-on-Demand database does not exist on the MySQL server, "
109+ "but `build_db` is set to 'never'. If you are not creating the "
110+ "database through other means, such as MySQL group replication, "
111+ "this likely means that you will get errors or undefined behavior."
112+ )
113+ else :
114+ build_database (args )
115+
104116 pyproject_toml = pkg_resources .get_distribution ("aiod_metadata_catalogue" )
117+ app = build_app (args .url_prefix , pyproject_toml .version )
118+ return app
119+
120+
121+ def build_app (url_prefix : str = "" , version : str = "dev" ):
105122 app = FastAPI (
106- openapi_url = f"{ args . url_prefix } /openapi.json" ,
107- docs_url = f"{ args . url_prefix } /docs" ,
123+ openapi_url = f"{ url_prefix } /openapi.json" ,
124+ docs_url = f"{ url_prefix } /docs" ,
108125 title = "AIoD Metadata Catalogue" ,
109126 description = "This is the Swagger documentation of the AIoD Metadata Catalogue. For the "
110127 "Changelog, refer to "
111128 '<a href="https://github.com/aiondemand/AIOD-rest-api/releases">https'
112129 "://github.com/aiondemand/AIOD-rest-api/releases</a>." ,
113- version = pyproject_toml . version ,
114- swagger_ui_oauth2_redirect_url = f"{ args . url_prefix } /docs/oauth2-redirect" ,
130+ version = version ,
131+ swagger_ui_oauth2_redirect_url = f"{ url_prefix } /docs/oauth2-redirect" ,
115132 swagger_ui_init_oauth = {
116133 "clientId" : KEYCLOAK_CONFIG .get ("client_id_swagger" ),
117134 "realm" : KEYCLOAK_CONFIG .get ("realm" ),
@@ -120,32 +137,25 @@ def create_app() -> FastAPI:
120137 "scopes" : KEYCLOAK_CONFIG .get ("scopes" ),
121138 },
122139 )
123- if args .build_db == "never" :
124- if not database_exists ():
125- logging .warning (
126- "AI-on-Demand database does not exist on the MySQL server, "
127- "but `build_db` is set to 'never'. If you are not creating the "
128- "database through other means, such as MySQL group replication, "
129- "this likely means that you will get errors or undefined behavior."
130- )
131- else :
132-
133- drop_database = args .build_db == "drop-then-build"
134- create_database (delete_first = drop_database )
135- AIoDConcept .metadata .create_all (EngineSingleton ().engine , checkfirst = True )
136- with DbSession () as session :
137- triggers = create_delete_triggers (AIoDConcept )
138- for trigger in triggers :
139- session .execute (trigger )
140- existing_platforms = session .scalars (select (Platform )).all ()
141- if not any (existing_platforms ):
142- session .add_all ([Platform (name = name ) for name in PlatformName ])
143- session .commit ()
144-
145- add_routes (app , url_prefix = args .url_prefix )
140+ add_routes (app , url_prefix = url_prefix )
141+ app .add_exception_handler (HTTPException , http_exception_handler )
146142 return app
147143
148144
145+ def build_database (args ):
146+ drop_database = args .build_db == "drop-then-build"
147+ create_database (delete_first = drop_database )
148+ AIoDConcept .metadata .create_all (EngineSingleton ().engine , checkfirst = True )
149+ with DbSession () as session :
150+ triggers = create_delete_triggers (AIoDConcept )
151+ for trigger in triggers :
152+ session .execute (trigger )
153+ existing_platforms = session .scalars (select (Platform )).all ()
154+ if not any (existing_platforms ):
155+ session .add_all ([Platform (name = name ) for name in PlatformName ])
156+ session .commit ()
157+
158+
149159def main ():
150160 """Run the application. Placed in a separate function, to avoid having global variables"""
151161 args = _parse_args ()
0 commit comments