From b2fecae6e65cba2a47b18de2c66dc77c2bec438c Mon Sep 17 00:00:00 2001 From: Alexander Goscinski Date: Sat, 29 Nov 2025 07:10:20 +0100 Subject: [PATCH] Implementation of supervisor service infrastructure --- CONTRIBUTING.md | 100 +- pyproject.toml | 19 +- scripts/cmd_airflow.py | 1 - scripts/cmd_daemon_start.py | 106 ++ scripts/cmd_daemon_status.py | 133 +++ scripts/cmd_daemon_stop.py | 54 + scripts/setup_test_profile.py | 94 ++ .../aiida_core/engine/daemon/_supervisor.py | 1061 +++++++++++++++++ .../engine/daemon/airflow_daemon.py | 113 ++ .../engine/daemon/triggerer_service.py | 197 +++ 10 files changed, 1873 insertions(+), 5 deletions(-) create mode 100644 scripts/cmd_daemon_start.py create mode 100644 scripts/cmd_daemon_status.py create mode 100644 scripts/cmd_daemon_stop.py create mode 100644 scripts/setup_test_profile.py create mode 100644 src/airflow_provider_aiida/aiida_core/engine/daemon/_supervisor.py create mode 100644 src/airflow_provider_aiida/aiida_core/engine/daemon/airflow_daemon.py create mode 100644 src/airflow_provider_aiida/aiida_core/engine/daemon/triggerer_service.py diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 22c6c0c..23cf5ae 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -15,6 +15,22 @@ hatch run hatch-test.py3.11:setup-profile hatch run hatch-test.py3.11:unit-tests ``` +To run the integration tests we need to start the Airflow services before. + +```bash +# 1. Start PostgreSQL +hatch run hatch-test.py3.11:start-psql-service + +# 2. Create AiiDA profile and databases +hatch run hatch-test.py3.11:setup-profile + +# 3. Start Airflow services +hatch run hatch-test.py3.11:daemon-start + +# 4. Run integration tests +hatch run hatch-test.py3.11:integration-tests +``` + ## Unit tests ### Setup test environment @@ -84,6 +100,22 @@ postgres (admin) The profile will be located at `.pytest/.aiida/test/` (or `$AIIDA_PATH/.aiida/test/` if `AIIDA_PATH` is set) with the following structure: ``` .pytest/.aiida/test/ +├── daemon/ # Daemon process management +│ ├── services/ # Per-service directories +│ │ ├── scheduler/ +│ │ │ ├── state.json # PID, state, timestamps +│ │ │ ├── stdout.log # Service output +│ │ │ └── stderr.log # Service errors +│ │ ├── triggerer/ +│ │ │ ├── state.json +│ │ │ ├── stdout.log +│ │ │ └── stderr.log +│ │ ├── dag-processor/ +│ │ │ └── ... +│ │ └── api-server/ +│ │ └── ... +│ ├── daemon.pid # Daemon PID file +│ └── daemon.log # Daemon output (background mode) └── airflow/ # Airflow files ├── dags/ # DAG files └── airflow.cfg # Airflow configuration @@ -113,6 +145,66 @@ hatch test -- -m 'not integration' Integration tests require running Airflow services (scheduler, triggerer, dag-processor). These tests are marked with `@pytest.mark.integration`. + +### Start Airflow in background with daemon + +Use the daemon manager to start all services at once: + +```bash +# Start daemon in background (default) +hatch run hatch-test.py3.11:daemon-start +``` + +The daemon will: +- Check if daemon is already running +- Start all Airflow services (scheduler, triggerer, dag-processor, api-server) +- Run a health monitor thread that tracks service status every 5 seconds +- Detach and run in background + +**Background mode** (default): +- Daemon detaches and runs in background +- Services continue running after terminal closes +- Use `daemon-stop` to stop all services +- Use `daemon-status` to check service status + +**Foreground mode** (for debugging): +```bash +# Run daemon in foreground with --foreground flag +python scripts/cmd_daemon_start.py --profile-name test --foreground +``` +- Keeps daemon running in your terminal +- Press Ctrl+C to gracefully stop all services +- Useful for interactive testing and debugging + +Check service status in another terminal: + + ```bash +hatch run hatch-test.py3.11:daemon-status +``` + +Expected output: +``` +=== Airflow Test Services Status === + +Health monitor daemon: RUNNING (PID: 12345) + +Service Status: +-------------------------------------------------------------------------------- +Service State PID Uptime Last Check Failures +-------------------------------------------------------------------------------- +scheduler ✓ RUNNING 12346 2m 15s 3s ago 0 +triggerer ✓ RUNNING 12347 2m 15s 3s ago 0 +dag-processor ✓ RUNNING 12348 2m 15s 3s ago 0 +api-server ✓ RUNNING 12349 2m 15s 3s ago 0 +-------------------------------------------------------------------------------- + +Airflow home: .pytest/.aiida/test/airflow +DAGs folder: .pytest/.aiida/test/airflow/dags +Logs: .pytest/.aiida/test/airflow/logs +``` + + + ### Start Airflow services in foreground (recommended for debugging) If you prefer to start services manually in separate terminals: @@ -157,10 +249,16 @@ hatch test Be sure that all airflow services and the docker service have been stopped. The docker service can be stopped with. -``` + +```bash hatch run hatch-test.py3.11:stop-psql-service ``` +The Airflow services +```bash +hatch run hatch-test.py3.11:daemon-stop +``` + To clean test artifacts including the PostgreSQL databasese cluster, as well as the aiida and the airflow config. ```bash hatch run hatch-test.py3.11:clean diff --git a/pyproject.toml b/pyproject.toml index e987e01..5e4b482 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,6 +37,9 @@ Source = "..." # TODO file = "README.md" content-type = "text/markdown" +[project.scripts] +airflow-provider-aiida-triggerer-service = "airflow_provider_aiida.aiida_core.engine.daemon.triggerer_service:main" + [project.entry-points.apache_airflow_provider] provider_info = "airflow_provider_aiida.__init__:get_provider_info" @@ -54,6 +57,7 @@ python = "3.11" [tool.hatch.envs.hatch-test] default-args = [] +installer = "uv" dependencies = [ "pytest>=7.0", "pytest-mock>=3.10", @@ -83,9 +87,6 @@ AIRFLOW_PROVIDER_AIIDA__TESTS__AIRFLOW_POSTGRES_PASSWORD = "password" setup-profile = [ "python scripts/cmd_profile_create.py --postgres-host {env:POSTGRES_HOST} --postgres-port {env:POSTGRES_HOST_PORT} --postgres-user {env:POSTGRES_USER} --postgres-password {env:POSTGRES_PASSWORD} --profile-name {env:AIRFLOW_PROVIDER_AIIDA__TESTS__AIIDA_PROFILE} --aiida-password {env:AIRFLOW_PROVIDER_AIIDA__TESTS__AIRFLOW_POSTGRES_PASSWORD}", ] -reserialize = [ - "python scripts/cmd_airflow.py --profile-name {env:AIRFLOW_PROVIDER_AIIDA__TESTS__AIIDA_PROFILE} dags reserialize --bundle-name aiida_dags", -] teardown-profile = [ "python scripts/cmd_profile_delete.py --profile-name {env:AIRFLOW_PROVIDER_AIIDA__TESTS__AIIDA_PROFILE} --postgres-user {env:POSTGRES_USER} --postgres-password {env:POSTGRES_PASSWORD}" @@ -96,6 +97,11 @@ start-psql-service = "docker compose -f docker-compose.test.yml up -d" stop-psql-service = "docker compose -f docker-compose.test.yml down -v" status-psql-service = "docker ps" +# Daemon commands +daemon-start = "python scripts/cmd_daemon_start.py --profile-name {env:AIRFLOW_PROVIDER_AIIDA__TESTS__AIIDA_PROFILE}" +daemon-stop = "python scripts/cmd_daemon_stop.py --profile-name {env:AIRFLOW_PROVIDER_AIIDA__TESTS__AIIDA_PROFILE}" +daemon-status = "python scripts/cmd_daemon_status.py --profile-name {env:AIRFLOW_PROVIDER_AIIDA__TESTS__AIIDA_PROFILE}" + # Start individual Airflow services (for manual testing) scheduler = "python scripts/cmd_airflow.py --profile-name {env:AIRFLOW_PROVIDER_AIIDA__TESTS__AIIDA_PROFILE} scheduler" api-server = "python scripts/cmd_airflow.py --profile-name {env:AIRFLOW_PROVIDER_AIIDA__TESTS__AIIDA_PROFILE} api-server" @@ -108,6 +114,13 @@ integration-tests = "pytest -m 'integration' {args:{root}/tests}" # Default test runner - excludes integration tests, only runs tests/ directory run = "pytest {args:{root}/tests}" +kill-zombies = [ + "pkill -f 'airflow dag-processor'", + "pkill -f 'airflow api-server'", + "pkill -f 'airflow scheduler'", + "pkill -f 'airflow triggerer'", +] + # Clean test artifacts # NOTE: Only use after services have been stopped to clean = [ diff --git a/scripts/cmd_airflow.py b/scripts/cmd_airflow.py index 95c5ac2..43f0029 100755 --- a/scripts/cmd_airflow.py +++ b/scripts/cmd_airflow.py @@ -13,7 +13,6 @@ """ import sys -import os import subprocess from airflow_provider_aiida.aiida_core import load_profile diff --git a/scripts/cmd_daemon_start.py b/scripts/cmd_daemon_start.py new file mode 100644 index 0000000..efbbc4b --- /dev/null +++ b/scripts/cmd_daemon_start.py @@ -0,0 +1,106 @@ +#!/usr/bin/env python +""" +Start Airflow services for testing using the new Airflow Daemon. + +This script is designed to be run via hatch: + hatch run hatch-test.py3.11:start-airflow-services +""" + +import sys +import os +import argparse + + +def start_services(profile_name: str, num_sync_workers: int, num_async_workers: int, foreground: bool): + """Start all Airflow services using the daemon.""" + from airflow_provider_aiida.aiida_core.engine.daemon.airflow_daemon import AirflowDaemon + + print("=== Starting Airflow Test Services ===\n") + print(f"Sync workers (scheduler): {num_sync_workers}") + print(f"Async workers (triggerer): {num_async_workers}\n") + + try: + # Load AiiDA profile + from airflow_provider_aiida.aiida_core import load_profile + aiida_profile = load_profile(profile_name) + AirflowDaemon(aiida_profile).start( + num_workers=num_sync_workers, + num_triggerers=num_async_workers, + foreground=foreground + ) + + return 0 + + except Exception as e: + print(f"\n✗ Error: {e}") + import traceback + traceback.print_exc() + return 1 + + +def main(): + """Start Airflow services.""" + parser = argparse.ArgumentParser( + description="Start Airflow test services using daemon architecture", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Examples: + # Start with default (1 sync worker, 1 async worker) + python cmd_daemon_start.py + + # Start with 2 sync workers and 3 async workers + python cmd_daemon_start.py 2 3 + + # Start with 4 sync and 4 async workers, with specific profile + python cmd_daemon_start.py 4 4 --profile-name myprofile + + # Start in foreground mode with defaults + python cmd_daemon_start.py --foreground + """ + ) + + parser.add_argument( + 'num_sync_workers', + type=int, + nargs='?', + default=1, + help='Number of sync workers (scheduler parallelism, default: 1)' + ) + + parser.add_argument( + 'num_async_workers', + type=int, + nargs='?', + default=1, + help='Number of async workers (triggerer instances, default: 1)' + ) + + parser.add_argument( + '--foreground', '-f', + action='store_true', + help='Run daemon in foreground (default: background)' + ) + + parser.add_argument( + '--profile-name', + default=os.getenv('AIIDA_PROFILE'), + help='AiiDA profile name (default: from AIIDA_PROFILE env)' + ) + + args = parser.parse_args() + + if args.num_sync_workers < 1: + parser.error("Number of sync workers must be at least 1") + if args.num_async_workers < 1: + parser.error("Number of async workers must be at least 1") + + return start_services( + args.profile_name, + args.num_sync_workers, + args.num_async_workers, + args.foreground + ) + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/scripts/cmd_daemon_status.py b/scripts/cmd_daemon_status.py new file mode 100644 index 0000000..60327dc --- /dev/null +++ b/scripts/cmd_daemon_status.py @@ -0,0 +1,133 @@ +#!/usr/bin/env python +""" +Check status of Airflow test services. + +This script is designed to be run via hatch: + hatch run hatch-test.py3.11:status-services +""" + +import sys +import os +import argparse +import time + + +def format_status(status_dict: dict) -> str: + """Format status dictionary as a readable table.""" + lines = [] + + # Header + lines.append("=" * 100) + lines.append(f"Daemon Status - Session: {status_dict['session']}") + lines.append("=" * 100) + + # Supervisor info + lines.append("\nSupervisor Process:") + supervisor = status_dict.get('supervisor') + if supervisor: + if 'error' in supervisor: + lines.append(f" Error: {supervisor['error']}") + else: + lines.append(f" PID: {supervisor['pid']}") + lines.append(f" Status: {supervisor['status']}") + lines.append(f" Started: {time.ctime(supervisor['started'])}") + lines.append(f" Log: {supervisor['log']}") + else: + lines.append(" No supervisor info available") + + # Services table + lines.append("\n" + "-" * 100) + lines.append("Services:") + lines.append("-" * 100) + + # Check if there are any errors at the top level + if 'error' in status_dict: + lines.append(f"\nError: {status_dict['error']}") + return "\n".join(lines) + + services = status_dict.get('services', {}) + if not services: + lines.append("\nNo services configured") + return "\n".join(lines) + + # Table header + lines.append(f"\n{'Service':<30} {'Type':<10} {'Workers':<8} {'PID':<10} {'Status':<10} {'Failures':<10}") + lines.append("-" * 100) + + # Table rows + for service_name, service_info in services.items(): + svc_type = service_info['type'] + + if svc_type == 'service': + # Non-worker service (single instance) + pid = str(service_info['pid']) if service_info['pid'] is not None else "-" + status = service_info['status'] if service_info['status'] is not None else "ERROR" + failures = str(service_info['failures']) if service_info['failures'] is not None else "-" + + if service_info.get('error'): + status = f"ERROR: {service_info['error']}" + + num_workers = service_info.get('num_workers', '-') + lines.append(f"{service_name:<30} {svc_type:<10} {num_workers:<8} {pid:<10} {status:<10} {failures:<10}") + + elif svc_type == 'worker': + # Worker service (multiple instances) + for worker_num, worker_info in service_info['workers'].items(): + worker_str = f"#{worker_num}" + pid = str(worker_info['pid']) if worker_info['pid'] is not None else "-" + status = worker_info['status'] if worker_info['status'] is not None else "ERROR" + failures = str(worker_info['failures']) if worker_info['failures'] is not None else "-" + + if worker_info.get('error'): + status = f"ERROR: {worker_info['error']}" + + lines.append(f"{service_name:<30} {svc_type:<10} {worker_str:<8} {pid:<10} {status:<10} {failures:<10}") + + lines.append("\n" + "=" * 100) + + return "\n".join(lines) + + +def check_status(profile_name: str): + """Check and display status of Airflow daemon and services.""" + from airflow_provider_aiida.aiida_core.engine.daemon.airflow_daemon import AirflowDaemon + + try: + # Load AiiDA profile + from airflow_provider_aiida.aiida_core import load_profile + aiida_profile = load_profile(profile_name) + + # Get status dictionary + status_dict = AirflowDaemon(aiida_profile).status() + + # Format and print + print(format_status(status_dict)) + + return 0 + + except Exception as e: + print(f"\n✗ Error: {e}") + import traceback + traceback.print_exc() + return 1 + + +def main(): + """Parse arguments and check status.""" + parser = argparse.ArgumentParser( + description="Check status of Airflow test services", + formatter_class=argparse.RawDescriptionHelpFormatter, + ) + + parser.add_argument( + '--profile-name', + default=os.getenv('AIIDA_PROFILE'), + help='AiiDA profile name (default: from AIIDA_PROFILE env)' + ) + + args = parser.parse_args() + return check_status(args.profile_name) + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/scripts/cmd_daemon_stop.py b/scripts/cmd_daemon_stop.py new file mode 100644 index 0000000..d788686 --- /dev/null +++ b/scripts/cmd_daemon_stop.py @@ -0,0 +1,54 @@ +#!/usr/bin/env python +""" +Stop Airflow daemon and all its services. + +This script is designed to be run via hatch: + hatch run hatch-test.py3.11:stop-airflow-services +""" + +import sys +import os +import signal +import time +import argparse +from pathlib import Path + + +def stop_services(profile_name: str): + """Stop the Airflow daemon and all its services.""" + from airflow_provider_aiida.aiida_core.engine.daemon.airflow_daemon import AirflowDaemon + + try: + # Load AiiDA profile + from airflow_provider_aiida.aiida_core import load_profile + aiida_profile = load_profile(profile_name) + AirflowDaemon(aiida_profile).stop() + + return 0 + + except Exception as e: + print(f"\n✗ Error: {e}") + import traceback + traceback.print_exc() + return 1 + + +def main(): + """Parse arguments and stop services.""" + parser = argparse.ArgumentParser( + description="Stop Airflow test services", + formatter_class=argparse.RawDescriptionHelpFormatter, + ) + + parser.add_argument( + '--profile-name', + default=os.getenv('AIIDA_PROFILE'), + help='AiiDA profile name (default: from AIIDA_PROFILE env)' + ) + + args = parser.parse_args() + return stop_services(args.profile_name) + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/scripts/setup_test_profile.py b/scripts/setup_test_profile.py new file mode 100644 index 0000000..d84f1e0 --- /dev/null +++ b/scripts/setup_test_profile.py @@ -0,0 +1,94 @@ +#!/usr/bin/env python +""" +Setup AiiDA test profile using the same PostgreSQL as the test environment. + +This script is designed to be run via hatch: + hatch run hatch-test.py3.11:setup-profile +""" + +import sys +import os +import argparse +from airflow_provider_aiida.utils.profile_creation import setup_aiida_profile + + +def parse_arguments(): + """Parse command line arguments.""" + parser = argparse.ArgumentParser( + description="Setup AiiDA and Airflow test profiles using PostgreSQL", + formatter_class=argparse.RawDescriptionHelpFormatter, + ) + + parser.add_argument( + '--postgres-host', + default=os.getenv('POSTGRES_HOST', '127.0.0.1'), + help='PostgreSQL server hostname (default: from POSTGRES_HOST env or 127.0.0.1)' + ) + + parser.add_argument( + '--postgres-port', + type=int, + default=int(os.getenv('POSTGRES_HOST_PORT', '5432')), + help='PostgreSQL server port (default: from POSTGRES_HOST_PORT env or 5434)' + ) + + parser.add_argument( + '--postgres-user', + default=os.getenv('POSTGRES_USER', 'postgres'), + help='PostgreSQL admin username (default: from POSTGRES_USER env or postgres)' + ) + + parser.add_argument( + '--postgres-password', + default=os.getenv('POSTGRES_PASSWORD', 'postgres'), + help='PostgreSQL admin password (default: from POSTGRES_PASSWORD env or empty)' + ) + + parser.add_argument( + '--aiida-password', + default=os.getenv('AIIDA_PASSWORD', 'password'), + help='PostgreSQL aiida user password (default: just password)' + ) + + parser.add_argument( + '--profile-name', + default=os.getenv('AIIDA_PROFILE', 'presto'), + help='AiiDA profile name (default: from AIIDA_PROFILE env or test)' + ) + + return parser.parse_args() + +def main(): + """Main function.""" + try: + args = parse_arguments() + + profile_parameters = { + 'pg_host': args.postgres_host, + 'pg_port': args.postgres_port, + 'pg_user': args.postgres_user, + 'pg_password': args.postgres_password, + 'profile_name': args.profile_name, + 'overwrite': True + } + + profile = setup_aiida_profile(**profile_parameters) + + if profile is None: + return 1 + + # Show how to use with process manager + print("\nTo start Airflow services with this profile:") + print(" hatch run hatch-test.py3.11:start-services") + + return 0 + + except Exception as e: + print(f"\n✗ Error: {e}") + import traceback + traceback.print_exc() + return 1 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/src/airflow_provider_aiida/aiida_core/engine/daemon/_supervisor.py b/src/airflow_provider_aiida/aiida_core/engine/daemon/_supervisor.py new file mode 100644 index 0000000..16ff13c --- /dev/null +++ b/src/airflow_provider_aiida/aiida_core/engine/daemon/_supervisor.py @@ -0,0 +1,1061 @@ +""" +Service Daemon - Generic process manager for services with health monitoring. + +Directory Structure: +``` +daemon/ +├── profile-/ +│ ├── / +│ │ ├── worker_service_config.json # config how to start the worker service +│ │ ├── 1/ +│ │ │ ├── info.json # PID, state, timestamps, failure count +│ │ │ ├── stdout.log # Service stdout +│ │ │ └── stderr.log # Service stderr +│ │ └── 2/ +│ │ ├── info.json # PID, state, timestamps, failure count +│ │ ├── stdout.log # Service stdout +│ │ └── stderr.log # Service stderr +│ └── / +│ ├── service_config.json # config how to start the service +│ ├── info.json # PID, state, timestamps, failure count +│ ├── stdout.log # Service stdout +│ └── stderr.log # Service stderr +├── supervisor_info.json # Single daemon PID file +├── supervisor_config.json # Single daemon PID file +└── supervisor.log # Daemon output (background mode only) +``` + +Key Features: +- Generic daemon works with any service type +- Single daemon.pid file +- Health monitor runs as thread (not separate process) +- Services are children of daemon (proper parent-child relationship) +- When daemon terminates, all children are cleaned up automatically +- Foreground mode: daemon runs in terminal, Ctrl+C stops everything +- Background mode: daemon detaches, use stop script to send SIGTERM +""" + +# Design choices +# Services need to be able to be started by the command line +# Worker and reglar servivces are conceptual separated + +from dataclasses import dataclass, asdict +import os +import re +import sys +import subprocess +import time +import signal +import threading +import json +import psutil +from typing import List, Dict, Type, assert_never, Self, ClassVar +from abc import abstractmethod, ABC +from pathlib import Path +import enum +import logging + +# TODO public API + +# TODO logger as in aiida +logger = logging.getLogger() + +class ServiceState(enum.Enum): + ALIVE = "ALIVE" + DEAD = "DEAD" + +class FolderIdentifier(ABC): + + def __init__(self, identifier: str): + if not self.is_valid_identifier(identifier): + #TODO polish error msg + raise ValueError( + f"Invalid service identifier {identifier}. May only contain alphanumeric characters, underscores, hyphens" + "Must start with a letter or number" + "Length: 1-255 characters" + ) + self._identifier = identifier + + @property + def identifier(self) -> str: + return self._identifier + + def __str__(self) -> str: + """String representation returns full identifier.""" + return self._identifier + + def __repr__(self) -> str: + """Developer-friendly representation.""" + return f"ServiceIdentifier('{self._identifier}')" + + def __hash__(self) -> int: + """Hash based on the full identifier string.""" + return hash(self._identifier) + + def __eq__(self, other) -> bool: + """Equality based on the full identifier string.""" + if not isinstance(other, ServiceIdentifier): + return False + return self._identifier == other._identifier + + @staticmethod + def is_valid_identifier(name: str) -> bool: + """ + Check if a string is a valid service identifier name. + + Allows: alphanumeric characters, underscores, hyphens + Must start with a letter or number + Length: 1-255 characters + """ + if not name: + return False + + # Only allow safe characters: letters, numbers, underscore, hyphen + # Must start with alphanumeric + pattern = r'^[a-zA-Z0-9][a-zA-Z0-9_-]{0,254}$' + + return bool(re.match(pattern, name)) + +# TODO needed? kind of it checks if it is a folder identifier +class ServiceIdentifier(FolderIdentifier): + pass + +@dataclass +class JsonSerialization: + @classmethod + def from_file(cls, path: Path) -> Self: + with open(path, 'r') as f: + return cls(**json.load(f)) + + def to_file(self, path: Path): + with open(path, 'w') as f: + json.dump(asdict(self), f) + +@dataclass +class ProcessInfo: + pid: int + create_time: float + +# TODO MonitoredProcessInfo -> WorkerInfo + +@dataclass +class ServiceInfo(ProcessInfo, JsonSerialization): + service_name: str + command: str + state: str + last_check: float + failures: int + +# We need auto-register pattern because we need to be able to load the the specialized ServiceConfigs from the json file, so we potentially need to be able to create a ServiceConfig just from the service_name + +SERVICE_CONFIG_REGISTRY: Dict[str, Type["ServiceConfig"]] = {} + +@dataclass +class ServiceConfig(ABC): + # TODO rename to service_identifier? + service_name: ClassVar[str] + command: ClassVar[str] + + def __init_subclass__(cls, **kwargs): + super().__init_subclass__(**kwargs) + # Auto-register any subclass that defines service_name + name = getattr(cls, "service_name", None) + if name is not None: + SERVICE_CONFIG_REGISTRY[name] = cls + + @abstractmethod + def create_unique_env(self) -> dict[str, str]: + raise NotImplementedError() + + def to_dict(self): + values = asdict(self) + # adding class attributes that are not included + values["service_name"] = self.service_name + values["command"] = self.command + return values + +@dataclass +class NonWorkerServiceConfig(ServiceConfig): + pass + +@dataclass +class WorkerServiceConfig(ServiceConfig): + num_workers: int + + # TODO need to rethink since for airflow it makes more sense to + # start the sync workers all at once + @abstractmethod + def create_unique_env(self) -> dict[str, str]: + raise NotImplementedError() + +@dataclass +class AiidaWorkerConfig(WorkerServiceConfig): + service_name: ClassVar[str] = "aiida_worker" + command: ClassVar[str] = "verdi daemon worker" + + def create_unique_env(self) -> dict[str, str]: + aiida_path = os.environ.get("AIIDA_PATH") + return {} if aiida_path is None else {"AIIDA_PATH": aiida_path} + +# TODO make commands tunable, not sure how without complicated interface, maybe just comamnd_base and args +@dataclass +class SleepServiceConfig(NonWorkerServiceConfig): + service_name: ClassVar[str] = "sleep10" + command: ClassVar[str] = "sleep 10" + + def create_unique_env(self) -> dict[str, str]: + return {} + +class ServiceConfigFactory: + + @staticmethod + def from_file(path: Path) -> ServiceConfig: + with open(path, "r") as f: + values = json.load(f) + return ServiceConfigFactory.from_dict(values) + + @staticmethod + def from_dict(values: dict) -> ServiceConfig: + if not (service_name := values.pop("service_name")): + raise ValueError("Missing 'service_name' in config") + if not (cls := SERVICE_CONFIG_REGISTRY.get(service_name)): + raise ValueError(f"Unknown service_name {service_name!r}") + + # have to remove class variables before init + values.pop("command") + return cls(**values) + +class ServiceConfigMap: + """ + Collection of service configurations with dict-like access. + + This class manages multiple ServiceConfig instances and provides + methods to query and access them by ServiceIdentifier. + + Supports dict-like access: collection[service_id] returns the ServiceConfig. + """ + + def __init__(self, configs: List[ServiceConfig]): + # enforce uniqueness + configs_counts = {} + configs_counts = {config.service_name: 1 + configs_counts.get(config.service_name, 0) for config in configs} + duplicate_configs = list(filter(lambda key: configs_counts[key] > 1 , configs_counts.keys())) + if len(duplicate_configs) > 0: + raise ValueError(f"Found serivce names {duplicate_configs} multiple times in the list of service configs") + self._configs = {ServiceIdentifier(config.service_name): config for config in configs} + + def __getitem__(self, identifier: str | ServiceIdentifier) -> ServiceConfig: + # Convert string to ServiceIdentifier if needed + if isinstance(identifier, str): + identifier = ServiceIdentifier(identifier) + + if identifier not in self._configs: + available = [sid for sid in self._configs.keys()] + raise KeyError( + f"Service '{identifier}' not found. " + f"Available services: {available}" + ) + return self._configs[identifier] + + def __contains__(self, identifier: str | ServiceIdentifier) -> bool: + if isinstance(identifier, str): + identifier = ServiceIdentifier(identifier) + return identifier in self._configs + + def __len__(self) -> int: + return len(self._configs) + + def __iter__(self): + return iter(self._configs) + + def keys(self): + return self._configs.keys() + + def values(self): + return self._configs.values() + + def items(self): + return self._configs.items() + + def to_file(self, path: Path): + with open(path, "w") as f: + json.dump({key.identifier: config.to_dict() for key, config in self._configs.items()}, f) + + @classmethod + def from_file(cls, path: Path) -> Self: + with open(path, "r") as f: + values = json.load(f) + return cls([ServiceConfigFactory.from_dict(value) for value in values.values()]) + +@dataclass +class SupervisorInfo(ProcessInfo, JsonSerialization): + pass + +class ServiceSupervisorCommon: + SUPERVISOR_INFO_FILE = "supervisor_info.json" + SUPERVISOR_CONFIG_FILE = "supervisor_config.json" + # TODO split log + SUPERVISOR_LOG_FILE = "supervisor.log" + PROCESS_INFO_FILE = "info.json" + KILL_TIMEOUT = 10.0 + + @staticmethod + def _start_service_process(service_dir: Path, config: ServiceConfig, info: ServiceInfo | None = None): + ServiceSupervisorCommon._start_process(service_dir / config.service_name, config, info) + + @staticmethod + def _start_worker_service_process(service_dir: Path, config: ServiceConfig, worker_num: int, info: ServiceInfo | None = None): + ServiceSupervisorCommon._start_process(service_dir / config.service_name / str(worker_num), config, info) + + @staticmethod + def _start_process(process_dir: Path, config: ServiceConfig, info: ServiceInfo | None = None): + process_dir.mkdir(parents=True, exist_ok=True) + + # Open log files in service directory + stdout_log = process_dir / "stdout.log" + stderr_log = process_dir / "stderr.log" + + stdout_file = open(stdout_log, 'a', buffering=1) + stderr_file = open(stderr_log, 'a', buffering=1) + + # Get unique environment for this service + service_env = config.create_unique_env() + + process = subprocess.Popen( + config.command.split(), + stdout=stdout_file, + stderr=stderr_file, + env=os.environ | service_env, + start_new_session=True # Create new process group + ) + create_time = psutil.Process(process.pid).create_time() + + service_info = ServiceInfo( + service_name=config.service_name, + command=config.command, + pid=process.pid, + state=ServiceState.ALIVE.value, + create_time=create_time, + last_check=create_time, + failures=0 if info is None else info.failures+1 + ) + service_info.to_file(process_dir / ServiceSupervisorCommon.PROCESS_INFO_FILE) + + # TODO consider adding the create_time, then we maybe don't need do check if is alive but directly send _kill_service + @staticmethod + def _kill_service(pid: int) -> bool: + # TODO add create_time + # Kill the process with timeout of 5 seconds, force kill if it doesn't stop + logger.debug(f"Stopping process with PID {pid}") + kill_successful = True + try: + # Send SIGTERM for graceful shutdown + os.kill(pid, signal.SIGTERM) + + # Wait up to 5 seconds for the process to terminate + start_time = time.time() + while time.time() - start_time < ServiceSupervisorCommon.KILL_TIMEOUT: + os.kill(pid, 0) + time.sleep(0.1) + else: + # Timeout - force kill + print(f"⚠ Process {pid} did not stop gracefully, force killing...") + os.kill(pid, signal.SIGKILL) + start_time = time.time() + while time.time() - start_time < ServiceSupervisorCommon.KILL_TIMEOUT: + os.kill(pid, 0) + time.sleep(0.1) + kill_successful = False + logger.info(f"Force killing process {pid} failed.") + except ProcessLookupError: + # Process already gone + logger.debug(f"Process {pid} stopped.") + pass + except PermissionError: + kill_successful = False + logger.debug(f"No permission to kill process {pid}") + except Exception as e: + kill_successful = False + logger.warning(f"Error stopping process {pid}: {e}") + return kill_successful + + @staticmethod + def _is_alive(pid: int, create_time: float) -> bool: + """ + Check if process is alive with PID reuse protection. + + This function verifies: + 1. Process with the PID exists + 2. Process creation time matches what we stored + + This prevents accidentally treating a different process + (that reused the PID) as our service. + + Args: + service_info: Service information from state file + + Returns: + True if process exists and is confirmed to be our service + """ + try: + process = psutil.Process(pid) + + # Get actual process creation time + actual_create_time = process.create_time() + + # Compare with stored creation time + # Allow 1 second tolerance for timing differences + time_diff = abs(actual_create_time - create_time) + + if time_diff > 1.0: + # PID has been reused by a different process! + print(f"WARNING: PID {pid} reused by different process!") + print(f" Expected create time: {create_time}") + print(f" Actual create time: {actual_create_time}") + print(f" Difference: {time_diff} seconds") + return False + + # Process exists and creation time matches + return True + + except psutil.NoSuchProcess: + # Process doesn't exist + return False + except psutil.AccessDenied: + # Process exists but we don't have permission to access it + # This could mean: + # 1. It's running as a different user + # 2. System restrictions prevent access + # Conservative approach: assume it's alive + print(f"WARNING: Cannot access PID {pid} (permission denied)") + return True + + # TODO add for _cleanup_last_session a boolean return value so we can say it worked + @staticmethod + def stop(session_dir: Path): + service_configs = ServiceConfigMap.from_file(session_dir / ServiceSupervisorCommon.SUPERVISOR_CONFIG_FILE) + + supervisor_info_file = session_dir / ServiceSupervisorCommon.SUPERVISOR_INFO_FILE + if (supervisor_info_file := supervisor_info_file).exists(): + service_info = None + try: + service_info = SupervisorInfo.from_file(supervisor_info_file) + except (FileNotFoundError, json.JSONDecodeError, KeyError, TypeError) as e: + logger.warning(f"Skipping invalid or corrupted supervisor info file {supervisor_info_file}: {e}. ") + + if (service_info is not None and + ServiceSupervisorCommon._is_alive(service_info.pid, service_info.create_time) and + not ServiceSupervisorCommon._kill_service(service_info.pid)): + pass + # TODO but need to include state to ServiceInfo + #service_info.state = ServiceState.DEAD.value + #service_info.to_file(supervisor_info_file) + + for config in service_configs.values(): + if isinstance(config, NonWorkerServiceConfig): + process_dir = session_dir / config.service_name + if (info_file := process_dir / ServiceSupervisorCommon.PROCESS_INFO_FILE).exists(): + try: + info = ServiceInfo.from_file(info_file) + except (FileNotFoundError, json.JSONDecodeError, KeyError, TypeError) as e: + logger.warning(f"Skipping invalid or corrupted info file {info_file}: {e}.") + continue + + # TODO should I consider here info.state? _is_alive is safer + if ServiceSupervisorCommon._is_alive(info.pid, info.create_time): + logger.info(f"Terminating service {info.service_name} process with pid {info.pid}") + if ServiceSupervisorCommon._kill_service(info.pid): + info.state = ServiceState.DEAD.value + info.to_file(process_dir / ServiceSupervisorCommon.PROCESS_INFO_FILE) + else: + info.state = ServiceState.DEAD.value + info.to_file(process_dir / ServiceSupervisorCommon.PROCESS_INFO_FILE) + + elif isinstance(config, WorkerServiceConfig): + for i in range(config.num_workers): + process_dir = session_dir / config.service_name / str(i) + if (info_file := process_dir / ServiceSupervisorCommon.PROCESS_INFO_FILE).exists(): + try: + info = ServiceInfo.from_file(info_file) + except (FileNotFoundError, json.JSONDecodeError, KeyError, TypeError) as e: + logger.warning(f"Skipping invalid or corrupted info file {info_file}: {e}. ") + continue + + if ServiceSupervisorCommon._is_alive(info.pid, info.create_time): + logger.info(f"Terminating service {info.service_name} worker {i} with process with pid {info.pid}") + if ServiceSupervisorCommon._kill_service(info.pid): + info.state = ServiceState.DEAD.value + info.to_file(process_dir / ServiceSupervisorCommon.PROCESS_INFO_FILE) + else: + info.state = ServiceState.DEAD.value + info.to_file(process_dir / ServiceSupervisorCommon.PROCESS_INFO_FILE) + else: + assert_never(config) + + +class ServiceSupervisorProcess: + """ + The supervisor process that manages service processes. + + This class runs as the main daemon process and: + - Monitors child service processes via a health monitor thread + - Restarts failed services automatically + - Handles graceful shutdown on SIGTERM/SIGINT + """ + + def __init__(self, session_dir: Path, foreground: bool): + self._session_dir = session_dir + self._log_fd = None + + if not foreground: + self._daemonize() + + # Configure logger to write to supervisor log file + self._setup_logging() + + # TODO remove prints + print(f"[{time.ctime()}] Supervisor started (PID: {os.getpid()})") + self._save_supervisor_info() + + # Setup signal handlers for graceful shutdown + signal.signal(signal.SIGTERM, self._signal_handler) + signal.signal(signal.SIGINT, self._signal_handler) + + # Setup zombie reaper BEFORE starting any child processes + self._setup_child_reaper() + + # Start health monitor thread + self.monitor_thread = threading.Thread(target=self._health_monitor, daemon=False, name="HealthMonitor") + self.monitor_thread.start() + print("✓ Health monitor started") + + def _daemonize(self): + # TODO add stackoverflow why double fork needed to decouple from terminal session + """Double fork to become a daemon process.""" + # First fork + try: + pid = os.fork() + if pid > 0: + # Exit first parent + sys.exit(0) + except OSError as e: + sys.stderr.write(f"Fork #1 failed: {e}\n") + sys.exit(1) + + # Decouple from parent environment + os.chdir("/") + os.setsid() + os.umask(0) + + # Second fork + try: + pid = os.fork() + if pid > 0: + # Exit second parent + sys.exit(0) + except OSError as e: + sys.stderr.write(f"Fork #2 failed: {e}\n") + sys.exit(1) + + print(f"[{time.ctime()}] Daemon started (PID: {os.getpid()})") + + def _save_supervisor_info(self): + """Save the supervisor's PID and create time to a file for later reference.""" + pid = os.getpid() + supervisor_info = SupervisorInfo(pid, psutil.Process(pid).create_time()) + supervisor_info.to_file(self._session_dir / ServiceSupervisorCommon.SUPERVISOR_INFO_FILE) + + def _setup_logging(self): + """ + Configure logging to write to the supervisor log file. + + This redirects both: + 1. stdout/stderr (for print() statements) + 2. Python logging (for logger.* calls) + + to the supervisor log file. + """ + log_file = self._session_dir / ServiceSupervisorCommon.SUPERVISOR_LOG_FILE + + # Redirect stdout and stderr to log file + sys.stdout.flush() + sys.stderr.flush() + self._log_fd = open(log_file, 'a') + os.dup2(self._log_fd.fileno(), sys.stdout.fileno()) + os.dup2(self._log_fd.fileno(), sys.stderr.fileno()) + + # Configure Python logging to also write to the same file + file_handler = logging.FileHandler(log_file, mode='a') + file_handler.setLevel(logging.DEBUG) + + # Create formatter with timestamp + formatter = logging.Formatter( + fmt='[%(asctime)s] %(levelname)s: %(message)s', + datefmt='%Y-%m-%d %H:%M:%S' + ) + file_handler.setFormatter(formatter) + + # Add handler to the logger + logger.addHandler(file_handler) + logger.setLevel(logging.DEBUG) + + logger.info("Logger configured to write to supervisor log file") + + def _setup_child_reaper(self): + """ + Setup SIGCHLD handler to automatically reap zombie processes. + + When child processes terminate, they become zombies until the parent + calls wait(). This handler automatically reaps terminated children + to prevent zombie accumulation. + """ + def handle_sigchld(signum, frame): + """Non-blocking reaper for terminated child processes.""" + while True: + try: + # Reap any terminated children (non-blocking with WNOHANG) + pid, status = os.waitpid(-1, os.WNOHANG) + if pid == 0: # No more children to reap + break + logger.debug(f"Reaped child process {pid} with exit status {status}") + except ChildProcessError: + # No more children exist + break + except Exception as e: + # Log unexpected errors but don't crash + logger.warning(f"Error in SIGCHLD handler: {e}") + break + + # Register the handler + signal.signal(signal.SIGCHLD, handle_sigchld) + logger.debug("SIGCHLD handler installed for zombie reaping") + + def _check_service_process_health(self, config: ServiceConfig): + process_dir = self._session_dir / config.service_name + self._check_process_health(process_dir, config) + + def _check_worker_process_health(self, config: ServiceConfig, worker_num: int): + process_dir = self._session_dir / config.service_name / str(worker_num) + self._check_process_health(process_dir, config) + + @staticmethod + def _check_process_health(process_dir, config: ServiceConfig): + info_path = process_dir / ServiceSupervisorCommon.PROCESS_INFO_FILE + try: + info = ServiceInfo.from_file(info_path) + except (FileNotFoundError, json.JSONDecodeError, KeyError, TypeError) as e: + logger.warning(f"Skipping invalid or corrupted info file {info_path}: {e}. ") + return + + is_alive = ServiceSupervisorCommon._is_alive(info.pid, info.create_time) + info.last_check = time.time() + if is_alive: + info.last_check = time.time() + info.state = ServiceState.ALIVE.value + try: + info.to_file(info_path) + except: + logger.error(f"Unable to update service {info.service_name} info with PID {info.pid} after restart. Continueing.") + else: + # TODO add traceback into logger errors + logger.info(f"[{time.ctime(info.last_check)}] Service {info.service_name!r} died with PID {info.pid} died") + info.state = ServiceState.DEAD.value + try: + info.to_file(info_path) + except: + logger.error(f"Unable to update service {info.service_name} info with PID {info.pid} after it died. Continueing.") + + logger.info(f"Restarting service {info.service_name!r}...") + try: + ServiceSupervisorCommon._start_process(process_dir, config) + except: + logger.error(f"Unable to restart service {info.service_name}. Continueing.") + return + else: + logger.info(f"Restarting service {info.service_name!r} was successful.") + + # TODO rename to keep_alive_monitor + def _health_monitor(self): + """Health monitor thread - checks child processes every 5s.""" + logger.info(f"[{time.ctime()}] Health monitor thread started") + self.running = True + while self.running: + # TODO global config + time.sleep(5) + + service_configs = ServiceConfigMap.from_file(self._session_dir / ServiceSupervisorCommon.SUPERVISOR_CONFIG_FILE) + for config in service_configs.values(): + if isinstance(config, WorkerServiceConfig): + for i in range(config.num_workers): + self._check_worker_process_health(config, i) + elif isinstance(config, NonWorkerServiceConfig): + self._check_service_process_health(config) + else: + assert_never(config) + print(f"[{time.ctime()}] Health monitor thread shutting down...") + self._shutdown() + + def _shutdown(self): + ServiceSupervisorCommon.stop(self._session_dir) + + # Close log file descriptor before exiting + if self._log_fd is not None: + try: + self._log_fd.close() + except Exception as e: + # Best effort - don't fail shutdown + print(f"Warning: Failed to close log file descriptor: {e}") + + + def _signal_handler(self, signum, frame): + """Handle SIGTERM/SIGINT.""" + print(f"\n[{time.ctime()}] Received signal {signum}, shutting down...") + self._shutdown() + sys.exit(0) + +class ServiceSupervisorController: + + # TODO move to common + class SessionDirUtils: + SESSION_DIR_TIMESTAMP_FORMAT = "%Y-%m-%d_%H-%M-%S" + SESSION_DIR_PATTERN = r'^\d{4}-\d{2}-\d{2}_\d{2}-\d{2}-\d{2}-\d{6}$' + + @staticmethod + def generate_dirname() -> str: + from datetime import datetime + + now = datetime.now() + dirname = now.strftime(ServiceSupervisorController.SessionDirUtils.SESSION_DIR_TIMESTAMP_FORMAT) + f"-{now.microsecond:06d}" + # NOTE: raise value for internal consistency + if not ServiceSupervisorController.SessionDirUtils.match_dirname(dirname): + raise RuntimeError(f"The created timestamp {dirname} does not match pattern {ServiceSupervisorController.SessionDirUtils.SESSION_DIR_PATTERN}. Please contact a developer.") + return dirname + + @staticmethod + def match_dirname(dirname: str) -> bool: + import re + # NOTE: The regex expression has to match the timestamp format + return bool(re.compile(ServiceSupervisorController.SessionDirUtils.SESSION_DIR_PATTERN).match(dirname)) + + + @staticmethod + def _is_running(session_dir: Path) -> bool: + """ + Check if daemon is running by checking PID file. + + Returns: + True if daemon is running, False otherwise + """ + if not session_dir.exists() or not session_dir.is_dir(): + return False + + supervisor_info_file = session_dir / ServiceSupervisorCommon.SUPERVISOR_INFO_FILE + if not supervisor_info_file.exists(): + return False + + try: + info = SupervisorInfo.from_file(supervisor_info_file) + except: + logger.warning(f"Could not read supervisor info file {supervisor_info_file}. Assuming it is not running.") + return False + else: + return ServiceSupervisorCommon._is_alive(info.pid, info.create_time) + + @staticmethod + def _create_new_session_dir(supervisor_dir: Path) -> Path: + timestamp = ServiceSupervisorController.SessionDirUtils.generate_dirname() + daemon_current_session_dir = supervisor_dir / timestamp + daemon_current_session_dir.mkdir(parents=False, exist_ok=False) + + return daemon_current_session_dir + + @staticmethod + def _get_latest_session_dir(supervisor_dir: Path) -> Path | None: + """ + Get the most recent daemon session directory based on timestamp. + + Scans the daemon base directory for session directories with + timestamp format YYYY-MM-DD_HH-MM-SS-mmmmmm and returns the most recent one. + + Returns: + Path to the latest session directory, or None if no sessions exist + """ + ServiceSupervisorController._validate_supervisor_dir(supervisor_dir) + + # Pattern to match timestamp directories: YYYY-MM-DD_HH-MM-SS-mmmmmm + # Example: 2025-11-30_16-46-25-324746 + + session_dirs = [] + + # Find all directories matching the timestamp pattern + for path in supervisor_dir.iterdir(): + if path.is_dir() and ServiceSupervisorController.SessionDirUtils.match_dirname(path.name): + session_dirs.append(path) + + # No session directories found + if not session_dirs: + return None + + # Sort by directory name (timestamp) - latest will be last + # Because YYYY-MM-DD_HH-MM-SS-mmmmmm is lexicographically sortable + session_dirs.sort(key=lambda p: p.name) + + return session_dirs[-1] # Return most recent + + @staticmethod + def _validate_supervisor_dir(supervisor_dir: Path): + # Check if base directory exists + if not supervisor_dir.exists(): + # TODO error msg + raise ValueError() + elif not supervisor_dir.is_dir(): + # TODO error msg + raise ValueError() + + @staticmethod + def start(supervisor_dir: Path, service_configs: ServiceConfigMap, foreground: bool = False): + """ + Start the daemon. + + Args: + service_configs: Configuration for all services to manage + foreground: If True, run in foreground. If False, detach to run in background. + + Raises: + RuntimeError: If daemon is already running + """ + ServiceSupervisorController._validate_supervisor_dir(supervisor_dir) + latest_session_dir = ServiceSupervisorController._get_latest_session_dir(supervisor_dir) + # TODO check if service_configs have changed + if latest_session_dir is not None and ServiceSupervisorController._is_running(latest_session_dir): + logger.info("Daemon is already running, continue with last session. If you want to start with new settings please stop and start daemon.") + return + + session_dir = ServiceSupervisorController._create_new_session_dir(supervisor_dir) + + # Start all configured services + for config in service_configs.values(): + ServiceSupervisorController._start_service(session_dir, config) + + service_configs.to_file(session_dir / ServiceSupervisorCommon.SUPERVISOR_CONFIG_FILE) + ServiceSupervisorProcess(session_dir, foreground) + + @staticmethod + def _start_service(session_dir: Path, config: ServiceConfig): + if session_dir is None: + # TODO error message, should not happen so dev error + raise RuntimeError() + # Get config using base identifier + if isinstance(config, NonWorkerServiceConfig): + ServiceSupervisorCommon._start_service_process(session_dir, config) + elif isinstance(config, WorkerServiceConfig): + for i in range(config.num_workers): + ServiceSupervisorCommon._start_worker_service_process(session_dir, config, i) + + else: + assert_never(config) + + @staticmethod + def stop(supervisor_dir: Path): + ServiceSupervisorController._validate_supervisor_dir(supervisor_dir) + + session_dir = ServiceSupervisorController._get_latest_session_dir(supervisor_dir) + if session_dir is None: + raise ValueError(f"No session found in {supervisor_dir}") + + ServiceSupervisorCommon.stop(session_dir) + + @staticmethod + def status(supervisor_dir: Path) -> dict: + """ + Get status of all services by reading supervisor config and service info files. + + Reads the supervisor config file to discover all configured services, + then reads each service's info file to return current status information. + + Returns: + dict: Status dictionary with structure: + { + 'session': str, + 'supervisor': { + 'pid': int, + 'status': str, + 'started': float, + 'log': str + } or {'error': str} or None, + 'services': { + 'service-name': { + 'type': 'service', + 'command': str, + 'pid': int or None, + 'status': str, + 'started': float or None, + 'last_check': float or None, + 'failures': int or None, + 'stdout_log': str or None, + 'stderr_log': str or None, + 'error': str or None + }, + 'worker-service-name': { + 'type': 'worker', + 'command': str, + 'workers': { + 0: { + 'pid': int or None, + 'status': str, + 'started': float or None, + 'last_check': float or None, + 'failures': int or None, + 'stdout_log': str or None, + 'stderr_log': str or None, + 'error': str or None + }, + ... + } + } + } + } + """ + ServiceSupervisorController._validate_supervisor_dir(supervisor_dir) + + session_dir = ServiceSupervisorController._get_latest_session_dir(supervisor_dir) + if session_dir is None: + raise ValueError(f"No session found in {supervisor_dir}") + + status_dict = { + 'session': session_dir.name, + 'supervisor': None, + 'services': {} + } + + # Load supervisor config to get all services + config_file = session_dir / ServiceSupervisorCommon.SUPERVISOR_CONFIG_FILE + if not config_file.exists(): + status_dict['error'] = f"No supervisor config file found at {config_file}" + return status_dict + + try: + service_configs = ServiceConfigMap.from_file(config_file) + except Exception as e: + status_dict['error'] = f"Error reading supervisor config: {e}" + return status_dict + + # Check supervisor status + supervisor_info_file = session_dir / ServiceSupervisorCommon.SUPERVISOR_INFO_FILE + if supervisor_info_file.exists(): + try: + supervisor_info = SupervisorInfo.from_file(supervisor_info_file) + is_alive = ServiceSupervisorCommon._is_alive(supervisor_info.pid, supervisor_info.create_time) + log_path = session_dir / ServiceSupervisorCommon.SUPERVISOR_LOG_FILE + status_dict['supervisor'] = { + 'pid': supervisor_info.pid, + 'status': 'RUNNING' if is_alive else 'STOPPED', + 'started': supervisor_info.create_time, + 'log': str(log_path) + } + except Exception as e: + status_dict['supervisor'] = {'error': f"Error reading info - {e}"} + else: + status_dict['supervisor'] = {'error': "No info file found"} + + # Iterate through all configured services + for service_identifier, config in service_configs.items(): + if isinstance(config, NonWorkerServiceConfig): + # Non-worker service (single instance) + service_dir = session_dir / config.service_name + info_file = service_dir / ServiceSupervisorCommon.PROCESS_INFO_FILE + + service_entry = { + 'type': 'service', + 'command': config.command, + 'pid': None, + 'status': None, + 'started': None, + 'last_check': None, + 'failures': None, + 'stdout_log': None, + 'stderr_log': None, + 'error': None + } + + if info_file.exists(): + try: + info = ServiceInfo.from_file(info_file) + is_alive = ServiceSupervisorCommon._is_alive(info.pid, info.create_time) + service_entry.update({ + 'pid': info.pid, + 'status': 'ALIVE' if is_alive else 'DEAD', + 'started': info.create_time, + 'last_check': info.last_check, + 'failures': info.failures, + 'stdout_log': str(service_dir / 'stdout.log'), + 'stderr_log': str(service_dir / 'stderr.log') + }) + except Exception as e: + service_entry['error'] = f"Could not read info file - {e}" + else: + service_entry['error'] = "No info file found" + + status_dict['services'][config.service_name] = service_entry + + elif isinstance(config, WorkerServiceConfig): + # Worker service (multiple instances) + service_entry = { + 'type': 'worker', + 'command': config.command, + 'workers': {} + } + + for worker_num in range(config.num_workers): + worker_dir = session_dir / config.service_name / str(worker_num) + info_file = worker_dir / ServiceSupervisorCommon.PROCESS_INFO_FILE + + worker_entry = { + 'pid': None, + 'status': None, + 'started': None, + 'last_check': None, + 'failures': None, + 'stdout_log': None, + 'stderr_log': None, + 'error': None + } + + if info_file.exists(): + try: + info = ServiceInfo.from_file(info_file) + is_alive = ServiceSupervisorCommon._is_alive(info.pid, info.create_time) + worker_entry.update({ + 'pid': info.pid, + 'status': 'ALIVE' if is_alive else 'DEAD', + 'started': info.create_time, + 'last_check': info.last_check, + 'failures': info.failures, + 'stdout_log': str(worker_dir / 'stdout.log'), + 'stderr_log': str(worker_dir / 'stderr.log') + }) + except Exception as e: + worker_entry['error'] = f"Could not read info file - {e}" + else: + worker_entry['error'] = "No info file found" + + service_entry['workers'][worker_num] = worker_entry + + status_dict['services'][config.service_name] = service_entry + else: + assert_never(config) + + return status_dict + + + @staticmethod + def get_service_configs(supervisor_dir: Path) -> ServiceConfigMap | None: + session_dir = ServiceSupervisorController._get_latest_session_dir(supervisor_dir) + if session_dir is None: + return None + config_file = session_dir / ServiceSupervisorCommon.SUPERVISOR_CONFIG_FILE + try: + return ServiceConfigMap.from_file(config_file) + except Exception as e: + # TODO addd full traceback + logger.warning(f"Could not read supervisor config file {config_file}, due to exception: {e}") diff --git a/src/airflow_provider_aiida/aiida_core/engine/daemon/airflow_daemon.py b/src/airflow_provider_aiida/aiida_core/engine/daemon/airflow_daemon.py new file mode 100644 index 0000000..1f87ae3 --- /dev/null +++ b/src/airflow_provider_aiida/aiida_core/engine/daemon/airflow_daemon.py @@ -0,0 +1,113 @@ +from __future__ import annotations + +from airflow_provider_aiida.aiida_core.engine.daemon._supervisor import ( + ServiceSupervisorController, + NonWorkerServiceConfig, + ServiceConfigFactory, + ServiceConfigMap + ) +from pathlib import Path +from dataclasses import dataclass +from typing import ClassVar, TYPE_CHECKING +import os + +from airflow_provider_aiida.aiida_core.manage.configuration.config import get_airflow_home + +if TYPE_CHECKING: + from aiida.manage.configuration import Profile + from aiida.manage.configuration.config import Config + +def get_daemon_dir(profile: Profile, config: Config): + from aiida.manage.configuration.settings import AiiDAConfigPathResolver + config_path_resolver: AiiDAConfigPathResolver = AiiDAConfigPathResolver(Path(config.dirpath)) + daemon_dir = config_path_resolver.daemon_dir + return daemon_dir / f"{profile.name}" + +@dataclass +class AirflowDagProcessorServiceConfig(NonWorkerServiceConfig): + service_name: ClassVar[str] = "airflow-dag-processor" + command: ClassVar[str] = "airflow dag-processor" + airflow_home: str + + def create_unique_env(self) -> dict[str, str]: + return {'AIRFLOW_HOME': self.airflow_home} + +@dataclass +class AirflowSchedulerServiceConfig(NonWorkerServiceConfig): + # We do not want any limit on this + service_name: ClassVar[str] = "airflow-scheduler" + command: ClassVar[str] = "airflow scheduler" + airflow_home: str + num_workers: int + + + def create_unique_env(self) -> dict[str, str]: + return {'AIRFLOW_HOME': self.airflow_home, + 'AIRFLOW__CORE__PARALLELISM': str(self.num_workers)} + +@dataclass +class AirflowTriggererServiceConfig(NonWorkerServiceConfig): + # We do not want any limit on this + service_name: ClassVar[str] = "airflow-triggerer" + command: ClassVar[str] = "airflow-provider-aiida-triggerer-service" + airflow_home: str + num_triggerers: int + + + def create_unique_env(self) -> dict[str, str]: + return { + 'AIRFLOW_HOME': self.airflow_home, + 'AIRFLOW__CORE__ASYNC_PARALLELISM': str(self.num_triggerers)} + +@dataclass +class AirflowApiServerServiceConfig(NonWorkerServiceConfig): + # We do not want any limit on this + service_name: ClassVar[str] = "airlfow-api-server" + command: ClassVar[str] = "airflow api-server" + airflow_home: str + + def create_unique_env(self) -> dict[str, str]: + return {'AIRFLOW_HOME': self.airflow_home} + +class AirflowDaemon: + + def __init__(self, profile_identifier): + from aiida.manage import get_manager + manager = get_manager() + profile = manager.load_profile() if profile_identifier is None else manager.load_profile(profile_identifier) + + # Validate profile storage backend + if profile.storage_backend != 'core.psql_dos': + raise ValueError( + f"Profile '{profile.name}' uses unsupported storage backend '{profile.storage_backend}'. " + f"Only 'core.psql_dos' (PostgreSQL) is supported." + ) + self._daemon_dir = get_daemon_dir(profile, manager.get_config()) + self._daemon_dir.mkdir(exist_ok=True) + + self._airflow_home = get_airflow_home(profile) + + def start(self, num_workers: int, num_triggerers: int, foreground: bool): + scheduler_config = AirflowSchedulerServiceConfig(num_workers=num_workers, airflow_home=str(self._airflow_home)) + dag_processor_config = AirflowDagProcessorServiceConfig(airflow_home=str(self._airflow_home)) + api_server_config = AirflowApiServerServiceConfig(airflow_home=str(self._airflow_home)) + triggerer_config = AirflowTriggererServiceConfig(num_triggerers=num_triggerers, airflow_home=str(self._airflow_home)) + + service_configs = ServiceConfigMap([scheduler_config, dag_processor_config, api_server_config, triggerer_config]) + ServiceSupervisorController.start(self._daemon_dir, service_configs, foreground) + + def stop(self): + ServiceSupervisorController.stop(self._daemon_dir) + + def status(self) -> dict: + status_report = ServiceSupervisorController.status(self._daemon_dir) + if (configs := ServiceSupervisorController.get_service_configs(self._daemon_dir)) is not None: + for config in configs.values(): + if isinstance(config, AirflowSchedulerServiceConfig): + status_report['services'][f'{config.service_name}']['num_workers'] = config.num_workers + elif isinstance(config, AirflowTriggererServiceConfig): + status_report['services'][f'{config.service_name}']['num_workers'] = config.num_triggerers + + return status_report + breakpoint() + return status_report diff --git a/src/airflow_provider_aiida/aiida_core/engine/daemon/triggerer_service.py b/src/airflow_provider_aiida/aiida_core/engine/daemon/triggerer_service.py new file mode 100644 index 0000000..7855aac --- /dev/null +++ b/src/airflow_provider_aiida/aiida_core/engine/daemon/triggerer_service.py @@ -0,0 +1,197 @@ +import signal +import subprocess +import sys +import time +from pathlib import Path + + +class TriggererService: + """Manages multiple Airflow triggerer processes.""" + + def __init__(self, num_triggerers: int, airflow_home: str | None = None): + self.num_triggerers = num_triggerers + self.airflow_home = airflow_home + self.processes: list[subprocess.Popen] = [] + self.base_port = self._get_base_port() + + def _get_base_port(self) -> int: + """Get the base trigger log server port from airflow.cfg.""" + from airflow.configuration import AirflowConfigParser + + config = AirflowConfigParser() + + if self.airflow_home: + config_file = Path(self.airflow_home) / 'airflow.cfg' + if config_file.exists(): + config.read(str(config_file)) + + # Get port from config, default to 8794 + return config.getint('logging', 'trigger_log_server_port', fallback=8794) + + def _start_triggerer(self, worker_num: int) -> subprocess.Popen: + """Start a single triggerer process with a unique port.""" + # TODO check if port available if not add +1 + port = self.base_port + worker_num + + # Build environment with unique port + import os + env = os.environ.copy() + env['AIRFLOW__LOGGING__TRIGGER_LOG_SERVER_PORT'] = str(port) + + if self.airflow_home: + env['AIRFLOW_HOME'] = self.airflow_home + + print(f"Starting triggerer #{worker_num} on port {port}") + + # Start triggerer process - output goes to terminal stdout/stderr + process = subprocess.Popen( + ['airflow', 'triggerer'], + env=os.environ | env, + stdout=sys.stdout, + stderr=sys.stderr, + text=True + ) + + return process + + def start(self): + """Start all triggerer processes.""" + print(f"Starting {self.num_triggerers} triggerer(s)") + print(f"Base port: {self.base_port}") + print() + + # Start all triggerers + for i in range(self.num_triggerers): + # TODO check if port available if not add +1 + process = self._start_triggerer(i) + self.processes.append(process) + # Small delay to avoid startup race conditions + time.sleep(0.5) + + print() + print(f"Successfully started {len(self.processes)} triggerer(s)") + print("Press Ctrl+C to stop all triggerers") + + def stop(self): + """Stop all triggerer processes.""" + print("\nStopping all triggerers...") + + for i, process in enumerate(self.processes): + if process.poll() is None: # Still running + print(f"Stopping triggerer #{i} (PID {process.pid})") + process.terminate() + + # Wait for graceful shutdown (max 10 seconds) + for process in self.processes: + try: + process.wait(timeout=10) + except subprocess.TimeoutExpired: + print(f"Force killing triggerer (PID {process.pid})") + process.kill() + + print("All triggerers stopped") + + def monitor(self): + """Monitor triggerer processes and restart if they crash.""" + try: + while True: + # Check if any process has died + for i, process in enumerate(self.processes): + if process.poll() is not None: # Process has exited + returncode = process.returncode + print(f"\nTriggerer #{i} died with exit code {returncode}") + print(f"Restarting triggerer #{i}...") + + # Restart the process + new_process = self._start_triggerer(i) + self.processes[i] = new_process + + # Sleep before next check + time.sleep(5) + + except KeyboardInterrupt: + print("\nReceived interrupt signal") + self.stop() + + def run(self): + """Run the supervisor - start processes and monitor them.""" + # Setup signal handlers + signal.signal(signal.SIGTERM, self._signal_handler) + signal.signal(signal.SIGINT, self._signal_handler) + + # Start all triggerers + self.start() + + # Monitor them + self.monitor() + + def _signal_handler(self, signum, frame): + """Handle termination signals.""" + print(f"\nReceived signal {signum}") + self.stop() + sys.exit(0) + + + + +def main(): + """Main entry point.""" + from airflow_provider_aiida.aiida_core import load_profile + load_profile() + + import argparse + import os + parser = argparse.ArgumentParser( + description='Supervise multiple Airflow triggerer processes', + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Examples: + # Start 3 triggerers with default airflow home + python triggerer_supervisor.py 3 + + # Start 5 triggerers with custom airflow home + python triggerer_supervisor.py 5 --airflow-home /path/to/airflow + + # Start 2 triggerers (they will use ports 8794 and 8795 by default) + python triggerer_supervisor.py 2 + +The script will: + 1. Read the base port from airflow.cfg (default: 8794) + 2. Start N triggerers on ports: base_port, base_port+1, base_port+2, ... + 3. Redirect all output to logs/triggerers_stdout.log and logs/triggerers_stderr.log + 4. Monitor processes and restart them if they crash + 5. Gracefully stop all triggerers when receiving Ctrl+C or SIGTERM + """ + ) + + parser.add_argument( + 'num_triggerers', + type=int, + nargs='?', + default=None, + help='Number of triggerer processes to start (default: from AIRFLOW__CORE__ASYNC_PARALLELISM env var, or 1)' + ) + + args = parser.parse_args() + + # Determine num_triggerers: CLI arg > env var > default (1) + if args.num_triggerers is not None: + num_triggerers = args.num_triggerers + elif (env_value := os.environ.get("AIRFLOW__CORE__ASYNC_PARALLELISM")) is not None: + num_triggerers = int(env_value) + else: + raise ValueError("Value for number of triggerers was not provided and AIRFLOW__CORE__ASYNC_PARALLELISM is not set.") + + if num_triggerers < 1: + parser.error("Number of triggerers must be at least 1") + + # Create and run supervisor + supervisor = TriggererService( + num_triggerers=num_triggerers, + airflow_home=os.environ.get("AIRFLOW_HOME"), + ) + + supervisor.run() + +if __name__ == '__main__': + main()