1+ import asyncio
12import atexit
23import datetime
34import importlib .util
45import os
6+ import signal
57import sys
8+ import threading
69import types
10+ from types import FrameType
711from typing import Any
812
913import click
14+ import watchfiles
1015from dotenv import find_dotenv , load_dotenv
1116from rich .console import Console
1217from rich .panel import Panel
1318from rich .table import Table
1419
1520from . import flow , lib , setting
21+ from .runtime import execution_context
1622from .setup import apply_setup_changes , drop_setup , flow_names_with_setup , sync_setup
1723
1824# Create ServerSettings lazily upon first call, as environment variables may be loaded from files, etc.
@@ -116,6 +122,12 @@ def _load_user_app(app_target: str) -> types.ModuleType:
116122 )
117123
118124
125+ def _initialize_cocoindex_in_process () -> None :
126+ settings = setting .Settings .from_env ()
127+ lib .init (settings )
128+ atexit .register (lib .stop )
129+
130+
119131@click .group ()
120132@click .version_option (package_name = "cocoindex" , message = "%(prog)s version %(version)s" )
121133@click .option (
@@ -139,9 +151,7 @@ def cli(env_file: str | None = None) -> None:
139151 click .echo (f"Loaded environment variables from: { loaded_env_path } " , err = True )
140152
141153 try :
142- settings = setting .Settings .from_env ()
143- lib .init (settings )
144- atexit .register (lib .stop )
154+ _initialize_cocoindex_in_process ()
145155 except Exception as e :
146156 raise click .ClickException (f"Failed to initialize CocoIndex library: { e } " )
147157
@@ -485,6 +495,14 @@ def evaluate(
485495 default = False ,
486496 help = "Avoid printing anything to the standard output, e.g. statistics." ,
487497)
498+ @click .option (
499+ "-r" ,
500+ "--reload" ,
501+ is_flag = True ,
502+ show_default = True ,
503+ default = False ,
504+ help = "Enable auto-reload on code changes." ,
505+ )
488506def server (
489507 app_target : str ,
490508 address : str | None ,
@@ -493,6 +511,7 @@ def server(
493511 cors_origin : str | None ,
494512 cors_cocoindex : bool ,
495513 cors_local : int | None ,
514+ reload : bool ,
496515) -> None :
497516 """
498517 Start a HTTP server providing REST APIs.
@@ -502,6 +521,65 @@ def server(
502521 APP_TARGET: path/to/app.py or installed_module.
503522 """
504523 app_ref = _get_app_ref_from_specifier (app_target )
524+
525+ if reload :
526+ watch_paths = {os .getcwd ()}
527+ if os .path .isfile (app_ref ):
528+ watch_paths .add (os .path .dirname (os .path .abspath (app_ref )))
529+ else :
530+ try :
531+ spec = importlib .util .find_spec (app_ref )
532+ if spec and spec .origin :
533+ watch_paths .add (os .path .dirname (os .path .abspath (spec .origin )))
534+ except ImportError :
535+ pass
536+
537+ watchfiles .run_process (
538+ * watch_paths ,
539+ target = _reloadable_server_target ,
540+ args = (
541+ app_ref ,
542+ address ,
543+ cors_origin ,
544+ cors_cocoindex ,
545+ cors_local ,
546+ live_update ,
547+ quiet ,
548+ ),
549+ watch_filter = watchfiles .PythonFilter (),
550+ callback = lambda changes : click .secho (
551+ f"\n Detected changes in { len (changes )} file(s), reloading server...\n " ,
552+ fg = "cyan" ,
553+ ),
554+ )
555+ else :
556+ _run_server (
557+ app_ref ,
558+ address = address ,
559+ cors_origin = cors_origin ,
560+ cors_cocoindex = cors_cocoindex ,
561+ cors_local = cors_local ,
562+ live_update = live_update ,
563+ quiet = quiet ,
564+ )
565+
566+
567+ def _reloadable_server_target (* args : Any , ** kwargs : Any ) -> None :
568+ """Reloadable target for the watchfiles process."""
569+ _initialize_cocoindex_in_process ()
570+ _run_server (* args , ** kwargs )
571+
572+
573+ def _run_server (
574+ app_ref : str ,
575+ address : str | None = None ,
576+ cors_origin : str | None = None ,
577+ cors_cocoindex : bool = False ,
578+ cors_local : int | None = None ,
579+ live_update : bool = False ,
580+ quiet : bool = False ,
581+ ) -> None :
582+ """Helper function to run the server with specified settings."""
505583 _load_user_app (app_ref )
506584
507585 server_settings = setting .ServerSettings .from_env ()
@@ -525,7 +603,20 @@ def server(
525603 if live_update :
526604 options = flow .FlowLiveUpdaterOptions (live_mode = True , print_stats = not quiet )
527605 flow .update_all_flows (options )
528- input ("Press Enter to stop..." )
606+
607+ click .secho ("Press Ctrl+C to stop the server." , fg = "yellow" )
608+
609+ shutdown_event = threading .Event ()
610+
611+ def handle_signal (signum : int , frame : FrameType | None ) -> None :
612+ shutdown_event .set ()
613+
614+ async def _wait_for_shutdown_signal () -> None :
615+ await asyncio .to_thread (shutdown_event .wait )
616+
617+ signal .signal (signal .SIGINT , handle_signal )
618+ signal .signal (signal .SIGTERM , handle_signal )
619+ execution_context .run (_wait_for_shutdown_signal ())
529620
530621
531622def _flow_name (name : str | None ) -> str :
0 commit comments