diff --git a/CHANGELOG.md b/CHANGELOG.md index 10396ab5c..73522eff2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,9 +2,24 @@ All notable changes to this project will be documented in this file. -The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), +The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +## [3.31.0] - 2026-01-14 + +### Added + +- Modal simulation API support alongside GCP Workflows for economy calculations +- `SimulationAPIModal` class for HTTP-based job submission and polling +- Factory function to select between GCP and Modal backends via `USE_MODAL_SIMULATION_API` env var +- Status constants for both GCP (`ACTIVE`, `SUCCEEDED`, `FAILED`) and Modal (`running`, `complete`, `failed`) +- Unit tests for Modal client, factory, and status handling + +### Changed + +- `EconomyService` now handles both GCP and Modal execution status values +- Added `httpx` dependency for Modal HTTP client + ## [3.30.4] - 2026-01-13 13:30:17 ### Changed diff --git a/changelog_entry.yaml b/changelog_entry.yaml index e69de29bb..77aa4fea8 100644 --- a/changelog_entry.yaml +++ b/changelog_entry.yaml @@ -0,0 +1,10 @@ +- bump: minor + changes: + added: + - Modal simulation API support alongside GCP Workflows for economy calculations + - SimulationAPIModal class for HTTP-based job submission and polling + - Factory function to select between GCP and Modal backends via USE_MODAL_SIMULATION_API env var + - Status constants for both GCP and Modal execution states + - Unit tests for Modal client, factory, and status handling + changed: + - EconomyService now handles both GCP and Modal execution status values diff --git a/policyengine_api/constants.py b/policyengine_api/constants.py index 0c7d07139..3dc687355 100644 --- a/policyengine_api/constants.py +++ b/policyengine_api/constants.py @@ -56,4 +56,32 @@ ], } +# Simulation execution status constants +# GCP Workflow execution states (from google.cloud.workflows.executions_v1.Execution.State) +GCP_EXECUTION_STATUS_ACTIVE = "ACTIVE" +GCP_EXECUTION_STATUS_SUCCEEDED = "SUCCEEDED" +GCP_EXECUTION_STATUS_FAILED = "FAILED" +GCP_EXECUTION_STATUS_CANCELLED = "CANCELLED" + +# Modal simulation API status values +MODAL_EXECUTION_STATUS_SUBMITTED = "submitted" +MODAL_EXECUTION_STATUS_RUNNING = "running" +MODAL_EXECUTION_STATUS_COMPLETE = "complete" +MODAL_EXECUTION_STATUS_FAILED = "failed" + +# Status groupings for EconomyService._handle_execution_state() +EXECUTION_STATUSES_SUCCESS = ( + GCP_EXECUTION_STATUS_SUCCEEDED, + MODAL_EXECUTION_STATUS_COMPLETE, +) +EXECUTION_STATUSES_FAILURE = ( + GCP_EXECUTION_STATUS_FAILED, + MODAL_EXECUTION_STATUS_FAILED, +) +EXECUTION_STATUSES_PENDING = ( + GCP_EXECUTION_STATUS_ACTIVE, + MODAL_EXECUTION_STATUS_SUBMITTED, + MODAL_EXECUTION_STATUS_RUNNING, +) + __version__ = VERSION diff --git a/policyengine_api/libs/simulation_api_factory.py b/policyengine_api/libs/simulation_api_factory.py new file mode 100644 index 000000000..c94d1f90b --- /dev/null +++ b/policyengine_api/libs/simulation_api_factory.py @@ -0,0 +1,60 @@ +""" +Factory for selecting the appropriate Simulation API implementation. + +This module provides a factory function that returns either the GCP Workflows-based +SimulationAPI or the Modal-based SimulationAPIModal, depending on environment +configuration. + +Environment Variables +--------------------- +USE_MODAL_SIMULATION_API : str + Set to "true" to use the Modal simulation API. Defaults to "false" (GCP). +""" + +import os +from typing import Union + +from policyengine_api.gcp_logging import logger + + +def get_simulation_api() -> ( + Union["SimulationAPI", "SimulationAPIModal"] # noqa: F821 +): + """ + Get the appropriate simulation API client based on environment configuration. + + Returns the Modal-based client if USE_MODAL_SIMULATION_API is set to "true", + otherwise returns the GCP Workflows-based client. + + Returns + ------- + SimulationAPI or SimulationAPIModal + The simulation API client instance. + + Raises + ------ + ValueError + If GCP client is requested but GOOGLE_APPLICATION_CREDENTIALS is not set. + """ + use_modal = ( + os.environ.get("USE_MODAL_SIMULATION_API", "false").lower() == "true" + ) + + if use_modal: + logger.log_struct( + {"message": "Using Modal simulation API"}, + severity="INFO", + ) + from policyengine_api.libs.simulation_api_modal import ( + simulation_api_modal, + ) + + return simulation_api_modal + else: + logger.log_struct( + {"message": "Using GCP Workflows simulation API"}, + severity="INFO", + ) + from policyengine_api.libs.simulation_api import SimulationAPI + + return SimulationAPI() diff --git a/policyengine_api/libs/simulation_api_modal.py b/policyengine_api/libs/simulation_api_modal.py new file mode 100644 index 000000000..fd81781d3 --- /dev/null +++ b/policyengine_api/libs/simulation_api_modal.py @@ -0,0 +1,237 @@ +""" +HTTP client for the Modal Simulation API. + +This module provides a client for submitting simulation jobs to the +Modal-based simulation API and polling for results. It implements +the same interface as SimulationAPI (GCP) to allow for easy switching +between backends. +""" + +import os +from dataclasses import dataclass +from typing import Optional + +import httpx + +from policyengine_api.gcp_logging import logger + + +@dataclass +class ModalSimulationExecution: + """ + Represents a Modal simulation job execution. + + This class mirrors the interface of GCP's executions_v1.Execution + to allow the EconomyService to work with either backend. + """ + + job_id: str + status: str + result: Optional[dict] = None + error: Optional[str] = None + + @property + def name(self) -> str: + """Alias for job_id to match GCP Execution interface.""" + return self.job_id + + +class SimulationAPIModal: + """ + HTTP client for the Modal Simulation API. + + This class provides methods for submitting simulation jobs and + polling for their status/results via HTTP endpoints, replacing + the GCP Workflows SDK calls used in SimulationAPI. + """ + + def __init__(self): + self.base_url = os.environ.get( + "SIMULATION_API_URL", + "https://policyengine--policyengine-simulation-gateway-web-app.modal.run", + ) + self.client = httpx.Client(timeout=30.0) + + def run(self, payload: dict) -> ModalSimulationExecution: + """ + Submit a simulation job to the Modal API. + + Parameters + ---------- + payload : dict + The simulation parameters (country, reform, baseline, etc.) + Expected to match SimulationOptions schema. + + Returns + ------- + ModalSimulationExecution + Execution object with job_id and initial status. + + Raises + ------ + httpx.HTTPStatusError + If the API returns an error response. + """ + try: + # Map field names from SimulationOptions to Modal API format + # SimulationOptions uses 'model_version', Modal expects 'version' + modal_payload = dict(payload) + if "model_version" in modal_payload: + modal_payload["version"] = modal_payload.pop("model_version") + # Remove data_version as Modal doesn't use it + modal_payload.pop("data_version", None) + + response = self.client.post( + f"{self.base_url}/simulate/economy/comparison", + json=modal_payload, + ) + response.raise_for_status() + data = response.json() + + logger.log_struct( + { + "message": "Modal simulation job submitted", + "job_id": data.get("job_id"), + "status": data.get("status"), + }, + severity="INFO", + ) + + return ModalSimulationExecution( + job_id=data["job_id"], + status=data["status"], + ) + + except httpx.HTTPStatusError as e: + logger.log_struct( + { + "message": f"Modal API HTTP error: {e.response.status_code}", + "response_text": e.response.text[:500], + }, + severity="ERROR", + ) + raise + + except httpx.RequestError as e: + logger.log_struct( + { + "message": f"Modal API request error: {str(e)}", + }, + severity="ERROR", + ) + raise + + def get_execution_id(self, execution: ModalSimulationExecution) -> str: + """ + Get the job ID from an execution. + + Parameters + ---------- + execution : ModalSimulationExecution + The execution object returned from run(). + + Returns + ------- + str + The job ID. + """ + return execution.job_id + + def get_execution_by_id(self, job_id: str) -> ModalSimulationExecution: + """ + Poll the Modal API for the current status of a job. + + Parameters + ---------- + job_id : str + The job ID returned from run(). + + Returns + ------- + ModalSimulationExecution + Execution object with current status and result if complete. + """ + try: + response = self.client.get(f"{self.base_url}/jobs/{job_id}") + # Note: Modal returns 202 for running, 200 for complete, 500 for failed + # We handle all cases by checking the status field in the response + data = response.json() + + return ModalSimulationExecution( + job_id=job_id, + status=data["status"], + result=data.get("result"), + error=data.get("error"), + ) + + except httpx.HTTPStatusError as e: + logger.log_struct( + { + "message": f"Modal API HTTP error polling job {job_id}: {e.response.status_code}", + "response_text": e.response.text[:500], + }, + severity="ERROR", + ) + raise + + except httpx.RequestError as e: + logger.log_struct( + { + "message": f"Modal API request error polling job {job_id}: {str(e)}", + }, + severity="ERROR", + ) + raise + + def get_execution_status(self, execution: ModalSimulationExecution) -> str: + """ + Get the status string from an execution. + + Parameters + ---------- + execution : ModalSimulationExecution + The execution object. + + Returns + ------- + str + The status string ("submitted", "running", "complete", "failed"). + """ + return execution.status + + def get_execution_result( + self, execution: ModalSimulationExecution + ) -> Optional[dict]: + """ + Get the result from a completed execution. + + Parameters + ---------- + execution : ModalSimulationExecution + The execution object. + + Returns + ------- + dict or None + The simulation result if complete, None otherwise. + """ + return execution.result + + def health_check(self) -> bool: + """ + Check if the Modal API is healthy. + + Returns + ------- + bool + True if the API is healthy, False otherwise. + """ + try: + response = self.client.get(f"{self.base_url}/health") + return response.status_code == 200 + except Exception: + return False + + +# Global instance for use throughout the application +simulation_api_modal = SimulationAPIModal() diff --git a/policyengine_api/services/economy_service.py b/policyengine_api/services/economy_service.py index 68b86194f..9ca08b69d 100644 --- a/policyengine_api/services/economy_service.py +++ b/policyengine_api/services/economy_service.py @@ -5,9 +5,12 @@ from policyengine_api.constants import ( COUNTRY_PACKAGE_VERSIONS, REGION_PREFIXES, + EXECUTION_STATUSES_SUCCESS, + EXECUTION_STATUSES_FAILURE, + EXECUTION_STATUSES_PENDING, ) from policyengine_api.gcp_logging import logger -from policyengine_api.libs.simulation_api import SimulationAPI +from policyengine_api.libs.simulation_api_factory import get_simulation_api from policyengine_api.data.model_setup import get_dataset_version from policyengine_api.data.congressional_districts import ( get_valid_state_codes, @@ -16,22 +19,19 @@ ) from policyengine.simulation import SimulationOptions from policyengine.utils.data.datasets import get_default_dataset -from google.cloud.workflows import executions_v1 import json import datetime -from typing import Literal, Any, Optional, Annotated +from typing import Literal, Any, Optional, Annotated, Union from dotenv import load_dotenv from pydantic import BaseModel import numpy as np from enum import Enum -ExecutionState = executions_v1.Execution.State - load_dotenv() policy_service = PolicyService() reform_impacts_service = ReformImpactsService() -simulation_api = SimulationAPI() +simulation_api = get_simulation_api() class ImpactAction(Enum): @@ -319,12 +319,15 @@ def _handle_execution_state( setup_options: EconomicImpactSetupOptions, execution_state: str, reform_impact: dict, - execution: Optional[executions_v1.Execution] = None, + execution: Optional[Any] = None, ) -> EconomicImpactResult: """ Handle the state of the execution and return the appropriate status and result. + + Supports both GCP Workflow statuses (SUCCEEDED, FAILED, ACTIVE) and + Modal statuses (complete, failed, running, submitted). """ - if execution_state == "SUCCEEDED": + if execution_state in EXECUTION_STATUSES_SUCCESS: result = simulation_api.get_execution_result(execution) self._set_reform_impact_complete( setup_options=setup_options, @@ -337,21 +340,30 @@ def _handle_execution_state( ) return EconomicImpactResult.completed(data=result) - elif execution_state == "FAILED": + elif execution_state in EXECUTION_STATUSES_FAILURE: + # For Modal, try to get error message from execution + error_message = "Simulation API execution failed" + if ( + execution is not None + and hasattr(execution, "error") + and execution.error + ): + error_message = ( + f"Simulation API execution failed: {execution.error}" + ) + self._set_reform_impact_error( setup_options=setup_options, - message="Simulation API execution failed", + message=error_message, execution_id=reform_impact["execution_id"], ) logger.log_struct( - {"message": "Sim API execution failed"}, + {"message": error_message}, severity="ERROR", ) - return EconomicImpactResult.error( - message="Simulation API execution failed" - ) + return EconomicImpactResult.error(message=error_message) - elif execution_state == "ACTIVE": + elif execution_state in EXECUTION_STATUSES_PENDING: logger.log_struct( {"message": "Sim API execution is still running"}, severity="INFO", diff --git a/setup.py b/setup.py index ace52e8a9..0dbbe1860 100644 --- a/setup.py +++ b/setup.py @@ -19,6 +19,7 @@ "flask-cors>=5,<6", "google-cloud-logging", "gunicorn", + "httpx>=0.27.0", "markupsafe>=3,<4", "openai", "policyengine_canada==0.96.3", diff --git a/tests/fixtures/libs/__init__.py b/tests/fixtures/libs/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/fixtures/libs/simulation_api_factory.py b/tests/fixtures/libs/simulation_api_factory.py new file mode 100644 index 000000000..2d2fe6c42 --- /dev/null +++ b/tests/fixtures/libs/simulation_api_factory.py @@ -0,0 +1,79 @@ +""" +Test fixtures for simulation_api_factory. + +This module provides fixtures for testing the simulation API factory +that switches between GCP and Modal backends. +""" + +import pytest +from unittest.mock import patch, MagicMock + + +@pytest.fixture +def mock_env_use_modal_true(): + """Set USE_MODAL_SIMULATION_API to true.""" + with patch.dict( + "os.environ", + {"USE_MODAL_SIMULATION_API": "true"}, + ): + yield + + +@pytest.fixture +def mock_env_use_modal_false(): + """Set USE_MODAL_SIMULATION_API to false.""" + with patch.dict( + "os.environ", + {"USE_MODAL_SIMULATION_API": "false"}, + ): + yield + + +@pytest.fixture +def mock_env_use_modal_unset(): + """Ensure USE_MODAL_SIMULATION_API is not set.""" + with patch.dict( + "os.environ", + {}, + clear=True, + ): + # Re-patch to only clear the specific key + import os + + env_copy = dict(os.environ) + env_copy.pop("USE_MODAL_SIMULATION_API", None) + with patch.dict("os.environ", env_copy, clear=True): + yield + + +@pytest.fixture +def mock_factory_logger(): + """Mock logger for simulation_api_factory.""" + with patch("policyengine_api.libs.simulation_api_factory.logger") as mock: + yield mock + + +@pytest.fixture +def mock_simulation_api_modal_instance(): + """Mock the Modal simulation API instance.""" + mock_instance = MagicMock() + mock_instance.base_url = "https://mock-modal-api.modal.run" + with patch( + "policyengine_api.libs.simulation_api_factory.simulation_api_modal", + mock_instance, + ): + yield mock_instance + + +@pytest.fixture +def mock_simulation_api_gcp_class(): + """Mock the GCP SimulationAPI class.""" + mock_instance = MagicMock() + mock_instance.project = "mock-project" + mock_instance.location = "us-central1" + mock_instance.workflow = "simulation-workflow" + with patch( + "policyengine_api.libs.simulation_api_factory.SimulationAPI", + return_value=mock_instance, + ) as mock_class: + yield mock_class, mock_instance diff --git a/tests/fixtures/libs/simulation_api_modal.py b/tests/fixtures/libs/simulation_api_modal.py new file mode 100644 index 000000000..216cf5bf1 --- /dev/null +++ b/tests/fixtures/libs/simulation_api_modal.py @@ -0,0 +1,176 @@ +""" +Test fixtures for SimulationAPIModal. + +This module provides mock data, fixtures, and helper functions for testing +the Modal simulation API client. +""" + +import pytest +from unittest.mock import patch, MagicMock +import json + +from policyengine_api.constants import ( + MODAL_EXECUTION_STATUS_SUBMITTED, + MODAL_EXECUTION_STATUS_RUNNING, + MODAL_EXECUTION_STATUS_COMPLETE, + MODAL_EXECUTION_STATUS_FAILED, +) + + +# Mock data constants +MOCK_MODAL_JOB_ID = "fc-abc123xyz" +MOCK_MODAL_BASE_URL = "https://test-modal-api.modal.run" + +MOCK_SIMULATION_PAYLOAD = { + "country": "us", + "scope": "macro", + "reform": {"sample_param": {"2024-01-01.2100-12-31": 15}}, + "baseline": {}, + "time_period": "2025", + "region": "us", + "data": "gs://policyengine-us-data/cps_2023.h5", + "include_cliffs": False, +} + +MOCK_SIMULATION_RESULT = { + "poverty_impact": {"baseline": 0.12, "reform": 0.10}, + "budget_impact": {"baseline": 1000, "reform": 1200}, + "inequality_impact": {"baseline": 0.45, "reform": 0.42}, +} + +MOCK_SUBMIT_RESPONSE_SUCCESS = { + "job_id": MOCK_MODAL_JOB_ID, + "status": MODAL_EXECUTION_STATUS_SUBMITTED, + "poll_url": f"/jobs/{MOCK_MODAL_JOB_ID}", + "country": "us", + "version": "1.459.0", +} + +MOCK_POLL_RESPONSE_RUNNING = { + "status": MODAL_EXECUTION_STATUS_RUNNING, + "result": None, + "error": None, +} + +MOCK_POLL_RESPONSE_COMPLETE = { + "status": MODAL_EXECUTION_STATUS_COMPLETE, + "result": MOCK_SIMULATION_RESULT, + "error": None, +} + +MOCK_POLL_RESPONSE_FAILED = { + "status": MODAL_EXECUTION_STATUS_FAILED, + "result": None, + "error": "Simulation timed out", +} + +MOCK_HEALTH_RESPONSE = {"status": "healthy"} + + +def create_mock_httpx_response( + status_code: int = 200, + json_data: dict = None, +): + """ + Helper function to create a mock httpx response. + + Parameters + ---------- + status_code : int + HTTP status code for the response. + json_data : dict + JSON data to return from response.json(). + + Returns + ------- + MagicMock + A mock httpx response object. + """ + mock_response = MagicMock() + mock_response.status_code = status_code + mock_response.json.return_value = json_data or {} + mock_response.text = json.dumps(json_data or {}) + mock_response.raise_for_status = MagicMock() + + if status_code >= 400: + import httpx + + mock_response.raise_for_status.side_effect = httpx.HTTPStatusError( + message=f"HTTP {status_code}", + request=MagicMock(), + response=mock_response, + ) + + return mock_response + + +@pytest.fixture +def mock_modal_env_url(): + """Mock the SIMULATION_API_URL environment variable.""" + with patch.dict( + "os.environ", + {"SIMULATION_API_URL": MOCK_MODAL_BASE_URL}, + ): + yield MOCK_MODAL_BASE_URL + + +@pytest.fixture +def mock_httpx_client(): + """ + Mock httpx.Client for testing SimulationAPIModal. + + Returns a mock client that can be configured for different responses. + """ + with patch( + "policyengine_api.libs.simulation_api_modal.httpx.Client" + ) as mock_client_class: + mock_client = MagicMock() + mock_client_class.return_value = mock_client + yield mock_client + + +@pytest.fixture +def mock_httpx_client_submit_success(mock_httpx_client): + """Mock httpx client configured for successful job submission.""" + mock_httpx_client.post.return_value = create_mock_httpx_response( + status_code=202, + json_data=MOCK_SUBMIT_RESPONSE_SUCCESS, + ) + return mock_httpx_client + + +@pytest.fixture +def mock_httpx_client_poll_running(mock_httpx_client): + """Mock httpx client configured for polling a running job.""" + mock_httpx_client.get.return_value = create_mock_httpx_response( + status_code=202, + json_data=MOCK_POLL_RESPONSE_RUNNING, + ) + return mock_httpx_client + + +@pytest.fixture +def mock_httpx_client_poll_complete(mock_httpx_client): + """Mock httpx client configured for polling a completed job.""" + mock_httpx_client.get.return_value = create_mock_httpx_response( + status_code=200, + json_data=MOCK_POLL_RESPONSE_COMPLETE, + ) + return mock_httpx_client + + +@pytest.fixture +def mock_httpx_client_poll_failed(mock_httpx_client): + """Mock httpx client configured for polling a failed job.""" + mock_httpx_client.get.return_value = create_mock_httpx_response( + status_code=500, + json_data=MOCK_POLL_RESPONSE_FAILED, + ) + return mock_httpx_client + + +@pytest.fixture +def mock_modal_logger(): + """Mock logger for SimulationAPIModal.""" + with patch("policyengine_api.libs.simulation_api_modal.logger") as mock: + yield mock diff --git a/tests/fixtures/services/economy_service.py b/tests/fixtures/services/economy_service.py index 83068bd0c..293b8909e 100644 --- a/tests/fixtures/services/economy_service.py +++ b/tests/fixtures/services/economy_service.py @@ -4,6 +4,16 @@ import datetime from google.cloud.workflows import executions_v1 +from policyengine_api.constants import ( + GCP_EXECUTION_STATUS_ACTIVE, + GCP_EXECUTION_STATUS_SUCCEEDED, + GCP_EXECUTION_STATUS_FAILED, + MODAL_EXECUTION_STATUS_SUBMITTED, + MODAL_EXECUTION_STATUS_RUNNING, + MODAL_EXECUTION_STATUS_COMPLETE, + MODAL_EXECUTION_STATUS_FAILED, +) + # Mock data constants MOCK_COUNTRY_ID = "us" MOCK_POLICY_ID = 123 @@ -15,6 +25,7 @@ MOCK_OPTIONS = {"option1": "value1", "option2": "value2"} MOCK_OPTIONS_HASH = "[option1=value1&option2=value2]" MOCK_EXECUTION_ID = "mock_execution_id_12345" +MOCK_MODAL_JOB_ID = "fc-test123xyz" MOCK_PROCESS_ID = "job_20250626120000_1234" MOCK_MODEL_VERSION = "1.2.3" MOCK_DATA_VERSION = None @@ -184,6 +195,58 @@ def mock_execution_states(): } +def create_mock_modal_execution( + job_id=MOCK_MODAL_JOB_ID, + status=MODAL_EXECUTION_STATUS_SUBMITTED, + result=None, + error=None, +): + """ + Helper function to create mock Modal execution objects. + + Parameters + ---------- + job_id : str + The Modal job ID. + status : str + The execution status (submitted, running, complete, failed). + result : dict or None + The simulation result if complete. + error : str or None + The error message if failed. + + Returns + ------- + MagicMock + A mock ModalSimulationExecution object. + """ + mock_execution = MagicMock() + mock_execution.job_id = job_id + mock_execution.name = job_id # Alias for compatibility + mock_execution.status = status + mock_execution.result = result + mock_execution.error = error + return mock_execution + + +@pytest.fixture +def mock_simulation_api_modal(): + """Mock SimulationAPIModal with all required methods.""" + mock_api = MagicMock() + mock_execution = create_mock_modal_execution() + + mock_api.run.return_value = mock_execution + mock_api.get_execution_id.return_value = MOCK_MODAL_JOB_ID + mock_api.get_execution_by_id.return_value = mock_execution + mock_api.get_execution_status.return_value = MODAL_EXECUTION_STATUS_RUNNING + mock_api.get_execution_result.return_value = MOCK_REFORM_IMPACT_DATA + + with patch( + "policyengine_api.services.economy_service.simulation_api", mock_api + ) as mock: + yield mock + + # Expected GCS paths from get_default_dataset MOCK_US_NATIONWIDE_DATASET = "gs://policyengine-us-data/cps_2023.h5" MOCK_US_STATE_CA_DATASET = "gs://policyengine-us-data/states/CA.h5" diff --git a/tests/unit/libs/__init__.py b/tests/unit/libs/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/unit/libs/test_simulation_api_factory.py b/tests/unit/libs/test_simulation_api_factory.py new file mode 100644 index 000000000..6602c47b4 --- /dev/null +++ b/tests/unit/libs/test_simulation_api_factory.py @@ -0,0 +1,208 @@ +""" +Unit tests for simulation_api_factory. + +Tests the factory function that selects between GCP Workflows +and Modal simulation API backends. +""" + +import pytest +from unittest.mock import patch, MagicMock + +from tests.fixtures.libs.simulation_api_factory import ( + mock_env_use_modal_true, + mock_env_use_modal_false, + mock_factory_logger, +) + + +class TestGetSimulationApi: + """Tests for the get_simulation_api factory function.""" + + class TestModalSelection: + + def test__given_use_modal_env_true__then_returns_modal_api( + self, + mock_factory_logger, + ): + # Given + with patch.dict( + "os.environ", + {"USE_MODAL_SIMULATION_API": "true"}, + ): + # Need to reimport to pick up the env change + from policyengine_api.libs.simulation_api_factory import ( + get_simulation_api, + ) + from policyengine_api.libs.simulation_api_modal import ( + SimulationAPIModal, + ) + + # When + api = get_simulation_api() + + # Then + assert isinstance(api, SimulationAPIModal) + + def test__given_use_modal_env_true_uppercase__then_returns_modal_api( + self, + mock_factory_logger, + ): + # Given + with patch.dict( + "os.environ", + {"USE_MODAL_SIMULATION_API": "TRUE"}, + ): + from policyengine_api.libs.simulation_api_factory import ( + get_simulation_api, + ) + from policyengine_api.libs.simulation_api_modal import ( + SimulationAPIModal, + ) + + # When + api = get_simulation_api() + + # Then + assert isinstance(api, SimulationAPIModal) + + def test__given_use_modal_env_true__then_logs_modal_selection( + self, + mock_factory_logger, + ): + # Given + with patch.dict( + "os.environ", + {"USE_MODAL_SIMULATION_API": "true"}, + ): + from policyengine_api.libs.simulation_api_factory import ( + get_simulation_api, + ) + + # When + get_simulation_api() + + # Then + mock_factory_logger.log_struct.assert_called() + call_args = mock_factory_logger.log_struct.call_args[0][0] + assert "Modal" in call_args["message"] + + class TestGCPSelection: + + def test__given_use_modal_env_false__then_returns_gcp_api( + self, + mock_factory_logger, + ): + # Given + with patch.dict( + "os.environ", + { + "USE_MODAL_SIMULATION_API": "false", + "GOOGLE_APPLICATION_CREDENTIALS": "/path/to/creds.json", + }, + ): + # Mock the GCP client to avoid needing real credentials + with patch( + "policyengine_api.libs.simulation_api.executions_v1.ExecutionsClient" + ): + with patch( + "policyengine_api.libs.simulation_api.workflows_v1.WorkflowsClient" + ): + from policyengine_api.libs.simulation_api_factory import ( + get_simulation_api, + ) + from policyengine_api.libs.simulation_api import ( + SimulationAPI, + ) + + # When + api = get_simulation_api() + + # Then + assert isinstance(api, SimulationAPI) + + def test__given_use_modal_env_not_set__then_returns_gcp_api( + self, + mock_factory_logger, + ): + # Given + import os + + env_copy = dict(os.environ) + env_copy.pop("USE_MODAL_SIMULATION_API", None) + env_copy["GOOGLE_APPLICATION_CREDENTIALS"] = "/path/to/creds.json" + + with patch.dict("os.environ", env_copy, clear=True): + with patch( + "policyengine_api.libs.simulation_api.executions_v1.ExecutionsClient" + ): + with patch( + "policyengine_api.libs.simulation_api.workflows_v1.WorkflowsClient" + ): + from policyengine_api.libs.simulation_api_factory import ( + get_simulation_api, + ) + from policyengine_api.libs.simulation_api import ( + SimulationAPI, + ) + + # When + api = get_simulation_api() + + # Then + assert isinstance(api, SimulationAPI) + + def test__given_use_modal_env_false__then_logs_gcp_selection( + self, + mock_factory_logger, + ): + # Given + with patch.dict( + "os.environ", + { + "USE_MODAL_SIMULATION_API": "false", + "GOOGLE_APPLICATION_CREDENTIALS": "/path/to/creds.json", + }, + ): + with patch( + "policyengine_api.libs.simulation_api.executions_v1.ExecutionsClient" + ): + with patch( + "policyengine_api.libs.simulation_api.workflows_v1.WorkflowsClient" + ): + from policyengine_api.libs.simulation_api_factory import ( + get_simulation_api, + ) + + # When + get_simulation_api() + + # Then + mock_factory_logger.log_struct.assert_called() + call_args = mock_factory_logger.log_struct.call_args[ + 0 + ][0] + assert "GCP" in call_args["message"] + + class TestGCPCredentialsError: + + def test__given_gcp_selected_without_credentials__then_raises_error( + self, + mock_factory_logger, + ): + # Given + import os + + env_copy = dict(os.environ) + env_copy.pop("USE_MODAL_SIMULATION_API", None) + env_copy.pop("GOOGLE_APPLICATION_CREDENTIALS", None) + + with patch.dict("os.environ", env_copy, clear=True): + from policyengine_api.libs.simulation_api_factory import ( + get_simulation_api, + ) + + # When/Then + with pytest.raises(ValueError) as exc_info: + get_simulation_api() + + assert "GOOGLE_APPLICATION_CREDENTIALS" in str(exc_info.value) diff --git a/tests/unit/libs/test_simulation_api_modal.py b/tests/unit/libs/test_simulation_api_modal.py new file mode 100644 index 000000000..4ba7d0616 --- /dev/null +++ b/tests/unit/libs/test_simulation_api_modal.py @@ -0,0 +1,398 @@ +""" +Unit tests for SimulationAPIModal class. + +Tests the Modal simulation API HTTP client functionality including +job submission, status polling, and error handling. +""" + +import pytest +from unittest.mock import patch, MagicMock +import httpx + +from policyengine_api.libs.simulation_api_modal import ( + SimulationAPIModal, + ModalSimulationExecution, +) +from policyengine_api.constants import ( + MODAL_EXECUTION_STATUS_SUBMITTED, + MODAL_EXECUTION_STATUS_RUNNING, + MODAL_EXECUTION_STATUS_COMPLETE, + MODAL_EXECUTION_STATUS_FAILED, +) +from tests.fixtures.libs.simulation_api_modal import ( + MOCK_MODAL_JOB_ID, + MOCK_MODAL_BASE_URL, + MOCK_SIMULATION_PAYLOAD, + MOCK_SIMULATION_RESULT, + MOCK_SUBMIT_RESPONSE_SUCCESS, + MOCK_POLL_RESPONSE_RUNNING, + MOCK_POLL_RESPONSE_COMPLETE, + MOCK_POLL_RESPONSE_FAILED, + MOCK_HEALTH_RESPONSE, + create_mock_httpx_response, + mock_httpx_client, + mock_modal_logger, +) + + +class TestModalSimulationExecution: + """Tests for the ModalSimulationExecution dataclass.""" + + class TestNameProperty: + + def test__given_job_id__then_name_returns_job_id(self): + # Given + execution = ModalSimulationExecution( + job_id=MOCK_MODAL_JOB_ID, + status=MODAL_EXECUTION_STATUS_SUBMITTED, + ) + + # When + name = execution.name + + # Then + assert name == MOCK_MODAL_JOB_ID + + class TestAttributes: + + def test__given_complete_execution__then_all_attributes_accessible( + self, + ): + # Given + execution = ModalSimulationExecution( + job_id=MOCK_MODAL_JOB_ID, + status=MODAL_EXECUTION_STATUS_COMPLETE, + result=MOCK_SIMULATION_RESULT, + error=None, + ) + + # Then + assert execution.job_id == MOCK_MODAL_JOB_ID + assert execution.status == MODAL_EXECUTION_STATUS_COMPLETE + assert execution.result == MOCK_SIMULATION_RESULT + assert execution.error is None + + def test__given_failed_execution__then_error_attribute_populated(self): + # Given + error_message = "Simulation timed out" + execution = ModalSimulationExecution( + job_id=MOCK_MODAL_JOB_ID, + status=MODAL_EXECUTION_STATUS_FAILED, + result=None, + error=error_message, + ) + + # Then + assert execution.status == MODAL_EXECUTION_STATUS_FAILED + assert execution.error == error_message + assert execution.result is None + + +class TestSimulationAPIModal: + """Tests for the SimulationAPIModal class.""" + + class TestInit: + + def test__given_env_var_set__then_uses_env_url( + self, mock_httpx_client + ): + # Given + with patch.dict( + "os.environ", + {"SIMULATION_API_URL": MOCK_MODAL_BASE_URL}, + ): + # When + api = SimulationAPIModal() + + # Then + assert api.base_url == MOCK_MODAL_BASE_URL + + def test__given_env_var_not_set__then_uses_default_url( + self, mock_httpx_client + ): + # Given + with patch.dict("os.environ", {}, clear=False): + import os + + os.environ.pop("SIMULATION_API_URL", None) + + # When + api = SimulationAPIModal() + + # Then + assert "policyengine-simulation-gateway" in api.base_url + assert "modal.run" in api.base_url + + class TestRun: + + def test__given_valid_payload__then_returns_execution_with_job_id( + self, + mock_httpx_client, + mock_modal_logger, + ): + # Given + mock_httpx_client.post.return_value = create_mock_httpx_response( + status_code=202, + json_data=MOCK_SUBMIT_RESPONSE_SUCCESS, + ) + api = SimulationAPIModal() + + # When + execution = api.run(MOCK_SIMULATION_PAYLOAD) + + # Then + assert execution.job_id == MOCK_MODAL_JOB_ID + assert execution.status == MODAL_EXECUTION_STATUS_SUBMITTED + mock_httpx_client.post.assert_called_once() + + def test__given_valid_payload__then_posts_to_correct_endpoint( + self, + mock_httpx_client, + mock_modal_logger, + ): + # Given + mock_httpx_client.post.return_value = create_mock_httpx_response( + status_code=202, + json_data=MOCK_SUBMIT_RESPONSE_SUCCESS, + ) + api = SimulationAPIModal() + + # When + api.run(MOCK_SIMULATION_PAYLOAD) + + # Then + call_args = mock_httpx_client.post.call_args + assert "/simulate/economy/comparison" in call_args[0][0] + assert call_args[1]["json"] == MOCK_SIMULATION_PAYLOAD + + def test__given_http_error__then_raises_exception( + self, + mock_httpx_client, + mock_modal_logger, + ): + # Given + mock_response = create_mock_httpx_response( + status_code=400, + json_data={"error": "Invalid request"}, + ) + mock_httpx_client.post.return_value = mock_response + api = SimulationAPIModal() + + # When/Then + with pytest.raises(httpx.HTTPStatusError): + api.run(MOCK_SIMULATION_PAYLOAD) + + def test__given_network_error__then_raises_exception( + self, + mock_httpx_client, + mock_modal_logger, + ): + # Given + mock_httpx_client.post.side_effect = httpx.RequestError( + "Connection failed" + ) + api = SimulationAPIModal() + + # When/Then + with pytest.raises(httpx.RequestError): + api.run(MOCK_SIMULATION_PAYLOAD) + + class TestGetExecutionById: + + def test__given_running_job__then_returns_running_status( + self, + mock_httpx_client, + mock_modal_logger, + ): + # Given + mock_httpx_client.get.return_value = create_mock_httpx_response( + status_code=202, + json_data=MOCK_POLL_RESPONSE_RUNNING, + ) + api = SimulationAPIModal() + + # When + execution = api.get_execution_by_id(MOCK_MODAL_JOB_ID) + + # Then + assert execution.job_id == MOCK_MODAL_JOB_ID + assert execution.status == MODAL_EXECUTION_STATUS_RUNNING + assert execution.result is None + + def test__given_complete_job__then_returns_result( + self, + mock_httpx_client, + mock_modal_logger, + ): + # Given + mock_httpx_client.get.return_value = create_mock_httpx_response( + status_code=200, + json_data=MOCK_POLL_RESPONSE_COMPLETE, + ) + api = SimulationAPIModal() + + # When + execution = api.get_execution_by_id(MOCK_MODAL_JOB_ID) + + # Then + assert execution.status == MODAL_EXECUTION_STATUS_COMPLETE + assert execution.result == MOCK_SIMULATION_RESULT + + def test__given_failed_job__then_returns_error( + self, + mock_httpx_client, + mock_modal_logger, + ): + # Given + mock_httpx_client.get.return_value = create_mock_httpx_response( + status_code=200, # Failed jobs still return 200 with error in body + json_data=MOCK_POLL_RESPONSE_FAILED, + ) + api = SimulationAPIModal() + + # When + execution = api.get_execution_by_id(MOCK_MODAL_JOB_ID) + + # Then + assert execution.status == MODAL_EXECUTION_STATUS_FAILED + assert execution.error == "Simulation timed out" + + def test__given_job_id__then_polls_correct_endpoint( + self, + mock_httpx_client, + mock_modal_logger, + ): + # Given + mock_httpx_client.get.return_value = create_mock_httpx_response( + status_code=202, + json_data=MOCK_POLL_RESPONSE_RUNNING, + ) + api = SimulationAPIModal() + + # When + api.get_execution_by_id(MOCK_MODAL_JOB_ID) + + # Then + call_args = mock_httpx_client.get.call_args + assert f"/jobs/{MOCK_MODAL_JOB_ID}" in call_args[0][0] + + class TestGetExecutionId: + + def test__given_execution__then_returns_job_id( + self, mock_httpx_client + ): + # Given + api = SimulationAPIModal() + execution = ModalSimulationExecution( + job_id=MOCK_MODAL_JOB_ID, + status=MODAL_EXECUTION_STATUS_SUBMITTED, + ) + + # When + execution_id = api.get_execution_id(execution) + + # Then + assert execution_id == MOCK_MODAL_JOB_ID + + class TestGetExecutionStatus: + + def test__given_execution__then_returns_status_string( + self, mock_httpx_client + ): + # Given + api = SimulationAPIModal() + execution = ModalSimulationExecution( + job_id=MOCK_MODAL_JOB_ID, + status=MODAL_EXECUTION_STATUS_RUNNING, + ) + + # When + status = api.get_execution_status(execution) + + # Then + assert status == MODAL_EXECUTION_STATUS_RUNNING + + class TestGetExecutionResult: + + def test__given_complete_execution__then_returns_result( + self, mock_httpx_client + ): + # Given + api = SimulationAPIModal() + execution = ModalSimulationExecution( + job_id=MOCK_MODAL_JOB_ID, + status=MODAL_EXECUTION_STATUS_COMPLETE, + result=MOCK_SIMULATION_RESULT, + ) + + # When + result = api.get_execution_result(execution) + + # Then + assert result == MOCK_SIMULATION_RESULT + + def test__given_incomplete_execution__then_returns_none( + self, mock_httpx_client + ): + # Given + api = SimulationAPIModal() + execution = ModalSimulationExecution( + job_id=MOCK_MODAL_JOB_ID, + status=MODAL_EXECUTION_STATUS_RUNNING, + result=None, + ) + + # When + result = api.get_execution_result(execution) + + # Then + assert result is None + + class TestHealthCheck: + + def test__given_healthy_api__then_returns_true( + self, mock_httpx_client, mock_modal_logger + ): + # Given + mock_httpx_client.get.return_value = create_mock_httpx_response( + status_code=200, + json_data=MOCK_HEALTH_RESPONSE, + ) + api = SimulationAPIModal() + + # When + is_healthy = api.health_check() + + # Then + assert is_healthy is True + + def test__given_unhealthy_api__then_returns_false( + self, mock_httpx_client, mock_modal_logger + ): + # Given + mock_httpx_client.get.return_value = create_mock_httpx_response( + status_code=503, + json_data={"status": "unhealthy"}, + ) + api = SimulationAPIModal() + + # When + is_healthy = api.health_check() + + # Then + assert is_healthy is False + + def test__given_network_error__then_returns_false( + self, mock_httpx_client, mock_modal_logger + ): + # Given + mock_httpx_client.get.side_effect = httpx.RequestError( + "Connection failed" + ) + api = SimulationAPIModal() + + # When + is_healthy = api.health_check() + + # Then + assert is_healthy is False diff --git a/tests/unit/services/test_economy_service.py b/tests/unit/services/test_economy_service.py index d870a627d..ba4a4e586 100644 --- a/tests/unit/services/test_economy_service.py +++ b/tests/unit/services/test_economy_service.py @@ -422,6 +422,109 @@ def test__given_unknown_state__raises_error( exc_info.value ) + # Modal status tests + def test__given_modal_complete_state__then_returns_completed_result( + self, + economy_service, + setup_options, + mock_simulation_api, + mock_reform_impacts_service, + mock_logger, + ): + # Given + reform_impact = create_mock_reform_impact(status="computing") + mock_execution = MagicMock() + mock_simulation_api.get_execution_result.return_value = ( + MOCK_REFORM_IMPACT_DATA + ) + + # When + result = economy_service._handle_execution_state( + setup_options, "complete", reform_impact, mock_execution + ) + + # Then + assert result.status == ImpactStatus.OK + assert result.data == MOCK_REFORM_IMPACT_DATA + mock_reform_impacts_service.set_complete_reform_impact.assert_called_once() + + def test__given_modal_failed_state__then_returns_error_result( + self, + economy_service, + setup_options, + mock_reform_impacts_service, + mock_logger, + ): + # Given + reform_impact = create_mock_reform_impact(status="computing") + mock_execution = MagicMock() + mock_execution.error = None + + # When + result = economy_service._handle_execution_state( + setup_options, "failed", reform_impact, mock_execution + ) + + # Then + assert result.status == ImpactStatus.ERROR + assert result.data is None + mock_reform_impacts_service.set_error_reform_impact.assert_called_once() + + def test__given_modal_failed_state_with_error_message__then_includes_error_in_message( + self, + economy_service, + setup_options, + mock_reform_impacts_service, + mock_logger, + ): + # Given + reform_impact = create_mock_reform_impact(status="computing") + mock_execution = MagicMock() + mock_execution.error = "Simulation timed out" + + # When + result = economy_service._handle_execution_state( + setup_options, "failed", reform_impact, mock_execution + ) + + # Then + assert result.status == ImpactStatus.ERROR + # Verify the error message was passed to the service + call_args = ( + mock_reform_impacts_service.set_error_reform_impact.call_args + ) + assert "Simulation timed out" in call_args[1]["message"] + + def test__given_modal_running_state__then_returns_computing_result( + self, economy_service, setup_options, mock_logger + ): + # Given + reform_impact = create_mock_reform_impact(status="computing") + + # When + result = economy_service._handle_execution_state( + setup_options, "running", reform_impact + ) + + # Then + assert result.status == ImpactStatus.COMPUTING + assert result.data is None + + def test__given_modal_submitted_state__then_returns_computing_result( + self, economy_service, setup_options, mock_logger + ): + # Given + reform_impact = create_mock_reform_impact(status="computing") + + # When + result = economy_service._handle_execution_state( + setup_options, "submitted", reform_impact + ) + + # Then + assert result.status == ImpactStatus.COMPUTING + assert result.data is None + class TestCreateProcessId: @pytest.fixture