Skip to content

Commit d2a9326

Browse files
committed
feat(cli): add auto-reload for server command
1 parent 9c590be commit d2a9326

File tree

2 files changed

+96
-4
lines changed

2 files changed

+96
-4
lines changed

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ dependencies = [
1414
"click>=8.1.8",
1515
"rich>=14.0.0",
1616
"python-dotenv>=1.1.0",
17+
"watchfiles>=1.1.0",
1718
]
1819
license = "Apache-2.0"
1920
urls = { Homepage = "https://cocoindex.io/" }

python/cocoindex/cli.py

Lines changed: 95 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,24 @@
1+
import asyncio
12
import atexit
23
import datetime
34
import importlib.util
45
import os
6+
import signal
57
import sys
8+
import threading
69
import types
10+
from types import FrameType
711
from typing import Any
812

913
import click
14+
import watchfiles
1015
from dotenv import find_dotenv, load_dotenv
1116
from rich.console import Console
1217
from rich.panel import Panel
1318
from rich.table import Table
1419

1520
from . import flow, lib, setting
21+
from .runtime import execution_context
1622
from .setup import apply_setup_changes, drop_setup, flow_names_with_setup, sync_setup
1723

1824
# Create ServerSettings lazily upon first call, as environment variables may be loaded from files, etc.
@@ -116,6 +122,12 @@ def _load_user_app(app_target: str) -> types.ModuleType:
116122
)
117123

118124

125+
def _initialize_cocoindex_in_process() -> None:
126+
settings = setting.Settings.from_env()
127+
lib.init(settings)
128+
atexit.register(lib.stop)
129+
130+
119131
@click.group()
120132
@click.version_option(package_name="cocoindex", message="%(prog)s version %(version)s")
121133
@click.option(
@@ -139,9 +151,7 @@ def cli(env_file: str | None = None) -> None:
139151
click.echo(f"Loaded environment variables from: {loaded_env_path}", err=True)
140152

141153
try:
142-
settings = setting.Settings.from_env()
143-
lib.init(settings)
144-
atexit.register(lib.stop)
154+
_initialize_cocoindex_in_process()
145155
except Exception as e:
146156
raise click.ClickException(f"Failed to initialize CocoIndex library: {e}")
147157

@@ -485,6 +495,14 @@ def evaluate(
485495
default=False,
486496
help="Avoid printing anything to the standard output, e.g. statistics.",
487497
)
498+
@click.option(
499+
"-r",
500+
"--reload",
501+
is_flag=True,
502+
show_default=True,
503+
default=False,
504+
help="Enable auto-reload on code changes.",
505+
)
488506
def server(
489507
app_target: str,
490508
address: str | None,
@@ -493,6 +511,7 @@ def server(
493511
cors_origin: str | None,
494512
cors_cocoindex: bool,
495513
cors_local: int | None,
514+
reload: bool,
496515
) -> None:
497516
"""
498517
Start a HTTP server providing REST APIs.
@@ -502,6 +521,65 @@ def server(
502521
APP_TARGET: path/to/app.py or installed_module.
503522
"""
504523
app_ref = _get_app_ref_from_specifier(app_target)
524+
525+
if reload:
526+
watch_paths = {os.getcwd()}
527+
if os.path.isfile(app_ref):
528+
watch_paths.add(os.path.dirname(os.path.abspath(app_ref)))
529+
else:
530+
try:
531+
spec = importlib.util.find_spec(app_ref)
532+
if spec and spec.origin:
533+
watch_paths.add(os.path.dirname(os.path.abspath(spec.origin)))
534+
except ImportError:
535+
pass
536+
537+
watchfiles.run_process(
538+
*watch_paths,
539+
target=_reloadable_server_target,
540+
args=(
541+
app_ref,
542+
address,
543+
cors_origin,
544+
cors_cocoindex,
545+
cors_local,
546+
live_update,
547+
quiet,
548+
),
549+
watch_filter=watchfiles.PythonFilter(),
550+
callback=lambda changes: click.secho(
551+
f"\nDetected changes in {len(changes)} file(s), reloading server...\n",
552+
fg="cyan",
553+
),
554+
)
555+
else:
556+
_run_server(
557+
app_ref,
558+
address=address,
559+
cors_origin=cors_origin,
560+
cors_cocoindex=cors_cocoindex,
561+
cors_local=cors_local,
562+
live_update=live_update,
563+
quiet=quiet,
564+
)
565+
566+
567+
def _reloadable_server_target(*args: Any, **kwargs: Any) -> None:
568+
"""Reloadable target for the watchfiles process."""
569+
_initialize_cocoindex_in_process()
570+
_run_server(*args, **kwargs)
571+
572+
573+
def _run_server(
574+
app_ref: str,
575+
address: str | None = None,
576+
cors_origin: str | None = None,
577+
cors_cocoindex: bool = False,
578+
cors_local: int | None = None,
579+
live_update: bool = False,
580+
quiet: bool = False,
581+
) -> None:
582+
"""Helper function to run the server with specified settings."""
505583
_load_user_app(app_ref)
506584

507585
server_settings = setting.ServerSettings.from_env()
@@ -525,7 +603,20 @@ def server(
525603
if live_update:
526604
options = flow.FlowLiveUpdaterOptions(live_mode=True, print_stats=not quiet)
527605
flow.update_all_flows(options)
528-
input("Press Enter to stop...")
606+
607+
click.secho("Press Ctrl+C to stop the server.", fg="yellow")
608+
609+
shutdown_event = threading.Event()
610+
611+
def handle_signal(signum: int, frame: FrameType | None) -> None:
612+
shutdown_event.set()
613+
614+
async def _wait_for_shutdown_signal() -> None:
615+
await asyncio.to_thread(shutdown_event.wait)
616+
617+
signal.signal(signal.SIGINT, handle_signal)
618+
signal.signal(signal.SIGTERM, handle_signal)
619+
execution_context.run(_wait_for_shutdown_signal())
529620

530621

531622
def _flow_name(name: str | None) -> str:

0 commit comments

Comments
 (0)