diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 5409d73164b9..953579cb395d 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -528,7 +528,21 @@ def cleanup_router_config_variables(): async def proxy_shutdown_event(): global prisma_client, master_key, user_custom_auth, user_custom_key_generate + global _scheduler_instance + verbose_proxy_logger.info("Shutting down LiteLLM Proxy Server") + + # Shutdown APScheduler first (before disconnecting other resources) + if _scheduler_instance is not None: + try: + verbose_proxy_logger.info("Shutting down APScheduler...") + _scheduler_instance.shutdown(wait=False) # Don't wait for running jobs + _scheduler_instance = None + verbose_proxy_logger.info("APScheduler shutdown complete") + except Exception as e: + verbose_proxy_logger.error(f"Error shutting down scheduler: {e}") + + # Then disconnect database (after scheduler is stopped) if prisma_client: verbose_proxy_logger.debug("Disconnecting from Prisma") await prisma_client.disconnect() @@ -1001,6 +1015,9 @@ def swagger_monkey_patch(*args, **kwargs): scheduler = None last_model_cost_map_reload = None +### SCHEDULER INSTANCE ### +_scheduler_instance = None # Track APScheduler instance for proper shutdown + ### DB WRITER ### db_writer_client: Optional[AsyncHTTPHandler] = None ### logger ### @@ -3602,6 +3619,290 @@ def giveup(e): return result +class DatabaseJobsCoordinator: + """ + Tracks last execution times for all database jobs and ensures + they run at their specified intervals within the unified coordinator. + + This prevents concurrent database access that can cause deadlocks. + """ + + def __init__( + self, + proxy_budget_rescheduler_min_time: int, + proxy_budget_rescheduler_max_time: int, + proxy_batch_write_at: int, + ): + self.last_run_times: Dict[str, Optional[datetime]] = { + "update_spend": None, + "reset_budget": None, + "add_deployment": None, + "get_credentials": None, + "spend_log_cleanup": None, + "check_batch_cost": None, + } + + # Pre-calculate randomized intervals at initialization (only once) + # This ensures consistent behavior matching the original implementation + self.task_intervals = { + "update_spend": random.randint( + proxy_batch_write_at - 3, proxy_batch_write_at + 3 + ), + "reset_budget": random.randint( + proxy_budget_rescheduler_min_time, + proxy_budget_rescheduler_max_time, + ), + } + + def should_run(self, job_name: str, interval_seconds: int) -> bool: + """ + Check if enough time has passed since last run + + Args: + job_name: Name of the job to check + interval_seconds: Required interval in seconds + + Returns: + True if job should run, False otherwise + """ + last_run = self.last_run_times.get(job_name) + if last_run is None: + return True + + elapsed = (datetime.now() - last_run).total_seconds() + return elapsed >= interval_seconds + + def mark_run(self, job_name: str): + """Mark job as completed at current time""" + self.last_run_times[job_name] = datetime.now() + + def set_last_run_time(self, job_name: str, run_time: Optional[datetime]): + """ + Set the last run time for a specific job (for testing purposes) + + Args: + job_name: Name of the job + run_time: The datetime to set as last run time, or None to reset + + Raises: + ValueError: If job_name is not valid + """ + if job_name not in self.last_run_times: + raise ValueError( + f"Invalid job name: {job_name}. " + f"Valid jobs: {list(self.last_run_times.keys())}" + ) + self.last_run_times[job_name] = run_time + + def get_last_run_time(self, job_name: str) -> Optional[datetime]: + """ + Get the last run time for a specific job (for testing purposes) + + Args: + job_name: Name of the job + + Returns: + Last run time or None if never run + + Raises: + ValueError: If job_name is not valid + """ + if job_name not in self.last_run_times: + raise ValueError( + f"Invalid job name: {job_name}. " + f"Valid jobs: {list(self.last_run_times.keys())}" + ) + return self.last_run_times[job_name] + + def get_task_interval(self, job_name: str) -> Optional[int]: + """ + Get the pre-calculated interval for a specific job (for testing purposes) + + Args: + job_name: Name of the job + + Returns: + Pre-calculated interval in seconds, or None if not pre-calculated + """ + return self.task_intervals.get(job_name) + + +async def high_frequency_database_jobs_coordinator( + prisma_client: PrismaClient, + db_writer_client: Optional[AsyncHTTPHandler], + proxy_logging_obj: ProxyLogging, + coordinator: DatabaseJobsCoordinator, + proxy_config: ProxyConfig, + llm_router: Optional[Router], +): + """ + High-frequency database jobs coordinator - handles tasks that need to run every 10 seconds. + + Runs every 10 seconds and executes: + - add_deployment: Refresh model configuration from database (10s interval) + - get_credentials: Refresh credentials from database (10s interval) + - update_spend: Update spend logs (60s interval) + - check_batch_cost: Check batch costs (60s interval) + + If one job fails, execution continues to the next job to maintain system stability. + + Args: + prisma_client: Database client + db_writer_client: Async HTTP handler for batch writes + proxy_logging_obj: Logging object + coordinator: Job coordinator tracking last run times + proxy_config: Proxy configuration object + llm_router: LLM router instance + """ + global store_model_in_db + + try: + # Only run high-frequency tasks if store_model_in_db is enabled + if store_model_in_db is True: + # 1. Add deployment / Refresh model config (runs every 10 seconds) + if coordinator.should_run("add_deployment", 10): + try: + await proxy_config.add_deployment( + prisma_client=prisma_client, + proxy_logging_obj=proxy_logging_obj + ) + coordinator.mark_run("add_deployment") + verbose_proxy_logger.debug("✓ High-freq coordinator: add_deployment completed") + except Exception as e: + verbose_proxy_logger.error(f"✗ High-freq coordinator: add_deployment failed - {e}") + + # 2. Get credentials (runs every 10 seconds) + if coordinator.should_run("get_credentials", 10): + try: + await proxy_config.get_credentials(prisma_client=prisma_client) + coordinator.mark_run("get_credentials") + verbose_proxy_logger.debug("✓ High-freq coordinator: get_credentials completed") + except Exception as e: + verbose_proxy_logger.error(f"✗ High-freq coordinator: get_credentials failed - {e}") + + # 3. Update spend logs (uses pre-calculated randomized interval ~60s) + if coordinator.should_run("update_spend", coordinator.task_intervals["update_spend"]): + try: + await update_spend(prisma_client, db_writer_client, proxy_logging_obj) + coordinator.mark_run("update_spend") + verbose_proxy_logger.debug("✓ High-freq coordinator: update_spend completed") + except Exception as e: + verbose_proxy_logger.error(f"✗ High-freq coordinator: update_spend failed - {e}") + + # 4. Check batch cost (runs every 60 seconds) + if llm_router is not None: + if coordinator.should_run("check_batch_cost", 60): + try: + from litellm_enterprise.proxy.common_utils.check_batch_cost import ( + CheckBatchCost, + ) + check_batch_cost_job = CheckBatchCost( + proxy_logging_obj=proxy_logging_obj, + prisma_client=prisma_client, + llm_router=llm_router, + ) + await check_batch_cost_job.check_batch_cost() + coordinator.mark_run("check_batch_cost") + verbose_proxy_logger.debug("✓ High-freq coordinator: check_batch_cost completed") + except ImportError: + # Enterprise feature not available - this is expected + pass + except Exception as e: + verbose_proxy_logger.error(f"✗ High-freq coordinator: check_batch_cost failed - {e}") + + except Exception as e: + verbose_proxy_logger.error( + f"High-frequency database jobs coordinator error: {e}\n{traceback.format_exc()}" + ) + + +async def low_frequency_database_jobs_coordinator( # noqa: PLR0915 + prisma_client: PrismaClient, + db_writer_client: Optional[AsyncHTTPHandler], + proxy_logging_obj: ProxyLogging, + coordinator: DatabaseJobsCoordinator, + general_settings: dict, + llm_router: Optional[Router], + proxy_budget_rescheduler_min_time: int, + proxy_budget_rescheduler_max_time: int, + proxy_batch_write_at: int, + proxy_batch_polling_interval: int, +): + """ + Low-frequency database jobs coordinator - handles tasks that run less frequently. + + Runs every 30 minutes and executes: + - reset_budget: Reset budget (uses pre-calculated interval 3600-7200s) + - spend_log_cleanup: Clean old logs (based on retention interval) + + If one job fails, execution continues to the next job to maintain system stability. + + Args: + prisma_client: Database client + db_writer_client: Async HTTP handler for batch writes + proxy_logging_obj: Logging object + coordinator: Job coordinator tracking last run times and pre-calculated intervals + general_settings: General proxy settings + llm_router: LLM router instance + proxy_budget_rescheduler_min_time: Min time for budget reset randomization (used for initialization) + proxy_budget_rescheduler_max_time: Max time for budget reset randomization (used for initialization) + proxy_batch_write_at: Base interval for batch writing (used for initialization) + proxy_batch_polling_interval: Interval for batch polling + """ + try: + # 1. Reset budget (uses pre-calculated randomized interval from initialization) + if coordinator.should_run("reset_budget", coordinator.task_intervals["reset_budget"]): + try: + if general_settings.get("disable_reset_budget", False) is False: + budget_reset_job = ResetBudgetJob( + proxy_logging_obj=proxy_logging_obj, + prisma_client=prisma_client, + ) + await budget_reset_job.reset_budget() + coordinator.mark_run("reset_budget") + verbose_proxy_logger.debug("✓ Low-freq coordinator: reset_budget completed") + except Exception as e: + verbose_proxy_logger.error(f"✗ Low-freq coordinator: reset_budget failed - {e}") + + # 2. Spend log cleanup (runs based on configured retention interval) + if general_settings.get("maximum_spend_logs_retention_period") is not None: + retention_interval = general_settings.get( + "maximum_spend_logs_retention_interval", "1d" + ) + + # Parse interval with fallback to default + try: + interval_seconds = duration_in_seconds(retention_interval) + except ValueError as e: + # Invalid configuration - use default and log warning + verbose_proxy_logger.warning( + f"✗ Low-freq coordinator: Invalid retention interval '{retention_interval}': {e}. Using default 1 day (86400s)." + ) + interval_seconds = 86400 # 1 day default + except Exception as e: + # Unexpected error - log and skip this iteration + verbose_proxy_logger.error( + f"✗ Low-freq coordinator: Unexpected error parsing retention interval '{retention_interval}': {e}" + ) + interval_seconds = None + + # Execute cleanup task if interval was successfully parsed + if interval_seconds is not None: + try: + if coordinator.should_run("spend_log_cleanup", interval_seconds): + spend_log_cleanup = SpendLogCleanup() + await spend_log_cleanup.cleanup_old_spend_logs(prisma_client) + coordinator.mark_run("spend_log_cleanup") + verbose_proxy_logger.debug("✓ Low-freq coordinator: spend_log_cleanup completed") + except Exception as e: + verbose_proxy_logger.error(f"✗ Low-freq coordinator: spend_log_cleanup failed - {e}") + + except Exception as e: + verbose_proxy_logger.error( + f"Low-frequency database jobs coordinator error: {e}\n{traceback.format_exc()}" + ) + + class ProxyStartupEvent: @classmethod def _initialize_startup_logging( @@ -3695,63 +3996,80 @@ async def initialize_scheduled_background_jobs( proxy_batch_write_at: int, proxy_logging_obj: ProxyLogging, ): - """Initializes scheduled background jobs""" - global store_model_in_db + """ + Initializes TWO scheduled background jobs - separate coordinators for high and low frequency tasks. + + This prevents high-frequency tasks (10s) from being blocked by low-frequency tasks (60s+). + """ + global store_model_in_db, llm_router, db_writer_client, proxy_config, _scheduler_instance + scheduler = AsyncIOScheduler() - interval = random.randint( - proxy_budget_rescheduler_min_time, proxy_budget_rescheduler_max_time - ) # random interval, so multiple workers avoid resetting budget at the same time - batch_writing_interval = random.randint( - proxy_batch_write_at - 3, proxy_batch_write_at + 3 - ) # random interval, so multiple workers avoid batch writing at the same time - - ### RESET BUDGET ### - if general_settings.get("disable_reset_budget", False) is False: - budget_reset_job = ResetBudgetJob( - proxy_logging_obj=proxy_logging_obj, - prisma_client=prisma_client, - ) - - scheduler.add_job( - budget_reset_job.reset_budget, - "interval", - seconds=interval, - ) - - ### UPDATE SPEND ### - scheduler.add_job( - update_spend, - "interval", - seconds=batch_writing_interval, - args=[prisma_client, db_writer_client, proxy_logging_obj], + + # Assign to global variable immediately after creation for proper lifecycle management + _scheduler_instance = scheduler + + # Initialize the coordinator state tracker with pre-calculated intervals (shared by both coordinators) + coordinator = DatabaseJobsCoordinator( + proxy_budget_rescheduler_min_time=proxy_budget_rescheduler_min_time, + proxy_budget_rescheduler_max_time=proxy_budget_rescheduler_max_time, + proxy_batch_write_at=proxy_batch_write_at, ) - - ### ADD NEW MODELS ### + + # Load all existing models on proxy startup if store_model_in_db is enabled store_model_in_db = ( get_secret_bool("STORE_MODEL_IN_DB", store_model_in_db) or store_model_in_db ) - + if store_model_in_db is True: - scheduler.add_job( - proxy_config.add_deployment, - "interval", - seconds=10, - args=[prisma_client, proxy_logging_obj], - ) - - # this will load all existing models on proxy startup + # Load all existing models on proxy startup await proxy_config.add_deployment( prisma_client=prisma_client, proxy_logging_obj=proxy_logging_obj ) - - ### GET STORED CREDENTIALS ### - scheduler.add_job( - proxy_config.get_credentials, - "interval", - seconds=10, - args=[prisma_client], - ) await proxy_config.get_credentials(prisma_client=prisma_client) + + # Add high-frequency coordinator job - runs every 10 seconds + # Handles: add_deployment, get_credentials, update_spend, check_batch_cost + scheduler.add_job( + high_frequency_database_jobs_coordinator, + "interval", + seconds=10, + max_instances=1, + coalesce=False, + misfire_grace_time=5, + args=[ + prisma_client, + db_writer_client, + proxy_logging_obj, + coordinator, + proxy_config, + llm_router, + ], + ) + + # Add low-frequency coordinator job - runs every 30 minutes + # Handles: reset_budget, spend_log_cleanup + scheduler.add_job( + low_frequency_database_jobs_coordinator, + "interval", + seconds=1800, + max_instances=1, + coalesce=False, + misfire_grace_time=1200, + args=[ + prisma_client, + db_writer_client, + proxy_logging_obj, + coordinator, + general_settings, + llm_router, + proxy_budget_rescheduler_min_time, + proxy_budget_rescheduler_max_time, + proxy_batch_write_at, + proxy_batch_polling_interval, + ], + ) + + # Keep other non-database jobs (alerting, etc.) if ( proxy_logging_obj is not None and proxy_logging_obj.slack_alerting_instance.alerting is not None @@ -3759,7 +4077,6 @@ async def initialize_scheduled_background_jobs( ): print("Alerting: Initializing Weekly/Monthly Spend Reports") # noqa ### Schedule weekly/monthly spend reports ### - ### Schedule spend reports ### spend_report_frequency: str = ( general_settings.get("spend_report_frequency", "7d") or "7d" ) @@ -3801,49 +4118,6 @@ async def initialize_scheduled_background_jobs( await cls._initialize_spend_tracking_background_jobs(scheduler=scheduler) - ### SPEND LOG CLEANUP ### - if general_settings.get("maximum_spend_logs_retention_period") is not None: - spend_log_cleanup = SpendLogCleanup() - # Get the interval from config or default to 1 day - retention_interval = general_settings.get( - "maximum_spend_logs_retention_interval", "1d" - ) - try: - interval_seconds = duration_in_seconds(retention_interval) - scheduler.add_job( - spend_log_cleanup.cleanup_old_spend_logs, - "interval", - seconds=interval_seconds, - args=[prisma_client], - ) - except ValueError: - verbose_proxy_logger.error( - "Invalid maximum_spend_logs_retention_interval value" - ) - ### CHECK BATCH COST ### - if llm_router is not None: - try: - from litellm_enterprise.proxy.common_utils.check_batch_cost import ( - CheckBatchCost, - ) - - check_batch_cost_job = CheckBatchCost( - proxy_logging_obj=proxy_logging_obj, - prisma_client=prisma_client, - llm_router=llm_router, - ) - scheduler.add_job( - check_batch_cost_job.check_batch_cost, - "interval", - seconds=proxy_batch_polling_interval, # these can run infrequently, as batch jobs take time to complete - ) - - except Exception: - verbose_proxy_logger.debug( - "Checking batch cost for LiteLLM Managed Files is an Enterprise Feature. Skipping..." - ) - pass - scheduler.start() @classmethod diff --git a/tests/test_litellm/proxy/test_database_jobs_coordinator.py b/tests/test_litellm/proxy/test_database_jobs_coordinator.py new file mode 100644 index 000000000000..f5a4831e070f --- /dev/null +++ b/tests/test_litellm/proxy/test_database_jobs_coordinator.py @@ -0,0 +1,600 @@ +""" +Tests for the optimized database jobs coordinator scheduling logic + +This test suite validates: +1. Correct interval pre-calculation at initialization +2. Proper timing logic for job execution +3. Independence of different job schedules +4. Consistency with original implementation behavior +""" +import pytest +from datetime import datetime, timedelta +from litellm.proxy.proxy_server import DatabaseJobsCoordinator + + +class TestDatabaseJobsCoordinatorInitialization: + """Test DatabaseJobsCoordinator initialization logic""" + + def test_coordinator_initialization(self): + """Verify coordinator correctly initializes all tasks""" + coordinator = DatabaseJobsCoordinator( + proxy_budget_rescheduler_min_time=3600, + proxy_budget_rescheduler_max_time=7200, + proxy_batch_write_at=60, + ) + + # Verify all tasks are initialized + expected_jobs = [ + "update_spend", + "reset_budget", + "add_deployment", + "get_credentials", + "spend_log_cleanup", + "check_batch_cost" + ] + + for job in expected_jobs: + assert job in coordinator.last_run_times + assert coordinator.last_run_times[job] is None # Initially None + + def test_precalculated_intervals_in_range(self): + """Verify pre-calculated intervals are in correct range""" + coordinator = DatabaseJobsCoordinator( + proxy_budget_rescheduler_min_time=3600, + proxy_budget_rescheduler_max_time=7200, + proxy_batch_write_at=60, + ) + + # update_spend: proxy_batch_write_at ± 3 seconds + update_interval = coordinator.get_task_interval("update_spend") + assert update_interval is not None + assert 57 <= update_interval <= 63 + + # reset_budget: proxy_budget_rescheduler_min/max_time range + reset_interval = coordinator.get_task_interval("reset_budget") + assert reset_interval is not None + assert 3600 <= reset_interval <= 7200 + + def test_intervals_remain_constant(self): + """Verify interval values remain constant throughout lifecycle""" + coordinator = DatabaseJobsCoordinator( + proxy_budget_rescheduler_min_time=3600, + proxy_budget_rescheduler_max_time=7200, + proxy_batch_write_at=60, + ) + + # Record initial intervals + initial_update = coordinator.get_task_interval("update_spend") + initial_reset = coordinator.get_task_interval("reset_budget") + + # Simulate multiple calls to should_run and mark_run + for _ in range(100): + coordinator.should_run("update_spend", 60) + coordinator.mark_run("update_spend") + + # Verify interval values haven't changed at all + assert coordinator.get_task_interval("update_spend") == initial_update + assert coordinator.get_task_interval("reset_budget") == initial_reset + + def test_different_instances_have_different_intervals(self): + """Verify different instances have different random intervals (avoid multi-worker collisions)""" + coordinators = [ + DatabaseJobsCoordinator( + proxy_budget_rescheduler_min_time=3600, + proxy_budget_rescheduler_max_time=7200, + proxy_batch_write_at=60, + ) + for _ in range(20) + ] + + update_intervals = [c.get_task_interval("update_spend") for c in coordinators] + reset_intervals = [c.get_task_interval("reset_budget") for c in coordinators] + + # Should have some variation (randomization effect) + unique_update = len(set(update_intervals)) + unique_reset = len(set(reset_intervals)) + + # With 20 instances, should have at least 2 different values + assert unique_update >= 2, f"update_spend interval lacks randomization: {unique_update} unique values" + assert unique_reset >= 2, f"reset_budget interval lacks randomization: {unique_reset} unique values" + + +class TestDatabaseJobsCoordinatorTiming: + """Test task execution timing logic""" + + def test_should_run_never_executed(self): + """Test that never-executed tasks should run immediately""" + coordinator = DatabaseJobsCoordinator( + proxy_budget_rescheduler_min_time=3600, + proxy_budget_rescheduler_max_time=7200, + proxy_batch_write_at=60, + ) + + # All tasks should return True in initial state + assert coordinator.should_run("update_spend", 60) is True + assert coordinator.should_run("reset_budget", 3600) is True + assert coordinator.should_run("add_deployment", 10) is True + + def test_should_run_just_executed(self): + """Test that just-executed tasks should not run again immediately""" + coordinator = DatabaseJobsCoordinator( + proxy_budget_rescheduler_min_time=3600, + proxy_budget_rescheduler_max_time=7200, + proxy_batch_write_at=60, + ) + + now = datetime.now() + + # Set as just executed + coordinator.set_last_run_time("update_spend", now) + + # Should not meet interval requirement + assert coordinator.should_run("update_spend", 60) is False + + def test_should_run_after_interval(self): + """Test that tasks should run after interval has passed""" + coordinator = DatabaseJobsCoordinator( + proxy_budget_rescheduler_min_time=3600, + proxy_budget_rescheduler_max_time=7200, + proxy_batch_write_at=60, + ) + + now = datetime.now() + + # Set as executed 61 seconds ago (exceeds 60 second interval) + coordinator.set_last_run_time("update_spend", now - timedelta(seconds=61)) + assert coordinator.should_run("update_spend", 60) is True + + # Set as executed 59 seconds ago (doesn't meet 60 second interval) + coordinator.set_last_run_time("update_spend", now - timedelta(seconds=59)) + assert coordinator.should_run("update_spend", 60) is False + + def test_should_run_exact_interval_boundary(self): + """Test behavior at exact interval boundary""" + coordinator = DatabaseJobsCoordinator( + proxy_budget_rescheduler_min_time=3600, + proxy_budget_rescheduler_max_time=7200, + proxy_batch_write_at=60, + ) + + now = datetime.now() + + # Exactly 60 seconds ago + coordinator.set_last_run_time("update_spend", now - timedelta(seconds=60)) + assert coordinator.should_run("update_spend", 60) is True + + # Exactly 60.001 seconds ago (just slightly over) + coordinator.set_last_run_time("update_spend", now - timedelta(seconds=60.001)) + assert coordinator.should_run("update_spend", 60) is True + + def test_mark_run_updates_time(self): + """Test mark_run correctly updates execution time""" + coordinator = DatabaseJobsCoordinator( + proxy_budget_rescheduler_min_time=3600, + proxy_budget_rescheduler_max_time=7200, + proxy_batch_write_at=60, + ) + + # Initial state + assert coordinator.get_last_run_time("update_spend") is None + + # Mark as run + before_mark = datetime.now() + coordinator.mark_run("update_spend") + after_mark = datetime.now() + + # Verify time was updated + last_run = coordinator.get_last_run_time("update_spend") + assert last_run is not None + assert before_mark <= last_run <= after_mark + + def test_independent_job_timing(self): + """Test timing independence of different tasks""" + coordinator = DatabaseJobsCoordinator( + proxy_budget_rescheduler_min_time=3600, + proxy_budget_rescheduler_max_time=7200, + proxy_batch_write_at=60, + ) + + now = datetime.now() + + # Set different execution times for different tasks + coordinator.set_last_run_time("update_spend", now - timedelta(seconds=70)) + coordinator.set_last_run_time("reset_budget", now - timedelta(seconds=10)) + coordinator.set_last_run_time("add_deployment", now - timedelta(seconds=15)) + + # Verify each task is judged independently + assert coordinator.should_run("update_spend", 60) is True # 70s > 60s + assert coordinator.should_run("reset_budget", 3600) is False # 10s < 3600s + assert coordinator.should_run("add_deployment", 10) is True # 15s > 10s + + +class TestDatabaseJobsCoordinatorEdgeCases: + """Test edge cases and exception handling""" + + def test_set_invalid_job_name(self): + """Test setting invalid job name should raise exception""" + coordinator = DatabaseJobsCoordinator( + proxy_budget_rescheduler_min_time=3600, + proxy_budget_rescheduler_max_time=7200, + proxy_batch_write_at=60, + ) + + with pytest.raises(ValueError, match="Invalid job name"): + coordinator.set_last_run_time("invalid_job", datetime.now()) + + def test_get_invalid_job_name(self): + """Test getting invalid job name should raise exception""" + coordinator = DatabaseJobsCoordinator( + proxy_budget_rescheduler_min_time=3600, + proxy_budget_rescheduler_max_time=7200, + proxy_batch_write_at=60, + ) + + with pytest.raises(ValueError, match="Invalid job name"): + coordinator.get_last_run_time("invalid_job") + + def test_set_none_resets_time(self): + """Test that setting None resets execution time""" + coordinator = DatabaseJobsCoordinator( + proxy_budget_rescheduler_min_time=3600, + proxy_budget_rescheduler_max_time=7200, + proxy_batch_write_at=60, + ) + + # First set a time + coordinator.set_last_run_time("update_spend", datetime.now()) + assert coordinator.get_last_run_time("update_spend") is not None + + # Reset to None + coordinator.set_last_run_time("update_spend", None) + assert coordinator.get_last_run_time("update_spend") is None + + # Should behave like never executed + assert coordinator.should_run("update_spend", 60) is True + + +class TestDatabaseJobsCoordinatorRealWorldScenarios: + """Test real-world scenarios""" + + def test_typical_update_spend_cycle(self): + """Simulate typical update_spend execution cycle""" + coordinator = DatabaseJobsCoordinator( + proxy_budget_rescheduler_min_time=3600, + proxy_budget_rescheduler_max_time=7200, + proxy_batch_write_at=60, + ) + + interval = coordinator.get_task_interval("update_spend") + + # First time: should execute + assert coordinator.should_run("update_spend", interval) is True + coordinator.mark_run("update_spend") + + # Check after 10 seconds: should not execute + coordinator.set_last_run_time( + "update_spend", + datetime.now() - timedelta(seconds=10) + ) + assert coordinator.should_run("update_spend", interval) is False + + # Check after interval seconds: should execute + coordinator.set_last_run_time( + "update_spend", + datetime.now() - timedelta(seconds=interval) + ) + assert coordinator.should_run("update_spend", interval) is True + + def test_typical_reset_budget_cycle(self): + """Simulate typical reset_budget execution cycle""" + coordinator = DatabaseJobsCoordinator( + proxy_budget_rescheduler_min_time=3600, + proxy_budget_rescheduler_max_time=7200, + proxy_batch_write_at=60, + ) + + interval = coordinator.get_task_interval("reset_budget") + + # Verify interval is between 1-2 hours + assert 3600 <= interval <= 7200 + + # Simulate execution cycle + assert coordinator.should_run("reset_budget", interval) is True + coordinator.mark_run("reset_budget") + + # After 2.5 hours: definitely should execute (exceeds max interval of 2h) + coordinator.set_last_run_time( + "reset_budget", + datetime.now() - timedelta(hours=2.5) + ) + assert coordinator.should_run("reset_budget", interval) is True + + def test_multiple_jobs_concurrent_scheduling(self): + """Test concurrent scheduling of multiple tasks""" + coordinator = DatabaseJobsCoordinator( + proxy_budget_rescheduler_min_time=3600, + proxy_budget_rescheduler_max_time=7200, + proxy_batch_write_at=60, + ) + + now = datetime.now() + + # Set different execution times for different tasks + coordinator.set_last_run_time("update_spend", now - timedelta(seconds=65)) + coordinator.set_last_run_time("reset_budget", now - timedelta(seconds=1800)) + coordinator.set_last_run_time("add_deployment", now - timedelta(seconds=5)) + coordinator.set_last_run_time("get_credentials", now - timedelta(seconds=12)) + + # Verify each task is judged independently + update_interval = coordinator.get_task_interval("update_spend") + reset_interval = coordinator.get_task_interval("reset_budget") + + assert coordinator.should_run("update_spend", update_interval) is True # 65s > ~60s + assert coordinator.should_run("reset_budget", reset_interval) is False # 30min < any random interval (1-2h) + assert coordinator.should_run("add_deployment", 10) is False # 5s < 10s + assert coordinator.should_run("get_credentials", 10) is True # 12s > 10s + + def test_consistent_behavior_across_coordinator_lifecycle(self): + """Test coordinator behavior consistency throughout lifecycle""" + coordinator = DatabaseJobsCoordinator( + proxy_budget_rescheduler_min_time=3600, + proxy_budget_rescheduler_max_time=7200, + proxy_batch_write_at=60, + ) + + update_interval = coordinator.get_task_interval("update_spend") + + # Simulate multiple execution cycles + for cycle in range(10): + # Simulate time passage + now = datetime.now() + + # update_spend should run every interval seconds + coordinator.set_last_run_time("update_spend", now - timedelta(seconds=update_interval + 1)) + assert coordinator.should_run("update_spend", update_interval) is True + coordinator.mark_run("update_spend") + + # Immediate check should return False + assert coordinator.should_run("update_spend", update_interval) is False + + def test_all_jobs_have_expected_intervals(self): + """Verify all tasks have expected interval settings""" + coordinator = DatabaseJobsCoordinator( + proxy_budget_rescheduler_min_time=3600, + proxy_budget_rescheduler_max_time=7200, + proxy_batch_write_at=60, + ) + + # update_spend and reset_budget should have pre-calculated intervals + assert coordinator.get_task_interval("update_spend") is not None + assert coordinator.get_task_interval("reset_budget") is not None + + # Other tasks should not have pre-calculated intervals (they use fixed intervals) + assert coordinator.get_task_interval("add_deployment") is None + assert coordinator.get_task_interval("get_credentials") is None + assert coordinator.get_task_interval("spend_log_cleanup") is None + assert coordinator.get_task_interval("check_batch_cost") is None + + +class TestDatabaseJobsCoordinatorConsistencyWithOriginal: + """Verify new implementation consistency with original implementation""" + + def test_update_spend_interval_matches_original(self): + """Verify update_spend interval logic matches original implementation""" + # Original: random.randint(proxy_batch_write_at - 3, proxy_batch_write_at + 3) + coordinator = DatabaseJobsCoordinator( + proxy_budget_rescheduler_min_time=3600, + proxy_budget_rescheduler_max_time=7200, + proxy_batch_write_at=60, + ) + + interval = coordinator.get_task_interval("update_spend") + + # Should be in [57, 63] range (60 ± 3) + assert 57 <= interval <= 63 + + def test_reset_budget_interval_matches_original(self): + """Verify reset_budget interval logic matches original implementation""" + # Original: random.randint(proxy_budget_rescheduler_min_time, proxy_budget_rescheduler_max_time) + coordinator = DatabaseJobsCoordinator( + proxy_budget_rescheduler_min_time=3600, + proxy_budget_rescheduler_max_time=7200, + proxy_batch_write_at=60, + ) + + interval = coordinator.get_task_interval("reset_budget") + + # Should be in [3600, 7200] range + assert 3600 <= interval <= 7200 + + def test_should_run_logic_matches_original(self): + """Verify should_run logic matches original implementation""" + coordinator = DatabaseJobsCoordinator( + proxy_budget_rescheduler_min_time=3600, + proxy_budget_rescheduler_max_time=7200, + proxy_batch_write_at=60, + ) + + # Original logic: + # if last_run is None: return True + # elapsed = (datetime.now() - last_run).total_seconds() + # return elapsed >= interval_seconds + + # Test None case + assert coordinator.should_run("update_spend", 60) is True + + # Test just executed + coordinator.set_last_run_time("update_spend", datetime.now()) + assert coordinator.should_run("update_spend", 60) is False + + # Test after interval + coordinator.set_last_run_time( + "update_spend", + datetime.now() - timedelta(seconds=61) + ) + assert coordinator.should_run("update_spend", 60) is True + + +class TestDatabaseJobsCoordinatorParameterVariations: + """Test behavior under different parameter configurations""" + + def test_different_batch_write_intervals(self): + """Test different proxy_batch_write_at values""" + for batch_write_at in [30, 60, 120, 300]: + coordinator = DatabaseJobsCoordinator( + proxy_budget_rescheduler_min_time=3600, + proxy_budget_rescheduler_max_time=7200, + proxy_batch_write_at=batch_write_at, + ) + + interval = coordinator.get_task_interval("update_spend") + assert interval is not None + assert batch_write_at - 3 <= interval <= batch_write_at + 3 + + def test_different_budget_rescheduler_ranges(self): + """Test different budget_rescheduler ranges""" + test_cases = [ + (1800, 3600), # 30 minutes to 1 hour + (3600, 7200), # 1 hour to 2 hours + (7200, 14400), # 2 hours to 4 hours + ] + + for min_time, max_time in test_cases: + coordinator = DatabaseJobsCoordinator( + proxy_budget_rescheduler_min_time=min_time, + proxy_budget_rescheduler_max_time=max_time, + proxy_batch_write_at=60, + ) + + interval = coordinator.get_task_interval("reset_budget") + assert interval is not None + assert min_time <= interval <= max_time + + def test_default_values_from_constants(self): + """Test using default constant values""" + # Simulate default values (from litellm.constants) + # PROXY_BUDGET_RESCHEDULER_MIN_TIME = 3600 + # PROXY_BUDGET_RESCHEDULER_MAX_TIME = 7200 + # PROXY_BATCH_WRITE_AT = 60 + + coordinator = DatabaseJobsCoordinator( + proxy_budget_rescheduler_min_time=3600, + proxy_budget_rescheduler_max_time=7200, + proxy_batch_write_at=60, + ) + + # Verify behavior with default values + update_interval = coordinator.get_task_interval("update_spend") + reset_interval = coordinator.get_task_interval("reset_budget") + + assert 57 <= update_interval <= 63 + assert 3600 <= reset_interval <= 7200 + + +class TestDatabaseJobsCoordinatorHelperMethods: + """Test helper method functionality""" + + def test_get_task_interval_returns_none_for_non_precalculated(self): + """Test get_task_interval returns None for non-pre-calculated tasks""" + coordinator = DatabaseJobsCoordinator( + proxy_budget_rescheduler_min_time=3600, + proxy_budget_rescheduler_max_time=7200, + proxy_batch_write_at=60, + ) + + # These tasks don't have pre-calculated intervals + assert coordinator.get_task_interval("add_deployment") is None + assert coordinator.get_task_interval("get_credentials") is None + assert coordinator.get_task_interval("spend_log_cleanup") is None + assert coordinator.get_task_interval("check_batch_cost") is None + + def test_set_and_get_last_run_time(self): + """Test set_last_run_time and get_last_run_time round-trip""" + coordinator = DatabaseJobsCoordinator( + proxy_budget_rescheduler_min_time=3600, + proxy_budget_rescheduler_max_time=7200, + proxy_batch_write_at=60, + ) + + test_time = datetime.now() - timedelta(hours=1) + + # Set time + coordinator.set_last_run_time("update_spend", test_time) + + # Get time and verify + retrieved_time = coordinator.get_last_run_time("update_spend") + assert retrieved_time == test_time + + def test_helper_methods_for_all_jobs(self): + """Verify helper methods work correctly for all tasks""" + coordinator = DatabaseJobsCoordinator( + proxy_budget_rescheduler_min_time=3600, + proxy_budget_rescheduler_max_time=7200, + proxy_batch_write_at=60, + ) + + all_jobs = [ + "update_spend", + "reset_budget", + "add_deployment", + "get_credentials", + "spend_log_cleanup", + "check_batch_cost" + ] + + test_time = datetime.now() - timedelta(minutes=30) + + for job in all_jobs: + # Should be able to set and get time + coordinator.set_last_run_time(job, test_time) + assert coordinator.get_last_run_time(job) == test_time + + # Should be able to reset to None + coordinator.set_last_run_time(job, None) + assert coordinator.get_last_run_time(job) is None + + +class TestDatabaseJobsCoordinatorRandomizationBehavior: + """Test randomization behavior (key feature to prevent multi-worker conflicts)""" + + def test_randomization_prevents_synchronized_execution(self): + """Verify randomization prevents synchronized execution across workers""" + # Create multiple coordinator instances (simulating multiple workers) + num_workers = 10 + coordinators = [ + DatabaseJobsCoordinator( + proxy_budget_rescheduler_min_time=3600, + proxy_budget_rescheduler_max_time=7200, + proxy_batch_write_at=60, + ) + for _ in range(num_workers) + ] + + # Collect all update_spend intervals + update_intervals = [c.get_task_interval("update_spend") for c in coordinators] + + # Verify there's at least some variation (not all workers have same interval) + unique_intervals = set(update_intervals) + assert len(unique_intervals) > 1, "All workers have same interval, defeats randomization purpose" + + def test_randomization_distribution(self): + """Test whether randomization distribution is reasonable""" + # Create many instances to verify distribution + num_samples = 100 + coordinators = [ + DatabaseJobsCoordinator( + proxy_budget_rescheduler_min_time=3600, + proxy_budget_rescheduler_max_time=7200, + proxy_batch_write_at=60, + ) + for _ in range(num_samples) + ] + + update_intervals = [c.get_task_interval("update_spend") for c in coordinators] + + # Verify values are in expected range + assert all(57 <= i <= 63 for i in update_intervals) + + # Verify distribution diversity (should have at least 3 different values) + unique_count = len(set(update_intervals)) + assert unique_count >= 3, f"Distribution lacks diversity, only {unique_count} unique values" diff --git a/tests/test_litellm/proxy/test_proxy_server.py b/tests/test_litellm/proxy/test_proxy_server.py index 63436899fccf..ef64a16fdd82 100644 --- a/tests/test_litellm/proxy/test_proxy_server.py +++ b/tests/test_litellm/proxy/test_proxy_server.py @@ -1,19 +1,13 @@ import asyncio -import importlib import json import os -import socket -import subprocess import sys from datetime import datetime from unittest import mock -from unittest.mock import AsyncMock, MagicMock, mock_open, patch +from unittest.mock import AsyncMock, MagicMock, patch -import click -import httpx import pytest import yaml -from fastapi import FastAPI from fastapi.testclient import TestClient sys.path.insert( @@ -51,13 +45,6 @@ } -def mock_patch_aembedding(): - return mock.patch( - "litellm.proxy.proxy_server.llm_router.aembedding", - return_value=example_embedding_result, - ) - - @pytest.fixture(scope="function") def client_no_auth(): # Assuming litellm.proxy.proxy_server is an object @@ -152,7 +139,6 @@ async def test_aaaproxy_startup_master_key(mock_prisma, monkeypatch, tmp_path): """ Test that master_key is correctly loaded from either config.yaml or environment variables """ - import yaml from fastapi import FastAPI # Import happens here - this is when the module probably reads the config path @@ -173,11 +159,8 @@ async def test_aaaproxy_startup_master_key(mock_prisma, monkeypatch, tmp_path): with open(config_path, "w") as f: yaml.dump(test_config, f) - print(f"SET ENV VARIABLE - CONFIG_FILE_PATH, str(config_path): {str(config_path)}") # Second setting of CONFIG_FILE_PATH to a different value monkeypatch.setenv("CONFIG_FILE_PATH", str(config_path)) - print(f"config_path: {config_path}") - print(f"os.getenv('CONFIG_FILE_PATH'): {os.getenv('CONFIG_FILE_PATH')}") async with proxy_startup_event(app): from litellm.proxy.proxy_server import master_key @@ -192,7 +175,6 @@ async def test_aaaproxy_startup_master_key(mock_prisma, monkeypatch, tmp_path): yaml.dump(empty_config, f) monkeypatch.setenv("LITELLM_MASTER_KEY", test_env_master_key) - print("test_env_master_key: {}".format(test_env_master_key)) async with proxy_startup_event(app): from litellm.proxy.proxy_server import master_key @@ -237,49 +219,79 @@ def test_team_info_masking(): all_teams_config=[team1_info], ) - print("Got exception: {}".format(exc_info.value)) assert "secret-test-key" not in str(exc_info.value) assert "public-test-key" not in str(exc_info.value) -@mock_patch_aembedding() -def test_embedding_input_array_of_tokens(mock_aembedding, client_no_auth): +def test_embedding_input_array_of_tokens(client_no_auth): """ Test to bypass decoding input as array of tokens for selected providers Ref: https://github.com/BerriAI/litellm/issues/10113 """ - try: + # Create an AsyncMock for aembedding + async_mock = AsyncMock(return_value=example_embedding_result) + + # Mock proxy_logging_obj to prevent error handling interference + # pre_call_hook should return the data it receives + async def mock_pre_call_hook(*args, **kwargs): + # Return the data parameter - it's passed as kwargs["data"] + return kwargs.get("data", {}) + + # during_call_hook should be a simple async mock that doesn't do anything + async def mock_during_call_hook(*args, **kwargs): + return None + + mock_proxy_logging = MagicMock() + mock_proxy_logging.pre_call_hook = AsyncMock(side_effect=mock_pre_call_hook) + mock_proxy_logging.during_call_hook = AsyncMock(side_effect=mock_during_call_hook) + mock_proxy_logging.post_call_failure_hook = AsyncMock() + mock_proxy_logging.async_post_call_streaming_hook = AsyncMock() + mock_proxy_logging.update_request_status = AsyncMock() + + with mock.patch( + "litellm.proxy.proxy_server.llm_router.aembedding", + new=async_mock, + ) as mock_aembedding, mock.patch( + "litellm.proxy.proxy_server.proxy_logging_obj", + mock_proxy_logging, + ), mock.patch( + "litellm.proxy.proxy_server.premium_user", + True, + ): test_data = { "model": "vllm_embed_model", "input": [[2046, 13269, 158208]], } - + response = client_no_auth.post("/v1/embeddings", json=test_data) - mock_aembedding.assert_called_once_with( - model="vllm_embed_model", - input=[[2046, 13269, 158208]], - metadata=mock.ANY, - proxy_server_request=mock.ANY, - secret_fields=mock.ANY, - ) + # Get the actual call to check parameters + assert mock_aembedding.called, "aembedding should have been called" + call_args = mock_aembedding.call_args + + # Verify the key parameters we care about + assert call_args.kwargs["model"] == "vllm_embed_model" + assert call_args.kwargs["input"] == [[2046, 13269, 158208]] + + # Verify metadata, proxy_server_request, and secret_fields are present + # but don't check their exact values since they're complex objects + assert "metadata" in call_args.kwargs + assert "proxy_server_request" in call_args.kwargs + assert "secret_fields" in call_args.kwargs + assert response.status_code == 200 result = response.json() - print(len(result["data"][0]["embedding"])) assert len(result["data"][0]["embedding"]) > 10 # this usually has len==1536 so - except Exception as e: - pytest.fail(f"LiteLLM Proxy test failed. Exception - {str(e)}") @pytest.mark.asyncio -async def test_get_all_team_models(): +async def test_get_all_team_models(): # noqa: PLR0915 """ Test get_all_team_models function with both "*" and specific team IDs """ from unittest.mock import AsyncMock, MagicMock - from litellm.proxy._types import LiteLLM_TeamTable from litellm.proxy.proxy_server import get_all_team_models # Mock team data @@ -435,7 +447,6 @@ def mock_get_model_list_with_none(model_name, team_id=None): # Should handle None return gracefully assert isinstance(result, dict) - print("result: ", result) assert result == {"gpt-4-model-1": ["team1"], "gpt-4-model-2": ["team1"]} @@ -626,7 +637,6 @@ async def test_add_proxy_budget_to_db_only_creates_user_no_keys(): """ from unittest.mock import AsyncMock, patch - import litellm from litellm.proxy.proxy_server import ProxyStartupEvent # Set up required litellm settings @@ -847,7 +857,7 @@ async def test_write_config_to_file(monkeypatch): """ Do not write config to file if store_model_in_db is True """ - from unittest.mock import AsyncMock, MagicMock, mock_open, patch + from unittest.mock import AsyncMock, mock_open, patch from litellm.proxy.proxy_server import ProxyConfig @@ -901,7 +911,7 @@ async def test_write_config_to_file_when_store_model_in_db_false(monkeypatch): """ Test that config IS written to file when store_model_in_db is False """ - from unittest.mock import AsyncMock, MagicMock, mock_open, patch + from unittest.mock import mock_open, patch from litellm.proxy.proxy_server import ProxyConfig @@ -1006,7 +1016,7 @@ def mock_streaming_hook(*args, **kwargs): mock_response, mock_user_api_key_dict, mock_request_data ): yielded_data.append(data) - except Exception as e: + except Exception: # If there's an exception, that's also part of what we want to test pass @@ -1085,9 +1095,7 @@ async def test_chat_completion_result_no_nested_none_values(): from unittest.mock import AsyncMock, MagicMock, patch from fastapi import Request, Response - from pydantic import BaseModel - import litellm from litellm.proxy._types import UserAPIKeyAuth from litellm.proxy.proxy_server import chat_completion @@ -1398,7 +1406,7 @@ def test_get_model_cost_map_reload_status_admin_access(self, client_with_auth): assert response.status_code == 200 data = response.json() - assert data["scheduled"] == True + assert data["scheduled"] is True assert data["interval_hours"] == 6 assert data["last_run"] == "2024-01-01T06:00:00" assert data["next_run"] == "2024-01-01T12:00:00" @@ -1426,10 +1434,10 @@ def test_get_model_cost_map_reload_status_no_config(self, client_with_auth): assert response.status_code == 200 data = response.json() - assert data["scheduled"] == False - assert data["interval_hours"] == None - assert data["last_run"] == None - assert data["next_run"] == None + assert data["scheduled"] is False + assert data["interval_hours"] is None + assert data["last_run"] is None + assert data["next_run"] is None def test_get_model_cost_map_reload_status_no_interval(self, client_with_auth): """Test that status returns not scheduled when no interval is configured""" @@ -1445,10 +1453,10 @@ def test_get_model_cost_map_reload_status_no_interval(self, client_with_auth): assert response.status_code == 200 data = response.json() - assert data["scheduled"] == False - assert data["interval_hours"] == None - assert data["last_run"] == None - assert data["next_run"] == None + assert data["scheduled"] is False + assert data["interval_hours"] is None + assert data["last_run"] is None + assert data["next_run"] is None class TestPriceDataReloadIntegration: @@ -1554,7 +1562,7 @@ def test_distributed_reload_check_function(self): # The param_value is now a JSON string, so we need to parse it param_value_json = call_args[1]["data"]["update"]["param_value"] param_value_dict = json.loads(param_value_json) - assert param_value_dict["force_reload"] == False + assert param_value_dict["force_reload"] is False def test_config_file_parsing(self): """Test parsing of config file with reload settings""" @@ -1613,7 +1621,7 @@ def test_database_config_storage(self): call_args = mock_prisma.db.litellm_config.upsert.call_args assert call_args[1]["where"]["param_name"] == "model_cost_map_reload_config" assert call_args[1]["data"]["create"]["param_value"]["interval_hours"] == 6 - assert call_args[1]["data"]["create"]["param_value"]["force_reload"] == False + assert call_args[1]["data"]["create"]["param_value"]["force_reload"] is False def test_manual_reload_force_flag(self): """Test that manual reload sets force flag correctly""" @@ -1640,7 +1648,7 @@ def test_manual_reload_force_flag(self): # Verify force_reload flag was set mock_prisma.db.litellm_config.upsert.assert_called_once() call_args = mock_prisma.db.litellm_config.upsert.call_args - assert call_args[1]["data"]["update"]["param_value"]["force_reload"] == True + assert call_args[1]["data"]["update"]["param_value"]["force_reload"] is True @pytest.mark.asyncio @@ -1651,7 +1659,7 @@ async def test_add_router_settings_from_db_config_merge_logic(): This tests how router settings from config file and database are combined, including scenarios where nested dictionaries should be properly merged. """ - from unittest.mock import AsyncMock, MagicMock, patch + from unittest.mock import AsyncMock, MagicMock from litellm.proxy.proxy_server import ProxyConfig @@ -1715,7 +1723,7 @@ async def test_add_router_settings_from_db_config_merge_logic(): # Config-only values should be preserved assert combined_settings["model_group_alias"] == {"gpt-4": "openai-gpt-4"} - assert combined_settings["enable_pre_call_checks"] == True + assert combined_settings["enable_pre_call_checks"] is True assert combined_settings["timeout"] == 30 # DB-only values should be added @@ -1978,7 +1986,7 @@ def test_add_callback_from_db_to_in_memory_litellm_callbacks(): Test that _add_callback_from_db_to_in_memory_litellm_callbacks correctly adds callbacks for success, failure, and combined event types. """ - from unittest.mock import MagicMock, patch + from unittest.mock import MagicMock from litellm.proxy.proxy_server import ProxyConfig