Skip to content

Commit d688d60

Browse files
ehhuangiamemilio
authored andcommitted
chore: refactor server.main (llamastack#3462)
# What does this PR do? As shown in llamastack#3421, we can scale stack to handle more RPS with k8s replicas. This PR enables multi process stack with uvicorn --workers so that we can achieve the same scaling without being in k8s. To achieve that we refactor main to split out the app construction logic. This method needs to be non-async. We created a new `Stack` class to house impls and have a `start()` method to be called in lifespan to start background tasks instead of starting them in the old `construct_stack`. This way we avoid having to manage an event loop manually. ## Test Plan CI > uv run --with llama-stack python -m llama_stack.core.server.server benchmarking/k8s-benchmark/stack_run_config.yaml works. > LLAMA_STACK_CONFIG=benchmarking/k8s-benchmark/stack_run_config.yaml uv run uvicorn llama_stack.core.server.server:create_app --port 8321 --workers 4 works.
1 parent 44eea8c commit d688d60

File tree

7 files changed

+225
-138
lines changed

7 files changed

+225
-138
lines changed

benchmarking/k8s-benchmark/apply.sh

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,8 @@ export POSTGRES_PASSWORD=llamastack
1717
export INFERENCE_MODEL=meta-llama/Llama-3.2-3B-Instruct
1818
export SAFETY_MODEL=meta-llama/Llama-Guard-3-1B
1919

20-
export MOCK_INFERENCE_MODEL=mock-inference
21-
22-
export MOCK_INFERENCE_URL=openai-mock-service:8080
23-
2420
export BENCHMARK_INFERENCE_MODEL=$INFERENCE_MODEL
21+
export LLAMA_STACK_WORKERS=4
2522

2623
set -euo pipefail
2724
set -x

benchmarking/k8s-benchmark/stack-configmap.yaml

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ data:
55
image_name: kubernetes-benchmark-demo
66
apis:
77
- agents
8+
- files
89
- inference
910
- files
1011
- safety
@@ -23,6 +24,14 @@ data:
2324
- provider_id: sentence-transformers
2425
provider_type: inline::sentence-transformers
2526
config: {}
27+
files:
28+
- provider_id: meta-reference-files
29+
provider_type: inline::localfs
30+
config:
31+
storage_dir: ${env.FILES_STORAGE_DIR:=~/.llama/distributions/starter/files}
32+
metadata_store:
33+
type: sqlite
34+
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/files_metadata.db
2635
vector_io:
2736
- provider_id: ${env.ENABLE_CHROMADB:+chromadb}
2837
provider_type: remote::chromadb

benchmarking/k8s-benchmark/stack-k8s.yaml.template

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,9 +52,20 @@ spec:
5252
value: http://vllm-server-safety.default.svc.cluster.local:8001/v1
5353
- name: VLLM_TLS_VERIFY
5454
value: "false"
55-
command: ["python", "-m", "llama_stack.core.server.server", "/etc/config/stack_run_config.yaml", "--port", "8323"]
55+
- name: LLAMA_STACK_LOGGING
56+
value: "all=WARNING"
57+
- name: LLAMA_STACK_CONFIG
58+
value: "/etc/config/stack_run_config.yaml"
59+
- name: LLAMA_STACK_WORKERS
60+
value: "${LLAMA_STACK_WORKERS}"
61+
command: ["uvicorn", "llama_stack.core.server.server:create_app", "--host", "0.0.0.0", "--port", "8323", "--workers", "$LLAMA_STACK_WORKERS", "--factory"]
5662
ports:
5763
- containerPort: 8323
64+
resources:
65+
requests:
66+
cpu: "${LLAMA_STACK_WORKERS}"
67+
limits:
68+
cpu: "${LLAMA_STACK_WORKERS}"
5869
volumeMounts:
5970
- name: llama-storage
6071
mountPath: /root/.llama

llama_stack/core/library_client.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@
4040
from llama_stack.core.resolver import ProviderRegistry
4141
from llama_stack.core.server.routes import RouteImpls, find_matching_route, initialize_route_impls
4242
from llama_stack.core.stack import (
43-
construct_stack,
43+
Stack,
4444
get_stack_run_config_from_distro,
4545
replace_env_vars,
4646
)
@@ -252,7 +252,10 @@ async def initialize(self) -> bool:
252252

253253
try:
254254
self.route_impls = None
255-
self.impls = await construct_stack(self.config, self.custom_provider_registry)
255+
256+
stack = Stack(self.config, self.custom_provider_registry)
257+
await stack.initialize()
258+
self.impls = stack.impls
256259
except ModuleNotFoundError as _e:
257260
cprint(_e.msg, color="red", file=sys.stderr)
258261
cprint(
@@ -289,6 +292,7 @@ async def initialize(self) -> bool:
289292
)
290293
raise _e
291294

295+
assert self.impls is not None
292296
if Api.telemetry in self.impls:
293297
setup_logger(self.impls[Api.telemetry])
294298

llama_stack/core/server/server.py

Lines changed: 96 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
import argparse
88
import asyncio
9+
import concurrent.futures
910
import functools
1011
import inspect
1112
import json
@@ -50,17 +51,15 @@
5051
request_provider_data_context,
5152
user_from_scope,
5253
)
53-
from llama_stack.core.resolver import InvalidProviderError
5454
from llama_stack.core.server.routes import (
5555
find_matching_route,
5656
get_all_api_routes,
5757
initialize_route_impls,
5858
)
5959
from llama_stack.core.stack import (
60+
Stack,
6061
cast_image_name_to_string,
61-
construct_stack,
6262
replace_env_vars,
63-
shutdown_stack,
6463
validate_env_pair,
6564
)
6665
from llama_stack.core.utils.config import redact_sensitive_fields
@@ -156,21 +155,34 @@ def translate_exception(exc: Exception) -> HTTPException | RequestValidationErro
156155
)
157156

