diff --git a/src/forge/controller/__init__.py b/src/forge/controller/__init__.py index e3215720a..a191d931a 100644 --- a/src/forge/controller/__init__.py +++ b/src/forge/controller/__init__.py @@ -6,11 +6,10 @@ from .actor import ForgeActor from .proc_mesh import get_proc_mesh, spawn_actors from .recoverable_mesh import RecoverableProcMesh -from .service import AutoscalingConfig, Service, ServiceConfig +from .service import Service, ServiceConfig from .spawn import spawn_service __all__ = [ - "AutoscalingConfig", "Service", "ServiceConfig", "spawn_service", diff --git a/src/forge/controller/service.md b/src/forge/controller/service.md index 1b6cb558a..3a8134fa3 100644 --- a/src/forge/controller/service.md +++ b/src/forge/controller/service.md @@ -1,19 +1,13 @@ # Service - Distributed Actor Service Controller -A robust service orchestration system for managing distributed actor-based workloads with automatic scaling, fault tolerance, and intelligent load balancing. +A robust service orchestration system for managing distributed actor-based workloads with fault tolerance and intelligent load balancing. ## Overview -The Service class provides a unified interface for deploying and managing multiple replicas of actor-based services across distributed compute resources. It automatically handles replica lifecycle, request routing, session management, and resource scaling based on real-time metrics. +The Service class provides a unified interface for deploying and managing multiple replicas of actor-based services across distributed compute resources. It automatically handles replica lifecycle, request routing, and session management. ## Key Features -### **Automatic Scaling** -- **Scale Up**: Automatically adds replicas based on queue depth, capacity utilization, and request rate -- **Scale Down**: Intelligently removes underutilized replicas while preserving workload -- **Emergency Scaling**: Rapid scale-up during traffic spikes -- **Configurable Thresholds**: Customizable scaling triggers and cooldown periods - ### **Fault Tolerance** - **Health Monitoring**: Continuous health checks with automatic replica recovery - **Request Migration**: Seamless migration of requests from failed replicas @@ -61,7 +55,7 @@ The Service class provides a unified interface for deploying and managing multip ### Basic Service Setup ```python -from forge.controller.service import Service, ServiceConfig, AutoscalingConfig +from forge.controller.service import Service, ServiceConfig # Configure service parameters config = ServiceConfig( @@ -70,11 +64,6 @@ config = ServiceConfig( max_replicas=10, default_replicas=3, replica_max_concurrent_requests=10, - autoscaling=AutoscalingConfig( - scale_up_queue_depth_threshold=5.0, - scale_up_capacity_threshold=0.8, - scale_down_capacity_threshold=0.3 - ) ) # Create service with your actor definition diff --git a/src/forge/controller/service.py b/src/forge/controller/service.py index 88808d724..2b41b05b1 100644 --- a/src/forge/controller/service.py +++ b/src/forge/controller/service.py @@ -21,9 +21,7 @@ >>> config = ServiceConfig( ... gpus_per_replica=1, - ... min_replicas=2, - ... max_replicas=10, - ... default_replicas=3 + ... num_replicas=3 ... ) >>> service = Service(config, MyActorClass, *args, **kwargs) >>> await service.__initialize__() @@ -40,7 +38,6 @@ import logging import pprint import time -import traceback import uuid from collections import defaultdict, deque from dataclasses import dataclass, field @@ -52,12 +49,7 @@ from forge.controller import RecoverableProcMesh logger = logging.getLogger(__name__) -logger.setLevel(logging.INFO) - -# Global context variable for session state -_session_context: contextvars.ContextVar[dict | None] = contextvars.ContextVar( - "session_context", default=None -) +logger.setLevel(logging.DEBUG) # TODO - tie this into metric logger when it exists @@ -195,65 +187,16 @@ def get_sessions_per_replica(self) -> float: return self.total_sessions / self.healthy_replicas -@dataclass -class AutoscalingConfig: - """Configuration for autoscaling behavior.""" - - # Autoscaling control - enabled: bool = False # Whether autoscaling is enabled (disabled by default) - - # Scale up thresholds - scale_up_queue_depth_threshold: float = ( - 5.0 # Average queue depth to trigger scale up - ) - scale_up_capacity_threshold: float = 0.8 # Capacity utilization to trigger scale up - scale_up_request_rate_threshold: float = 10.0 # Requests/sec to trigger scale up - - # Scale down thresholds - scale_down_capacity_threshold: float = ( - 0.3 # Capacity utilization to trigger scale down - ) - scale_down_queue_depth_threshold: float = ( - 1.0 # Average queue depth to trigger scale down - ) - scale_down_idle_time_threshold: float = ( - 300.0 # Seconds of low utilization before scale down - ) - - # Timing controls - min_time_between_scale_events: float = ( - 60.0 # Minimum seconds between scaling events - ) - scale_up_cooldown: float = 30.0 # Cooldown after scale up - scale_down_cooldown: float = 120.0 # Cooldown after scale down - - # Scaling behavior - scale_up_step_size: int = 1 # How many replicas to add at once - scale_down_step_size: int = 1 # How many replicas to remove at once - - # Safety limits - max_queue_depth_emergency: float = 20.0 # Emergency scale up threshold - min_healthy_replicas_ratio: float = 0.5 # Minimum ratio of healthy replicas - - @dataclass class ServiceConfig: procs_per_replica: int - min_replicas: int - max_replicas: int - default_replicas: int - autoscaling: AutoscalingConfig = field(default_factory=AutoscalingConfig) + num_replicas: int health_poll_rate: float = 0.2 replica_max_concurrent_requests: int = 10 return_first_rank_result: bool = ( True # Auto-unwrap ValueMesh to first rank's result ) - def validate(self): - assert self.min_replicas <= self.max_replicas - assert self.min_replicas <= self.default_replicas - assert self.default_replicas <= self.max_replicas - @dataclass class Replica: @@ -272,6 +215,13 @@ class Session: session_id: str +# Global context variable for session state +# This is used to propagate session state across async tasks +_session_context: contextvars.ContextVar[dict | None] = contextvars.ContextVar( + "session_context", default=None +) + + class SessionContext: """Context manager for service sessions using context variables.""" @@ -307,13 +257,12 @@ class Service: The Service acts as a unified interface for distributed workloads, automatically handling: - **Fault Tolerance**: Health monitoring, automatic replica recovery, request migration - - **Autoscaling**: Dynamic scaling based on queue depth, capacity, and request rate - **Load Balancing**: Round-robin, least-loaded, and session-affinity routing - **Session Management**: Stateful session handling with context propagation - **Metrics Collection**: Comprehensive performance and health monitoring Args: - cfg: Service configuration including scaling limits and autoscaling parameters + cfg: Service configuration including number of replicas, GPUs per replica, and health polling rate actor_def: Actor class definition to instantiate on each replica *actor_args: Positional arguments passed to actor constructor **actor_kwargs: Keyword arguments passed to actor constructor @@ -323,13 +272,7 @@ class Service: >>> config = ServiceConfig( ... gpus_per_replica=1, - ... min_replicas=2, - ... max_replicas=10, - ... default_replicas=3, - ... autoscaling=AutoscalingConfig( - ... scale_up_capacity_threshold=0.8, - ... scale_down_capacity_threshold=0.3 - ... ) + ... num_replicas=3, ... ) >>> service = Service(config, MyActorClass, model_path="/path/to/model") >>> await service.__initialize__() @@ -354,8 +297,6 @@ class Service: def __init__(self, cfg: ServiceConfig, actor_def, *actor_args, **actor_kwargs): self._cfg = cfg - self._cfg.validate() - self._replicas = [] self._actor_def = actor_def self._actor_args = actor_args @@ -390,14 +331,41 @@ def __init__(self, cfg: ServiceConfig, actor_def, *actor_args, **actor_kwargs): self._add_endpoint_method(func_name) async def __initialize__(self): - await self._scale_up(self._cfg.default_replicas) + logger.debug("Starting service up with %d replicas.", self._cfg.num_replicas) + replicas = [] + num_replicas = self._cfg.num_replicas + for i in range(num_replicas): + mesh = RecoverableProcMesh( + self._cfg.procs_per_replica, + ) + replica = Replica( + proc_mesh=mesh, + actor=None, + idx=len(self._replicas) + i, + max_concurrent_requests=self._cfg.replica_max_concurrent_requests, + ) + replicas.append(replica) + + # Initializing should only happen in the health_loop + # and during the first initialization. + # If multiple parts of the code try to initialize replicas at + # the same time, it can cause nasty race conditions + # (e.g., double initialization, inconsistent state, or resource conflicts). + # By funneling all replica initialization through a single queue and the + # health loop, we ensure safe, serialized initialization. + logger.debug( + "Queued %d replicas for initialization. Total replicas: %d", + num_replicas, + len(self._replicas), + ) + self._replicas_to_init.extend(replicas) await self._maybe_init_replicas() + self._replicas.extend(replicas) # Start the health loop in the background self._health_task = asyncio.create_task( self._health_loop(poll_rate_s=self._cfg.health_poll_rate) ) - # call setup on all replicas, if it exists def _add_endpoint_method(self, endpoint_name: str): """Dynamically adds an endpoint method to this Service instance.""" @@ -748,173 +716,6 @@ def get_metrics_summary(self) -> dict: return summary - def _should_scale_up(self) -> tuple[bool, str]: - """Determines if the service should scale up and why.""" - # Skip if shutdown is requested - if self._shutdown_requested: - return False, "Service shutdown requested" - - now = time.time() - cfg = self._cfg.autoscaling - - # Check cooldown periods - if now - self._last_scale_up_time < cfg.scale_up_cooldown: - return False, f"Scale up cooldown active ({cfg.scale_up_cooldown}s)" - - if ( - now - max(self._last_scale_up_time, self._last_scale_down_time) - < cfg.min_time_between_scale_events - ): - return ( - False, - f"Minimum time between scale events not met ({cfg.min_time_between_scale_events}s)", - ) - - # Check if we're already at max replicas - if len(self._replicas) >= self._cfg.max_replicas: - return False, f"Already at max replicas ({self._cfg.max_replicas})" - - # Get current metrics - self._update_service_metrics() - avg_queue_depth = self._metrics.get_avg_queue_depth() - avg_capacity = self._metrics.get_avg_capacity_utilization(self._replicas) - request_rate = self._metrics.get_total_request_rate() - - # Emergency scale up - very high queue depth - if avg_queue_depth >= cfg.max_queue_depth_emergency: - return ( - True, - f"Emergency scale up: queue depth {avg_queue_depth:.1f} >= {cfg.max_queue_depth_emergency}", - ) - - # Scale up conditions - reasons = [] - - if avg_queue_depth >= cfg.scale_up_queue_depth_threshold: - reasons.append( - f"queue depth {avg_queue_depth:.1f} >= {cfg.scale_up_queue_depth_threshold}" - ) - - if avg_capacity >= cfg.scale_up_capacity_threshold: - reasons.append( - f"capacity utilization {avg_capacity:.2f} >= {cfg.scale_up_capacity_threshold}" - ) - - if request_rate >= cfg.scale_up_request_rate_threshold and avg_queue_depth > 0: - reasons.append( - f"high request rate {request_rate:.1f} req/s with queued requests" - ) - - # Need at least one strong signal to scale up - if reasons: - return True, f"Scale up triggered: {', '.join(reasons)}" - - return False, "No scale up conditions met" - - def _should_scale_down(self) -> tuple[bool, str]: - """Determines if the service should scale down and why.""" - now = time.time() - cfg = self._cfg.autoscaling - - # Check cooldown periods - if now - self._last_scale_down_time < cfg.scale_down_cooldown: - return False, f"Scale down cooldown active ({cfg.scale_down_cooldown}s)" - - if ( - now - max(self._last_scale_up_time, self._last_scale_down_time) - < cfg.min_time_between_scale_events - ): - return ( - False, - f"Minimum time between scale events not met ({cfg.min_time_between_scale_events}s)", - ) - - # Check if we're already at min replicas - if len(self._replicas) <= self._cfg.min_replicas: - return False, f"Already at min replicas ({self._cfg.min_replicas})" - - # Check minimum healthy replicas ratio - healthy_replicas = [r for r in self._replicas if r.proc_mesh.healthy] - healthy_count = len(healthy_replicas) - if healthy_count <= len(self._replicas) * cfg.min_healthy_replicas_ratio: - return ( - False, - f"Too few healthy replicas ({healthy_count}/{len(self._replicas)})", - ) - - # Get current metrics - self._update_service_metrics() - avg_queue_depth = self._metrics.get_avg_queue_depth() - avg_capacity = self._metrics.get_avg_capacity_utilization(healthy_replicas) - - # Check if conditions are met for scale down - low_utilization = ( - avg_capacity <= cfg.scale_down_capacity_threshold - and avg_queue_depth <= cfg.scale_down_queue_depth_threshold - ) - - if low_utilization: - # Track how long we've been in low utilization - if self._low_utilization_start_time is None: - self._low_utilization_start_time = now - return False, "Low utilization detected, starting timer" - - # Check if we've been in low utilization long enough - low_util_duration = now - self._low_utilization_start_time - if low_util_duration >= cfg.scale_down_idle_time_threshold: - return ( - True, - f"Scale down: low utilization for {low_util_duration:.1f}s " - f"(capacity: {avg_capacity:.2f}, queue: {avg_queue_depth:.1f})", - ) - else: - return ( - False, - f"Low utilization for {low_util_duration:.1f}s, need {cfg.scale_down_idle_time_threshold}s", - ) - else: - # Reset low utilization timer - self._low_utilization_start_time = None - return ( - False, - f"Utilization too high for scale down (capacity: {avg_capacity:.2f}, queue: {avg_queue_depth:.1f})", - ) - - async def _execute_autoscaling(self): - """Executes autoscaling decisions based on current metrics.""" - # Skip autoscaling if shutdown is requested - if self._shutdown_requested: - return - - # Skip autoscaling if disabled - if not self._cfg.autoscaling.enabled: - return - - # Check scale up first (higher priority) - should_scale_up, scale_up_reason = self._should_scale_up() - if should_scale_up: - logger.debug("🔼 AUTOSCALING: %s", scale_up_reason) - await self._scale_up(self._cfg.autoscaling.scale_up_step_size) - self._last_scale_up_time = time.time() - self._metrics.last_scale_event = self._last_scale_up_time - return - - # Check scale down - should_scale_down, scale_down_reason = self._should_scale_down() - if should_scale_down: - logger.debug("🔽 AUTOSCALING: %s", scale_down_reason) - await self._scale_down_replicas(self._cfg.autoscaling.scale_down_step_size) - self._last_scale_down_time = time.time() - self._metrics.last_scale_event = self._last_scale_down_time - return - - # Log why we're not scaling (for debugging) - logger.debug( - "No autoscaling action: Scale up - %s, Scale down - %s", - scale_up_reason, - scale_down_reason, - ) - async def terminate_session(self, sess_id: str): """ Terminates an active session and cleans up associated resources. @@ -946,6 +747,13 @@ async def terminate_session(self, sess_id: str): self._update_service_metrics() async def _health_loop(self, poll_rate_s: float): + """Runs the health loop to monitor and recover replicas. + + This loop continuously checks the health of replicas and recovers + failed replicas by reinitializing their proc_meshes. It also + periodically updates service metrics to reflect the current state. + + """ while not self._shutdown_requested: # Process any replicas that need initialization await self._maybe_init_replicas() @@ -964,16 +772,6 @@ async def _health_loop(self, poll_rate_s: float): ) self._replicas_to_init.extend(failed_replicas) - # Execute autoscaling logic - try: - await self._execute_autoscaling() - except Exception as e: - logger.error( - "Error in autoscaling: %s\nTraceback:\n%s", - e, - traceback.format_exc(), - ) - await asyncio.sleep(poll_rate_s) async def _custom_replica_routing( diff --git a/tests/test_service.py b/tests/test_service.py index ab83d301e..7283aeeee 100644 --- a/tests/test_service.py +++ b/tests/test_service.py @@ -12,7 +12,7 @@ import logging import pytest -from forge.controller.service import AutoscalingConfig, ServiceConfig +from forge.controller.service import ServiceConfig from forge.controller.spawn import spawn_service from monarch.actor import Actor, endpoint @@ -43,7 +43,7 @@ async def fail_me(self): @endpoint async def slow_incr(self): - """Slow increment to test queueing and autoscaling.""" + """Slow increment to test queueing.""" await asyncio.sleep(1.0) self.v += 1 @@ -55,9 +55,7 @@ async def slow_incr(self): @pytest.mark.asyncio async def test_basic_service_operations(): """Test basic service creation, sessions, and endpoint calls.""" - cfg = ServiceConfig( - procs_per_replica=1, min_replicas=1, max_replicas=2, default_replicas=1 - ) + cfg = ServiceConfig(procs_per_replica=1, num_replicas=1) service = await spawn_service(service_cfg=cfg, actor_def=Counter, v=0) try: @@ -87,9 +85,7 @@ async def test_basic_service_operations(): @pytest.mark.asyncio async def test_sessionless_calls(): """Test sessionless calls with round-robin load balancing.""" - cfg = ServiceConfig( - procs_per_replica=1, min_replicas=2, max_replicas=2, default_replicas=2 - ) + cfg = ServiceConfig(procs_per_replica=1, num_replicas=2) service = await spawn_service(service_cfg=cfg, actor_def=Counter, v=0) try: @@ -119,9 +115,7 @@ async def test_sessionless_calls(): @pytest.mark.asyncio async def test_session_context_manager(): """Test session context manager functionality.""" - cfg = ServiceConfig( - procs_per_replica=1, min_replicas=1, max_replicas=1, default_replicas=1 - ) + cfg = ServiceConfig(procs_per_replica=1, num_replicas=1) service = await spawn_service(service_cfg=cfg, actor_def=Counter, v=0) try: @@ -162,9 +156,7 @@ async def worker(increments: int): @pytest.mark.asyncio async def test_replica_failure_and_recovery(): """Test replica failure handling and automatic recovery.""" - cfg = ServiceConfig( - procs_per_replica=1, min_replicas=2, max_replicas=2, default_replicas=2 - ) + cfg = ServiceConfig(procs_per_replica=1, num_replicas=2) service = await spawn_service(service_cfg=cfg, actor_def=Counter, v=0) try: @@ -197,157 +189,6 @@ async def test_replica_failure_and_recovery(): await service.stop() -# Autoscaling Tests - - -@pytest.mark.timeout(20) -@pytest.mark.asyncio -async def test_autoscaling_scale_up(): - """Test automatic scale up under high load.""" - autoscaling_cfg = AutoscalingConfig( - enabled=True, - scale_up_capacity_threshold=0.5, - scale_up_queue_depth_threshold=2.0, - scale_up_cooldown=0.5, - min_time_between_scale_events=0.5, - ) - - cfg = ServiceConfig( - procs_per_replica=1, - min_replicas=1, - max_replicas=3, - default_replicas=1, - autoscaling=autoscaling_cfg, - replica_max_concurrent_requests=1, # Force queueing - ) - service = await spawn_service(service_cfg=cfg, actor_def=Counter, v=0) - - try: - initial_replica_count = len(service._replicas) - assert initial_replica_count == 1 - - # Create high load with slow operations - sessions = [await service.start_session() for _ in range(3)] - tasks = [service.slow_incr(session) for session in sessions for _ in range(2)] - - # Start tasks and wait for scaling - await asyncio.gather(*tasks) - - # Check if scaling occurred - scaled_up = False - for _ in range(10): - await asyncio.sleep(0.5) - if len(service._replicas) > initial_replica_count: - scaled_up = True - break - - assert ( - scaled_up - ), f"Expected scale up, but replicas remained at {len(service._replicas)}" - - finally: - await service.stop() - - -@pytest.mark.timeout(25) -@pytest.mark.asyncio -async def test_autoscaling_scale_down(): - """Test automatic scale down when idle.""" - autoscaling_cfg = AutoscalingConfig( - enabled=True, - scale_down_capacity_threshold=0.2, - scale_down_queue_depth_threshold=0.5, - scale_down_idle_time_threshold=3.0, - scale_down_cooldown=1.0, - min_time_between_scale_events=1.0, - ) - - cfg = ServiceConfig( - procs_per_replica=1, - min_replicas=1, - max_replicas=3, - default_replicas=2, # Start with 2 replicas - autoscaling=autoscaling_cfg, - ) - service = await spawn_service(service_cfg=cfg, actor_def=Counter, v=0) - - try: - initial_replica_count = len(service._replicas) - assert initial_replica_count == 2 - - # Make minimal requests to establish baseline - session = await service.start_session() - await service.incr(session) - - # Wait for scale down - max_wait = 10.0 - waited = 0.0 - scaled_down = False - - while waited < max_wait: - await asyncio.sleep(1.0) - waited += 1.0 - if len(service._replicas) < initial_replica_count: - scaled_down = True - break - - assert ( - scaled_down - ), f"Expected scale down, but replicas remained at {len(service._replicas)}" - assert len(service._replicas) >= cfg.min_replicas - - finally: - await service.stop() - - -@pytest.mark.timeout(10) -@pytest.mark.asyncio -async def test_autoscaling_limits(): - """Test that autoscaling respects min/max limits.""" - autoscaling_cfg = AutoscalingConfig( - enabled=True, - scale_up_queue_depth_threshold=1.0, - scale_down_capacity_threshold=0.9, # High threshold to prevent scale down - ) - - cfg = ServiceConfig( - procs_per_replica=1, - min_replicas=1, - max_replicas=2, # Tight limit - default_replicas=1, - autoscaling=autoscaling_cfg, - ) - service = await spawn_service(service_cfg=cfg, actor_def=Counter, v=0) - - try: - # Test max limit - should_scale_up, reason = service._should_scale_up() - - # Manually scale to max - await service._scale_up(1) - assert len(service._replicas) == 2 - - # Should not scale beyond max - should_scale_up, reason = service._should_scale_up() - assert not should_scale_up - assert "max replicas" in reason.lower() - - # Test min limit - should_scale_down, reason = service._should_scale_down() - - # Scale down to min - await service._scale_down_replicas(1) - assert len(service._replicas) == 1 - - # Should not scale below min - should_scale_down, reason = service._should_scale_down() - assert not should_scale_down - assert "min replicas" in reason.lower() - - finally: - await service.stop() - - # Metrics and Monitoring Tests @@ -355,9 +196,7 @@ async def test_autoscaling_limits(): @pytest.mark.asyncio async def test_metrics_collection(): """Test comprehensive metrics collection.""" - cfg = ServiceConfig( - procs_per_replica=1, min_replicas=2, max_replicas=2, default_replicas=2 - ) + cfg = ServiceConfig(procs_per_replica=1, num_replicas=2) service = await spawn_service(service_cfg=cfg, actor_def=Counter, v=0) try: @@ -410,9 +249,7 @@ async def test_metrics_collection(): @pytest.mark.asyncio async def test_session_stickiness(): """Test that sessions stick to the same replica.""" - cfg = ServiceConfig( - procs_per_replica=1, min_replicas=2, max_replicas=2, default_replicas=2 - ) + cfg = ServiceConfig(procs_per_replica=1, num_replicas=2) service = await spawn_service(service_cfg=cfg, actor_def=Counter, v=0) try: @@ -441,9 +278,7 @@ async def test_session_stickiness(): @pytest.mark.asyncio async def test_load_balancing_multiple_sessions(): """Test load balancing across multiple sessions using least-loaded assignment.""" - cfg = ServiceConfig( - procs_per_replica=1, min_replicas=2, max_replicas=2, default_replicas=2 - ) + cfg = ServiceConfig(procs_per_replica=1, num_replicas=2) service = await spawn_service(service_cfg=cfg, actor_def=Counter, v=0) try: @@ -487,9 +322,7 @@ async def test_load_balancing_multiple_sessions(): @pytest.mark.asyncio async def test_concurrent_operations(): """Test concurrent operations across sessions and sessionless calls.""" - cfg = ServiceConfig( - procs_per_replica=1, min_replicas=2, max_replicas=2, default_replicas=2 - ) + cfg = ServiceConfig(procs_per_replica=1, num_replicas=2) service = await spawn_service( service_cfg=cfg, actor_def=Counter, name="counter", v=0 )