66
77import argparse
88import asyncio
9+ import concurrent .futures
910import functools
1011import inspect
1112import json
5051 request_provider_data_context ,
5152 user_from_scope ,
5253)
53- from llama_stack .core .resolver import InvalidProviderError
5454from llama_stack .core .server .routes import (
5555 find_matching_route ,
5656 get_all_api_routes ,
5757 initialize_route_impls ,
5858)
5959from llama_stack .core .stack import (
60+ Stack ,
6061 cast_image_name_to_string ,
61- construct_stack ,
6262 replace_env_vars ,
63- shutdown_stack ,
6463 validate_env_pair ,
6564)
6665from llama_stack .core .utils .config import redact_sensitive_fields
@@ -156,21 +155,34 @@ def translate_exception(exc: Exception) -> HTTPException | RequestValidationErro
156155 )
157156
158157
159- async def shutdown (app ):
160- """Initiate a graceful shutdown of the application.
161-
162- Handled by the lifespan context manager. The shutdown process involves
163- shutting down all implementations registered in the application.
158+ class StackApp (FastAPI ):
159+ """
160+ A wrapper around the FastAPI application to hold a reference to the Stack instance so that we can
161+ start background tasks (e.g. refresh model registry periodically) from the lifespan context manager.
164162 """
165- await shutdown_stack (app .__llama_stack_impls__ )
163+
164+ def __init__ (self , config : StackRunConfig , * args , ** kwargs ):
165+ super ().__init__ (* args , ** kwargs )
166+ self .stack : Stack = Stack (config )
167+
168+ # This code is called from a running event loop managed by uvicorn so we cannot simply call
169+ # asyncio.run() to initialize the stack. We cannot await either since this is not an async
170+ # function.
171+ # As a workaround, we use a thread pool executor to run the initialize() method
172+ # in a separate thread.
173+ with concurrent .futures .ThreadPoolExecutor () as executor :
174+ future = executor .submit (asyncio .run , self .stack .initialize ())
175+ future .result ()
166176
167177
168178@asynccontextmanager
169- async def lifespan (app : FastAPI ):
179+ async def lifespan (app : StackApp ):
170180 logger .info ("Starting up" )
181+ assert app .stack is not None
182+ app .stack .create_registry_refresh_task ()
171183 yield
172184 logger .info ("Shutting down" )
173- await shutdown (app )
185+ await app . stack . shutdown ()
174186
175187
176188def is_streaming_request (func_name : str , request : Request , ** kwargs ):
@@ -386,73 +398,61 @@ async def send_version_error(send):
386398 return await self .app (scope , receive , send )
387399
388400
389- def main (args : argparse .Namespace | None = None ):
390- """Start the LlamaStack server."""
391- parser = argparse .ArgumentParser (description = "Start the LlamaStack server." )
401+ def create_app (
402+ config_file : str | None = None ,
403+ env_vars : list [str ] | None = None ,
404+ ) -> StackApp :
405+ """Create and configure the FastAPI application.
392406
393- add_config_distro_args (parser )
394- parser .add_argument (
395- "--port" ,
396- type = int ,
397- default = int (os .getenv ("LLAMA_STACK_PORT" , 8321 )),
398- help = "Port to listen on" ,
399- )
400- parser .add_argument (
401- "--env" ,
402- action = "append" ,
403- help = "Environment variables in KEY=value format. Can be specified multiple times." ,
404- )
407+ Args:
408+ config_file: Path to config file. If None, uses LLAMA_STACK_CONFIG env var or default resolution.
409+ env_vars: List of environment variables in KEY=value format.
410+ disable_version_check: Whether to disable version checking. If None, uses LLAMA_STACK_DISABLE_VERSION_CHECK env var.
405411
406- # Determine whether the server args are being passed by the "run" command, if this is the case
407- # the args will be passed as a Namespace object to the main function, otherwise they will be
408- # parsed from the command line
409- if args is None :
410- args = parser .parse_args ()
412+ Returns:
413+ Configured StackApp instance.
414+ """
415+ config_file = config_file or os .getenv ("LLAMA_STACK_CONFIG" )
416+ if config_file is None :
417+ raise ValueError ("No config file provided and LLAMA_STACK_CONFIG env var is not set" )
411418
412- config_or_distro = get_config_from_args (args )
413- config_file = resolve_config_or_distro (config_or_distro , Mode .RUN )
419+ config_file = resolve_config_or_distro (config_file , Mode .RUN )
414420
421+ # Load and process configuration
415422 logger_config = None
416423 with open (config_file ) as fp :
417424 config_contents = yaml .safe_load (fp )
418425 if isinstance (config_contents , dict ) and (cfg := config_contents .get ("logging_config" )):
419426 logger_config = LoggingConfig (** cfg )
420427 logger = get_logger (name = __name__ , category = "core::server" , config = logger_config )
421- if args .env :
422- for env_pair in args .env :
428+
429+ if env_vars :
430+ for env_pair in env_vars :
423431 try :
424432 key , value = validate_env_pair (env_pair )
425- logger .info (f"Setting CLI environment variable { key } => { value } " )
433+ logger .info (f"Setting environment variable { key } => { value } " )
426434 os .environ [key ] = value
427435 except ValueError as e :
428436 logger .error (f"Error: { str (e )} " )
429- sys .exit (1 )
437+ raise ValueError (f"Invalid environment variable format: { env_pair } " ) from e
438+
430439 config = replace_env_vars (config_contents )
431440 config = StackRunConfig (** cast_image_name_to_string (config ))
432441
433442 _log_run_config (run_config = config )
434443
435- app = FastAPI (
444+ app = StackApp (
436445 lifespan = lifespan ,
437446 docs_url = "/docs" ,
438447 redoc_url = "/redoc" ,
439448 openapi_url = "/openapi.json" ,
449+ config = config ,
440450 )
441451
442452 if not os .environ .get ("LLAMA_STACK_DISABLE_VERSION_CHECK" ):
443453 app .add_middleware (ClientVersionMiddleware )
444454
445- try :
446- # Create and set the event loop that will be used for both construction and server runtime
447- loop = asyncio .new_event_loop ()
448- asyncio .set_event_loop (loop )
449-
450- # Construct the stack in the persistent event loop
451- impls = loop .run_until_complete (construct_stack (config ))
452-
453- except InvalidProviderError as e :
454- logger .error (f"Error: { str (e )} " )
455- sys .exit (1 )
455+ impls = app .stack .impls
456456
457457 if config .server .auth :
458458 logger .info (f"Enabling authentication with provider: { config .server .auth .provider_config .type .value } " )
@@ -553,9 +553,54 @@ def main(args: argparse.Namespace | None = None):
553553 app .exception_handler (RequestValidationError )(global_exception_handler )
554554 app .exception_handler (Exception )(global_exception_handler )
555555
556- app .__llama_stack_impls__ = impls
557556 app .add_middleware (TracingMiddleware , impls = impls , external_apis = external_apis )
558557
558+ return app
559+
560+
561+ def main (args : argparse .Namespace | None = None ):
562+ """Start the LlamaStack server."""
563+ parser = argparse .ArgumentParser (description = "Start the LlamaStack server." )
564+
565+ add_config_distro_args (parser )
566+ parser .add_argument (
567+ "--port" ,
568+ type = int ,
569+ default = int (os .getenv ("LLAMA_STACK_PORT" , 8321 )),
570+ help = "Port to listen on" ,
571+ )
572+ parser .add_argument (
573+ "--env" ,
574+ action = "append" ,
575+ help = "Environment variables in KEY=value format. Can be specified multiple times." ,
576+ )
577+
578+ # Determine whether the server args are being passed by the "run" command, if this is the case
579+ # the args will be passed as a Namespace object to the main function, otherwise they will be
580+ # parsed from the command line
581+ if args is None :
582+ args = parser .parse_args ()
583+
584+ config_or_distro = get_config_from_args (args )
585+
586+ try :
587+ app = create_app (
588+ config_file = config_or_distro ,
589+ env_vars = args .env ,
590+ )
591+ except Exception as e :
592+ logger .error (f"Error creating app: { str (e )} " )
593+ sys .exit (1 )
594+
595+ config_file = resolve_config_or_distro (config_or_distro , Mode .RUN )
596+ with open (config_file ) as fp :
597+ config_contents = yaml .safe_load (fp )
598+ if isinstance (config_contents , dict ) and (cfg := config_contents .get ("logging_config" )):
599+ logger_config = LoggingConfig (** cfg )
600+ else :
601+ logger_config = None
602+ config = StackRunConfig (** cast_image_name_to_string (replace_env_vars (config_contents )))
603+
559604 import uvicorn
560605
561606 # Configure SSL if certificates are provided
@@ -593,7 +638,6 @@ def main(args: argparse.Namespace | None = None):
593638 if ssl_config :
594639 uvicorn_config .update (ssl_config )
595640
596- # Run uvicorn in the existing event loop to preserve background tasks
597641 # We need to catch KeyboardInterrupt because uvicorn's signal handling
598642 # re-raises SIGINT signals using signal.raise_signal(), which Python
599643 # converts to KeyboardInterrupt. Without this catch, we'd get a confusing
@@ -604,13 +648,9 @@ def main(args: argparse.Namespace | None = None):
604648 # Another approach would be to ignore SIGINT entirely - let uvicorn handle it through its own
605649 # signal handling but this is quite intrusive and not worth the effort.
606650 try :
607- loop . run_until_complete (uvicorn .Server (uvicorn .Config (** uvicorn_config )).serve ())
651+ asyncio . run (uvicorn .Server (uvicorn .Config (** uvicorn_config )).serve ())
608652 except (KeyboardInterrupt , SystemExit ):
609653 logger .info ("Received interrupt signal, shutting down gracefully..." )
610- finally :
611- if not loop .is_closed ():
612- logger .debug ("Closing event loop" )
613- loop .close ()
614654
615655
616656def _log_run_config (run_config : StackRunConfig ):
0 commit comments