Skip to content

Commit 7c88648

Browse files
authored
feat(cli): add auto-reload for server command (#664)
* feat(cli): add auto-reload for server command * refactor(cli): simplify server function argument handling
1 parent a7dfea3 commit 7c88648

File tree

2 files changed

+84
-4
lines changed

2 files changed

+84
-4
lines changed

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ dependencies = [
1414
"click>=8.1.8",
1515
"rich>=14.0.0",
1616
"python-dotenv>=1.1.0",
17+
"watchfiles>=1.1.0",
1718
]
1819
license = "Apache-2.0"
1920
urls = { Homepage = "https://cocoindex.io/" }

python/cocoindex/cli.py

Lines changed: 83 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,15 @@
22
import datetime
33
import importlib.util
44
import os
5+
import signal
56
import sys
7+
import threading
68
import types
9+
from types import FrameType
710
from typing import Any
811

912
import click
13+
import watchfiles
1014
from dotenv import find_dotenv, load_dotenv
1115
from rich.console import Console
1216
from rich.panel import Panel
@@ -116,6 +120,12 @@ def _load_user_app(app_target: str) -> types.ModuleType:
116120
)
117121

118122

123+
def _initialize_cocoindex_in_process() -> None:
124+
settings = setting.Settings.from_env()
125+
lib.init(settings)
126+
atexit.register(lib.stop)
127+
128+
119129
@click.group()
120130
@click.version_option(package_name="cocoindex", message="%(prog)s version %(version)s")
121131
@click.option(
@@ -139,9 +149,7 @@ def cli(env_file: str | None = None) -> None:
139149
click.echo(f"Loaded environment variables from: {loaded_env_path}", err=True)
140150

141151
try:
142-
settings = setting.Settings.from_env()
143-
lib.init(settings)
144-
atexit.register(lib.stop)
152+
_initialize_cocoindex_in_process()
145153
except Exception as e:
146154
raise click.ClickException(f"Failed to initialize CocoIndex library: {e}")
147155

@@ -485,6 +493,14 @@ def evaluate(
485493
default=False,
486494
help="Avoid printing anything to the standard output, e.g. statistics.",
487495
)
496+
@click.option(
497+
"-r",
498+
"--reload",
499+
is_flag=True,
500+
show_default=True,
501+
default=False,
502+
help="Enable auto-reload on code changes.",
503+
)
488504
def server(
489505
app_target: str,
490506
address: str | None,
@@ -493,6 +509,7 @@ def server(
493509
cors_origin: str | None,
494510
cors_cocoindex: bool,
495511
cors_local: int | None,
512+
reload: bool,
496513
) -> None:
497514
"""
498515
Start a HTTP server providing REST APIs.
@@ -502,6 +519,58 @@ def server(
502519
APP_TARGET: path/to/app.py or installed_module.
503520
"""
504521
app_ref = _get_app_ref_from_specifier(app_target)
522+
args = (
523+
app_ref,
524+
address,
525+
cors_origin,
526+
cors_cocoindex,
527+
cors_local,
528+
live_update,
529+
quiet,
530+
)
531+
532+
if reload:
533+
watch_paths = {os.getcwd()}
534+
if os.path.isfile(app_ref):
535+
watch_paths.add(os.path.dirname(os.path.abspath(app_ref)))
536+
else:
537+
try:
538+
spec = importlib.util.find_spec(app_ref)
539+
if spec and spec.origin:
540+
watch_paths.add(os.path.dirname(os.path.abspath(spec.origin)))
541+
except ImportError:
542+
pass
543+
544+
watchfiles.run_process(
545+
*watch_paths,
546+
target=_reloadable_server_target,
547+
args=args,
548+
watch_filter=watchfiles.PythonFilter(),
549+
callback=lambda changes: click.secho(
550+
f"\nDetected changes in {len(changes)} file(s), reloading server...\n",
551+
fg="cyan",
552+
),
553+
)
554+
else:
555+
_run_server(*args)
556+
557+
558+
def _reloadable_server_target(*args: Any, **kwargs: Any) -> None:
559+
"""Reloadable target for the watchfiles process."""
560+
_initialize_cocoindex_in_process()
561+
_run_server(*args, **kwargs)
562+
563+
564+
def _run_server(
565+
app_ref: str,
566+
address: str | None = None,
567+
cors_origin: str | None = None,
568+
cors_cocoindex: bool = False,
569+
cors_local: int | None = None,
570+
live_update: bool = False,
571+
quiet: bool = False,
572+
) -> None:
573+
"""Helper function to run the server with specified settings."""
505574
_load_user_app(app_ref)
506575

507576
server_settings = setting.ServerSettings.from_env()
@@ -525,7 +594,17 @@ def server(
525594
if live_update:
526595
options = flow.FlowLiveUpdaterOptions(live_mode=True, print_stats=not quiet)
527596
flow.update_all_flows(options)
528-
input("Press Enter to stop...")
597+
598+
click.secho("Press Ctrl+C to stop the server.", fg="yellow")
599+
600+
shutdown_event = threading.Event()
601+
602+
def handle_signal(signum: int, frame: FrameType | None) -> None:
603+
shutdown_event.set()
604+
605+
signal.signal(signal.SIGINT, handle_signal)
606+
signal.signal(signal.SIGTERM, handle_signal)
607+
shutdown_event.wait()
529608

530609

531610
def _flow_name(name: str | None) -> str:

0 commit comments

Comments
 (0)