158157

159-
async def shutdown(app):
160-
"""Initiate a graceful shutdown of the application.
161-
162-
Handled by the lifespan context manager. The shutdown process involves
163-
shutting down all implementations registered in the application.
158+
class StackApp(FastAPI):
159+
"""
160+
A wrapper around the FastAPI application to hold a reference to the Stack instance so that we can
161+
start background tasks (e.g. refresh model registry periodically) from the lifespan context manager.
164162
"""
165-
await shutdown_stack(app.__llama_stack_impls__)
163+
164+
def __init__(self, config: StackRunConfig, *args, **kwargs):
165+
super().__init__(*args, **kwargs)
166+
self.stack: Stack = Stack(config)
167+
168+
# This code is called from a running event loop managed by uvicorn so we cannot simply call
169+
# asyncio.run() to initialize the stack. We cannot await either since this is not an async
170+
# function.
171+
# As a workaround, we use a thread pool executor to run the initialize() method
172+
# in a separate thread.
173+
with concurrent.futures.ThreadPoolExecutor() as executor:
174+
future = executor.submit(asyncio.run, self.stack.initialize())
175+
future.result()
166176

167177

168178
@asynccontextmanager
169-
async def lifespan(app: FastAPI):
179+
async def lifespan(app: StackApp):
170180
logger.info("Starting up")
181+
assert app.stack is not None
182+
app.stack.create_registry_refresh_task()
171183
yield
172184
logger.info("Shutting down")
173-
await shutdown(app)
185+
await app.stack.shutdown()
174186

175187

176188
def is_streaming_request(func_name: str, request: Request, **kwargs):
@@ -386,73 +398,61 @@ async def send_version_error(send):
386398
return await self.app(scope, receive, send)
387399

388400

389-
def main(args: argparse.Namespace | None = None):
390-
"""Start the LlamaStack server."""
391-
parser = argparse.ArgumentParser(description="Start the LlamaStack server.")
401+
def create_app(
402+
config_file: str | None = None,
403+
env_vars: list[str] | None = None,
404+
) -> StackApp:
405+
"""Create and configure the FastAPI application.
392406
393-
add_config_distro_args(parser)
394-
parser.add_argument(
395-
"--port",
396-
type=int,
397-
default=int(os.getenv("LLAMA_STACK_PORT", 8321)),
398-
help="Port to listen on",
399-
)
400-
parser.add_argument(
401-
"--env",
402-
action="append",
403-
help="Environment variables in KEY=value format. Can be specified multiple times.",
404-
)
407+
Args:
408+
config_file: Path to config file. If None, uses LLAMA_STACK_CONFIG env var or default resolution.
409+
env_vars: List of environment variables in KEY=value format.
410+
disable_version_check: Whether to disable version checking. If None, uses LLAMA_STACK_DISABLE_VERSION_CHECK env var.
405411
406-
# Determine whether the server args are being passed by the "run" command, if this is the case
407-
# the args will be passed as a Namespace object to the main function, otherwise they will be
408-
# parsed from the command line
409-
if args is None:
410-
args = parser.parse_args()
412+
Returns:
413+
Configured StackApp instance.
414+
"""
415+
config_file = config_file or os.getenv("LLAMA_STACK_CONFIG")
416+
if config_file is None:
417+
raise ValueError("No config file provided and LLAMA_STACK_CONFIG env var is not set")
411418

412-
config_or_distro = get_config_from_args(args)
413-
config_file = resolve_config_or_distro(config_or_distro, Mode.RUN)
419+
config_file = resolve_config_or_distro(config_file, Mode.RUN)
414420

421+
# Load and process configuration
415422
logger_config = None
416423
with open(config_file) as fp:
417424
config_contents = yaml.safe_load(fp)
418425
if isinstance(config_contents, dict) and (cfg := config_contents.get("logging_config")):
419426
logger_config = LoggingConfig(**cfg)
420427
logger = get_logger(name=__name__, category="core::server", config=logger_config)
421-
if args.env:
422-
for env_pair in args.env:
428+
429+
if env_vars:
430+
for env_pair in env_vars:
423431
try:
424432
key, value = validate_env_pair(env_pair)
425-
logger.info(f"Setting CLI environment variable {key} => {value}")
433+
logger.info(f"Setting environment variable {key} => {value}")
426434
os.environ[key] = value
427435
except ValueError as e:
428436
logger.error(f"Error: {str(e)}")
429-
sys.exit(1)
437+
raise ValueError(f"Invalid environment variable format: {env_pair}") from e
438+
430439
config = replace_env_vars(config_contents)
431440
config = StackRunConfig(**cast_image_name_to_string(config))
432441

433442
_log_run_config(run_config=config)
434443

435-
app = FastAPI(
444+
app = StackApp(
436445
lifespan=lifespan,
437446
docs_url="/docs",
438447
redoc_url="/redoc",
439448
openapi_url="/openapi.json",
449+
config=config,
440450
)
441451

442452
if not os.environ.get("LLAMA_STACK_DISABLE_VERSION_CHECK"):
443453
app.add_middleware(ClientVersionMiddleware)
444454

445-
try:
446-
# Create and set the event loop that will be used for both construction and server runtime
447-
loop = asyncio.new_event_loop()
448-
asyncio.set_event_loop(loop)
449-
450-
# Construct the stack in the persistent event loop
451-
impls = loop.run_until_complete(construct_stack(config))
452-
453-
except InvalidProviderError as e:
454-
logger.error(f"Error: {str(e)}")
455-
sys.exit(1)
455+
impls = app.stack.impls
456456

457457
if config.server.auth:
458458
logger.info(f"Enabling authentication with provider: {config.server.auth.provider_config.type.value}")
@@ -553,9 +553,54 @@ def main(args: argparse.Namespace | None = None):
553553
app.exception_handler(RequestValidationError)(global_exception_handler)
554554
app.exception_handler(Exception)(global_exception_handler)
555555

556-
app.__llama_stack_impls__ = impls
557556
app.add_middleware(TracingMiddleware, impls=impls, external_apis=external_apis)
558557

558+
return app
559+
560+
561+
def main(args: argparse.Namespace | None = None):
562+
"""Start the LlamaStack server."""
563+
parser = argparse.ArgumentParser(description="Start the LlamaStack server.")
564+
565+
add_config_distro_args(parser)
566+
parser.add_argument(
567+
"--port",
568+
type=int,
569+
default=int(os.getenv("LLAMA_STACK_PORT", 8321)),
570+
help="Port to listen on",
571+
)
572+
parser.add_argument(
573+
"--env",
574+
action="append",
575+
help="Environment variables in KEY=value format. Can be specified multiple times.",
576+
)
577+
578+
# Determine whether the server args are being passed by the "run" command, if this is the case
579+
# the args will be passed as a Namespace object to the main function, otherwise they will be
580+
# parsed from the command line
581+
if args is None:
582+
args = parser.parse_args()
583+
584+
config_or_distro = get_config_from_args(args)
585+
586+
try:
587+
app = create_app(
588+
config_file=config_or_distro,
589+
env_vars=args.env,
590+
)
591+
except Exception as e:
592+
logger.error(f"Error creating app: {str(e)}")
593+
sys.exit(1)
594+
595+
config_file = resolve_config_or_distro(config_or_distro, Mode.RUN)
596+
with open(config_file) as fp:
597+
config_contents = yaml.safe_load(fp)
598+
if isinstance(config_contents, dict) and (cfg := config_contents.get("logging_config")):
599+
logger_config = LoggingConfig(**cfg)
600+
else:
601+
logger_config = None
602+
config = StackRunConfig(**cast_image_name_to_string(replace_env_vars(config_contents)))
603+
559604
import uvicorn
560605

561606
# Configure SSL if certificates are provided
@@ -593,7 +638,6 @@ def main(args: argparse.Namespace | None = None):
593638
if ssl_config:
594639
uvicorn_config.update(ssl_config)
595640

596-
# Run uvicorn in the existing event loop to preserve background tasks
597641
# We need to catch KeyboardInterrupt because uvicorn's signal handling
598642
# re-raises SIGINT signals using signal.raise_signal(), which Python
599643
# converts to KeyboardInterrupt. Without this catch, we'd get a confusing
@@ -604,13 +648,9 @@ def main(args: argparse.Namespace | None = None):
604648
# Another approach would be to ignore SIGINT entirely - let uvicorn handle it through its own
605649
# signal handling but this is quite intrusive and not worth the effort.
606650
try:
607-
loop.run_until_complete(uvicorn.Server(uvicorn.Config(**uvicorn_config)).serve())
651+
asyncio.run(uvicorn.Server(uvicorn.Config(**uvicorn_config)).serve())
608652
except (KeyboardInterrupt, SystemExit):
609653
logger.info("Received interrupt signal, shutting down gracefully...")
610-
finally:
611-
if not loop.is_closed():
612-
logger.debug("Closing event loop")
613-
loop.close()
614654

615655

616656
def _log_run_config(run_config: StackRunConfig):

0 commit comments

Comments
 (0)