44import uvicorn
55from dataclasses import asdict , is_dataclass
66from enum import Enum
7- from fastapi import FastAPI , Depends , HTTPException , Request
7+ from fastapi import FastAPI , Depends , HTTPException , Request , status
88from fastapi .responses import StreamingResponse , FileResponse
9+ from fastapi .security .api_key import APIKeyQuery
10+ from fastapi import Security
911from typing import Annotated
1012
1113from .socket_address import parse_host_and_port
2325
2426
2527class HttpServer (InfoContext ):
26- def __init__ (self , listen_address = f"{ DEFAULT_HOST } :{ DEFAULT_PORT } " , app : FastAPI = None ):
28+ def __init__ (self , listen_address = f"{ DEFAULT_HOST } :{ DEFAULT_PORT } " , app : FastAPI = None , api = None ):
2729 from ..logging .logging import get_logger
2830
2931 self .logger = get_logger ("http" )
3032 self .node = None
3133 self .host = DEFAULT_HOST
3234 self .port = DEFAULT_PORT
3335 self .set_host_and_port (listen_address )
36+ self .api = api
3437
3538 if not app :
36- self .app = FastAPI ()
39+ self .app = FastAPI (dependencies = [ Depends ( validate_api_key )] )
3740 else :
3841 self .app = app
39-
4042 self .app .middleware ("http" )(self .inject_node )
4143 self ._setup_routes ()
4244
@@ -232,6 +234,11 @@ def set_host_and_port(self, listen_address):
232234
233235 async def start (self , node ):
234236 self .node = node
237+ # mount some additional if provided custom routes
238+ if self .api :
239+ self .api .routes (self .node )
240+ self .app .mount ("/" , self .api .api )
241+
235242 asyncio .create_task (self .serve ())
236243
237244
@@ -250,3 +257,10 @@ def default(obj):
250257 if isinstance (obj , Enum ):
251258 return obj .value
252259 raise TypeError (f"Object of type { obj .__class__ .__name__ } is not JSON serializable" )
260+
261+
262+ def validate_api_key (api_key : str = Security (APIKeyQuery (name = "api_key" , auto_error = False ))) -> str :
263+ expected = os .environ .get ("API_KEY" )
264+ if api_key != expected :
265+ raise HTTPException (status_code = status .HTTP_401_UNAUTHORIZED , detail = "Invalid or missing API key header" )
266+ return expected
0 commit comments