From 58778ba6cd103ad1d809a72b5a56972c1d6d870e Mon Sep 17 00:00:00 2001 From: Anthony Volk Date: Wed, 14 Jan 2026 23:44:07 +0300 Subject: [PATCH 1/4] feat: Enable Modal simulation API by default MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Change the default value of USE_MODAL_SIMULATION_API from "false" to "true", making Modal the default backend for economy simulations. The GCP Workflows backend can still be used by setting USE_MODAL_SIMULATION_API=false. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 --- changelog_entry.yaml | 4 ++++ policyengine_api/libs/simulation_api_factory.py | 8 ++------ 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/changelog_entry.yaml b/changelog_entry.yaml index e69de29bb..3332d2280 100644 --- a/changelog_entry.yaml +++ b/changelog_entry.yaml @@ -0,0 +1,4 @@ +- bump: minor + changes: + changed: + - Enable Modal simulation API by default instead of GCP Workflows diff --git a/policyengine_api/libs/simulation_api_factory.py b/policyengine_api/libs/simulation_api_factory.py index c94d1f90b..4ae3b84b2 100644 --- a/policyengine_api/libs/simulation_api_factory.py +++ b/policyengine_api/libs/simulation_api_factory.py @@ -17,9 +17,7 @@ from policyengine_api.gcp_logging import logger -def get_simulation_api() -> ( - Union["SimulationAPI", "SimulationAPIModal"] # noqa: F821 -): +def get_simulation_api() -> Union["SimulationAPI", "SimulationAPIModal"]: # noqa: F821 """ Get the appropriate simulation API client based on environment configuration. @@ -36,9 +34,7 @@ def get_simulation_api() -> ( ValueError If GCP client is requested but GOOGLE_APPLICATION_CREDENTIALS is not set. """ - use_modal = ( - os.environ.get("USE_MODAL_SIMULATION_API", "false").lower() == "true" - ) + use_modal = os.environ.get("USE_MODAL_SIMULATION_API", "true").lower() == "true" if use_modal: logger.log_struct( From ebaa39c59c6b2c17bd8f157a7eafd17be5d04711 Mon Sep 17 00:00:00 2001 From: Anthony Volk Date: Thu, 15 Jan 2026 00:29:56 +0300 Subject: [PATCH 2/4] fix: Update tests for Modal default and run linter MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Update test to expect Modal API when env var is not set (default changed) - Update GCP credentials test to explicitly set USE_MODAL_SIMULATION_API=false - Run black formatter on entire package 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 --- .../ai_prompts/simulation_analysis_prompt.py | 12 +- policyengine_api/api.py | 16 +-- policyengine_api/country.py | 32 ++--- .../data/congressional_districts.py | 4 +- policyengine_api/data/data.py | 8 +- policyengine_api/data/model_setup.py | 8 +- policyengine_api/endpoints/economy/compare.py | 68 +++------- policyengine_api/endpoints/household.py | 10 +- policyengine_api/endpoints/policy.py | 32 ++--- policyengine_api/libs/simulation_api.py | 12 +- policyengine_api/routes/economy_routes.py | 36 +++--- policyengine_api/routes/household_routes.py | 20 +-- policyengine_api/routes/metadata_routes.py | 4 +- policyengine_api/routes/policy_routes.py | 4 +- .../routes/report_output_routes.py | 12 +- .../routes/simulation_analysis_routes.py | 4 +- policyengine_api/routes/simulation_routes.py | 4 +- .../services/ai_analysis_service.py | 4 +- policyengine_api/services/economy_service.py | 70 +++++------ .../services/household_service.py | 12 +- .../services/report_output_service.py | 12 +- .../services/simulation_analysis_service.py | 12 +- .../services/simulation_service.py | 16 +-- .../services/tracer_analysis_service.py | 16 +-- .../validate_household_payload.py | 4 +- .../validate_set_policy_payload.py | 4 +- policyengine_api/utils/singleton.py | 4 +- .../test_environment_variables.py | 4 +- tests/fixtures/integration/simulations.py | 12 +- .../fixtures/services/ai_analysis_service.py | 12 +- tests/fixtures/services/economy_service.py | 19 +-- tests/fixtures/services/household_fixtures.py | 4 +- tests/fixtures/services/policy_service.py | 4 +- tests/integration/test_simulations.py | 8 +- tests/to_refactor/api/test_api.py | 8 +- .../to_refactor_household_fixtures.py | 8 +- .../python/test_ai_analysis_service_old.py | 4 +- .../python/test_household_routes.py | 12 +- .../python/test_policy_service_old.py | 32 ++--- .../python/test_simulation_analysis_routes.py | 12 +- .../python/test_tracer_analysis_routes.py | 12 +- .../python/test_us_policy_macro.py | 8 +- .../python/test_user_profile_routes.py | 16 +-- .../python/test_validate_household_payload.py | 8 +- .../python/test_yearly_var_removal.py | 34 ++--- .../test_simulation_analysis_prompt.py | 12 +- .../unit/data/test_congressional_districts.py | 89 ++++--------- tests/unit/endpoints/economy/test_compare.py | 116 +++++------------ .../unit/libs/test_simulation_api_factory.py | 39 +++--- tests/unit/libs/test_simulation_api_modal.py | 24 +--- .../unit/services/test_ai_analysis_service.py | 3 +- tests/unit/services/test_economy_service.py | 118 +++++------------- tests/unit/services/test_household_service.py | 4 +- tests/unit/services/test_metadata_service.py | 8 +- tests/unit/services/test_policy_service.py | 32 ++--- .../services/test_report_output_service.py | 38 ++---- .../unit/services/test_simulation_service.py | 21 +--- .../services/test_tracer_analysis_service.py | 8 +- tests/unit/services/test_tracer_service.py | 4 +- .../services/test_update_profile_service.py | 12 +- tests/unit/services/test_user_service.py | 4 +- tests/unit/test_country.py | 16 +-- 62 files changed, 326 insertions(+), 879 deletions(-) diff --git a/policyengine_api/ai_prompts/simulation_analysis_prompt.py b/policyengine_api/ai_prompts/simulation_analysis_prompt.py index e7605771f..dc809312e 100644 --- a/policyengine_api/ai_prompts/simulation_analysis_prompt.py +++ b/policyengine_api/ai_prompts/simulation_analysis_prompt.py @@ -95,18 +95,12 @@ def generate_simulation_analysis_prompt(params: InboundParameters) -> str: ) impact_budget: str = json.dumps(parameters.impact["budget"]) - impact_intra_decile: dict[str, Any] = json.dumps( - parameters.impact["intra_decile"] - ) + impact_intra_decile: dict[str, Any] = json.dumps(parameters.impact["intra_decile"]) impact_decile: str = json.dumps(parameters.impact["decile"]) impact_inequality: str = json.dumps(parameters.impact["inequality"]) impact_poverty: str = json.dumps(parameters.impact["poverty"]["poverty"]) - impact_deep_poverty: str = json.dumps( - parameters.impact["poverty"]["deep_poverty"] - ) - impact_poverty_by_gender: str = json.dumps( - parameters.impact["poverty_by_gender"] - ) + impact_deep_poverty: str = json.dumps(parameters.impact["poverty"]["deep_poverty"]) + impact_poverty_by_gender: str = json.dumps(parameters.impact["poverty_by_gender"]) all_parameters: AllParameters = AllParameters.model_validate( { diff --git a/policyengine_api/api.py b/policyengine_api/api.py index b22529b31..112cce9ac 100644 --- a/policyengine_api/api.py +++ b/policyengine_api/api.py @@ -132,9 +132,7 @@ def log_timing(message): app.route("//calculate-full", methods=["POST"])( cache.cached(make_cache_key=make_cache_key)( - lambda *args, **kwargs: get_calculate( - *args, **kwargs, add_missing=True - ) + lambda *args, **kwargs: get_calculate(*args, **kwargs, add_missing=True) ) ) log_timing("Calculate-full endpoint registered") @@ -153,9 +151,7 @@ def log_timing(message): app.route("//user-policy", methods=["PUT"])(update_user_policy) log_timing("User policy update endpoint registered") -app.route("//user-policy/", methods=["GET"])( - get_user_policy -) +app.route("//user-policy/", methods=["GET"])(get_user_policy) log_timing("User policy get endpoint registered") app.register_blueprint(user_profile_bp) @@ -177,9 +173,7 @@ def log_timing(message): @app.route("/liveness-check", methods=["GET"]) def liveness_check(): - return flask.Response( - "OK", status=200, headers={"Content-Type": "text/plain"} - ) + return flask.Response("OK", status=200, headers={"Content-Type": "text/plain"}) log_timing("Liveness check endpoint registered") @@ -187,9 +181,7 @@ def liveness_check(): @app.route("/readiness-check", methods=["GET"]) def readiness_check(): - return flask.Response( - "OK", status=200, headers={"Content-Type": "text/plain"} - ) + return flask.Response("OK", status=200, headers={"Content-Type": "text/plain"}) log_timing("Readiness check endpoint registered") diff --git a/policyengine_api/country.py b/policyengine_api/country.py index 29f64fbbe..a9b4695ec 100644 --- a/policyengine_api/country.py +++ b/policyengine_api/country.py @@ -60,9 +60,7 @@ def build_metadata(self): }[self.country_id], basicInputs=self.tax_benefit_system.basic_inputs, modelled_policies=self.tax_benefit_system.modelled_policies, - version=pkg_resources.get_distribution( - self.country_package_name - ).version, + version=pkg_resources.get_distribution(self.country_package_name).version, ) def build_microsimulation_options(self) -> dict: @@ -77,13 +75,9 @@ def build_microsimulation_options(self) -> dict: region = [ dict(name="uk", label="the UK", type="national"), dict(name="country/england", label="England", type="country"), - dict( - name="country/scotland", label="Scotland", type="country" - ), + dict(name="country/scotland", label="Scotland", type="country"), dict(name="country/wales", label="Wales", type="country"), - dict( - name="country/ni", label="Northern Ireland", type="country" - ), + dict(name="country/ni", label="Northern Ireland", type="country"), ] for i in range(len(constituency_names)): region.append( @@ -130,9 +124,7 @@ def build_microsimulation_options(self) -> dict: dict(name="state/co", label="Colorado", type="state"), dict(name="state/ct", label="Connecticut", type="state"), dict(name="state/de", label="Delaware", type="state"), - dict( - name="state/dc", label="District of Columbia", type="state" - ), + dict(name="state/dc", label="District of Columbia", type="state"), dict(name="state/fl", label="Florida", type="state"), dict(name="state/ga", label="Georgia", type="state"), dict(name="state/hi", label="Hawaii", type="state"), @@ -300,9 +292,7 @@ def build_parameters(self) -> dict: ), } elif isinstance(parameter, ParameterScaleBracket): - bracket_index = int( - parameter.name[parameter.name.index("[") + 1 : -1] - ) + bracket_index = int(parameter.name[parameter.name.index("[") + 1 : -1]) # Set the label to 'first bracket' for the first bracket, 'second bracket' for the second, etc. bracket_label = f"bracket {bracket_index + 1}" parameter_data[parameter.name] = { @@ -379,9 +369,7 @@ def calculate( for parameter_name in reform: for time_period, value in reform[parameter_name].items(): start_instant, end_instant = time_period.split(".") - parameter = get_parameter( - system.parameters, parameter_name - ) + parameter = get_parameter(system.parameters, parameter_name) node_type = type(parameter.values_list[-1].value) if node_type == int: node_type = float @@ -461,12 +449,8 @@ def calculate( if "axes" in household: pass else: - household[entity_plural][entity_id][variable_name][ - period - ] = None - print( - f"Error computing {variable_name} for {entity_id}: {e}" - ) + household[entity_plural][entity_id][variable_name][period] = None + print(f"Error computing {variable_name} for {entity_id}: {e}") tracer_output = simulation.tracer.computation_log log_lines = tracer_output.lines(aggregate=False, max_depth=10) diff --git a/policyengine_api/data/congressional_districts.py b/policyengine_api/data/congressional_districts.py index 7aa54ab8c..b085a0fa5 100644 --- a/policyengine_api/data/congressional_districts.py +++ b/policyengine_api/data/congressional_districts.py @@ -684,9 +684,7 @@ def build_congressional_district_metadata() -> list[dict]: return [ { "name": _build_district_name(district.state_code, district.number), - "label": _build_district_label( - district.state_code, district.number - ), + "label": _build_district_label(district.state_code, district.number), "type": "congressional_district", "state_abbreviation": district.state_code, "state_name": STATE_CODE_TO_NAME[district.state_code], diff --git a/policyengine_api/data/data.py b/policyengine_api/data/data.py index c64ffd065..a1f479227 100644 --- a/policyengine_api/data/data.py +++ b/policyengine_api/data/data.py @@ -30,9 +30,7 @@ def __init__( self.local = local if local: # Local development uses a sqlite database. - self.db_url = ( - REPO / "policyengine_api" / "data" / "policyengine.db" - ) + self.db_url = REPO / "policyengine_api" / "data" / "policyengine.db" if initialize or not Path(self.db_url).exists(): self.initialize() else: @@ -41,9 +39,7 @@ def __init__( self.initialize() def _create_pool(self): - instance_connection_name = ( - "policyengine-api:us-central1:policyengine-api-data" - ) + instance_connection_name = "policyengine-api:us-central1:policyengine-api-data" self.connector = Connector() db_user = "policyengine" db_pass = os.environ["POLICYENGINE_DB_PASSWORD"] diff --git a/policyengine_api/data/model_setup.py b/policyengine_api/data/model_setup.py index a2a6a3ee7..739f7bbcc 100644 --- a/policyengine_api/data/model_setup.py +++ b/policyengine_api/data/model_setup.py @@ -37,11 +37,7 @@ def get_dataset_version(country_id: str) -> str | None: for dataset in datasets["uk"]: - datasets["uk"][ - dataset - ] = f"{datasets['uk'][dataset]}@{get_dataset_version('uk')}" + datasets["uk"][dataset] = f"{datasets['uk'][dataset]}@{get_dataset_version('uk')}" for dataset in datasets["us"]: - datasets["us"][ - dataset - ] = f"{datasets['us'][dataset]}@{get_dataset_version('us')}" + datasets["us"][dataset] = f"{datasets['us'][dataset]}@{get_dataset_version('us')}" diff --git a/policyengine_api/endpoints/economy/compare.py b/policyengine_api/endpoints/economy/compare.py index c97a03f6f..117decb39 100644 --- a/policyengine_api/endpoints/economy/compare.py +++ b/policyengine_api/endpoints/economy/compare.py @@ -10,12 +10,8 @@ def budgetary_impact(baseline: dict, reform: dict) -> dict: tax_revenue_impact = reform["total_tax"] - baseline["total_tax"] - state_tax_revenue_impact = ( - reform["total_state_tax"] - baseline["total_state_tax"] - ) - benefit_spending_impact = ( - reform["total_benefits"] - baseline["total_benefits"] - ) + state_tax_revenue_impact = reform["total_state_tax"] - baseline["total_state_tax"] + benefit_spending_impact = reform["total_benefits"] - baseline["total_benefits"] budgetary_impact = tax_revenue_impact - benefit_spending_impact return dict( budgetary_impact=budgetary_impact, @@ -28,14 +24,10 @@ def budgetary_impact(baseline: dict, reform: dict) -> dict: def labor_supply_response(baseline: dict, reform: dict) -> dict: - substitution_lsr = ( - reform["substitution_lsr"] - baseline["substitution_lsr"] - ) + substitution_lsr = reform["substitution_lsr"] - baseline["substitution_lsr"] income_lsr = reform["income_lsr"] - baseline["income_lsr"] total_change = substitution_lsr + income_lsr - revenue_change = ( - reform["budgetary_impact_lsr"] - baseline["budgetary_impact_lsr"] - ) + revenue_change = reform["budgetary_impact_lsr"] - baseline["budgetary_impact_lsr"] substitution_lsr_hh = np.array(reform["substitution_lsr_hh"]) - np.array( baseline["substitution_lsr_hh"] @@ -48,17 +40,13 @@ def labor_supply_response(baseline: dict, reform: dict) -> dict: total_lsr_hh = substitution_lsr_hh + income_lsr_hh - emp_income = MicroSeries( - baseline["employment_income_hh"], weights=household_weight - ) + emp_income = MicroSeries(baseline["employment_income_hh"], weights=household_weight) self_emp_income = MicroSeries( baseline["self_employment_income_hh"], weights=household_weight ) earnings = emp_income + self_emp_income original_earnings = earnings - total_lsr_hh - substitution_lsr_hh = MicroSeries( - substitution_lsr_hh, weights=household_weight - ) + substitution_lsr_hh = MicroSeries(substitution_lsr_hh, weights=household_weight) income_lsr_hh = MicroSeries(income_lsr_hh, weights=household_weight) decile_avg = dict( @@ -81,9 +69,7 @@ def labor_supply_response(baseline: dict, reform: dict) -> dict: substitution=(substitution_lsr_hh.sum() / original_earnings.sum()), ) - decile_rel["income"] = { - int(k): v for k, v in decile_rel["income"].items() if k > 0 - } + decile_rel["income"] = {int(k): v for k, v in decile_rel["income"].items() if k > 0} decile_rel["substitution"] = { int(k): v for k, v in decile_rel["substitution"].items() if k > 0 } @@ -112,9 +98,7 @@ def labor_supply_response(baseline: dict, reform: dict) -> dict: ) -def detailed_budgetary_impact( - baseline: dict, reform: dict, country_id: str -) -> dict: +def detailed_budgetary_impact(baseline: dict, reform: dict, country_id: str) -> dict: result = {} if country_id == "uk": for program in baseline["programs"]: @@ -122,8 +106,7 @@ def detailed_budgetary_impact( result[program] = dict( baseline=baseline["programs"][program], reform=reform["programs"][program], - difference=reform["programs"][program] - - baseline["programs"][program], + difference=reform["programs"][program] - baseline["programs"][program], ) return result @@ -289,9 +272,7 @@ def poverty_impact(baseline: dict, reform: dict) -> dict: reform=float(reform_deep_poverty[age < 18].mean()), ), adult=dict( - baseline=float( - baseline_deep_poverty[(age >= 18) & (age < 65)].mean() - ), + baseline=float(baseline_deep_poverty[(age >= 18) & (age < 65)].mean()), reform=float(reform_deep_poverty[(age >= 18) & (age < 65)].mean()), ), senior=dict( @@ -323,9 +304,7 @@ def intra_decile_impact(baseline: dict, reform: dict) -> dict: decile = MicroSeries(baseline["household_income_decile"]).values absolute_change = (reform_income - baseline_income).values capped_baseline_income = np.maximum(baseline_income.values, 1) - capped_reform_income = ( - np.maximum(reform_income.values, 1) + absolute_change - ) + capped_reform_income = np.maximum(reform_income.values, 1) + absolute_change income_change = ( capped_reform_income - capped_baseline_income ) / capped_baseline_income @@ -362,9 +341,7 @@ def intra_decile_impact(baseline: dict, reform: dict) -> dict: if people_in_decile == 0 and people_in_both == 0: people_in_proportion: float = 0.0 else: - people_in_proportion: float = float( - people_in_both / people_in_decile - ) + people_in_proportion: float = float(people_in_both / people_in_decile) outcome_groups[label].append(people_in_proportion) @@ -385,9 +362,7 @@ def intra_wealth_decile_impact(baseline: dict, reform: dict) -> dict: decile = MicroSeries(baseline["household_wealth_decile"]).values absolute_change = (reform_income - baseline_income).values capped_baseline_income = np.maximum(baseline_income.values, 1) - capped_reform_income = ( - np.maximum(reform_income.values, 1) + absolute_change - ) + capped_reform_income = np.maximum(reform_income.values, 1) + absolute_change income_change = ( capped_reform_income - capped_baseline_income ) / capped_baseline_income @@ -424,9 +399,7 @@ def intra_wealth_decile_impact(baseline: dict, reform: dict) -> dict: if people_in_decile == 0 and people_in_both == 0: people_in_proportion = 0 else: - people_in_proportion: float = float( - people_in_both / people_in_decile - ) + people_in_proportion: float = float(people_in_both / people_in_decile) outcome_groups[label].append(people_in_proportion) @@ -508,9 +481,7 @@ def poverty_racial_breakdown(baseline: dict, reform: dict) -> dict: reform_poverty = MicroSeries( reform["person_in_poverty"], weights=baseline_poverty.weights ) - race = MicroSeries( - baseline["race"] - ) # Can be WHITE, BLACK, HISPANIC, or OTHER. + race = MicroSeries(baseline["race"]) # Can be WHITE, BLACK, HISPANIC, or OTHER. poverty = dict( white=dict( @@ -752,10 +723,7 @@ def uk_local_authority_breakdown( continue elif selected_country == "wales" and not code.startswith("W"): continue - elif ( - selected_country == "northern_ireland" - and not code.startswith("N") - ): + elif selected_country == "northern_ireland" and not code.startswith("N"): continue weight: np.ndarray = weights[i] @@ -841,9 +809,7 @@ def compare_economic_outputs( uk_local_authority_breakdown(baseline, reform, country_id, region) ) if local_authority_impact_data is not None: - local_authority_impact_data = ( - local_authority_impact_data.model_dump() - ) + local_authority_impact_data = local_authority_impact_data.model_dump() try: wealth_decile_impact_data = wealth_decile_impact(baseline, reform) intra_wealth_decile_impact_data = intra_wealth_decile_impact( diff --git a/policyengine_api/endpoints/household.py b/policyengine_api/endpoints/household.py index b841c5e10..edd647906 100644 --- a/policyengine_api/endpoints/household.py +++ b/policyengine_api/endpoints/household.py @@ -41,11 +41,7 @@ def add_yearly_variables(household, country_id): if variables[variable]["isInputVariable"]: household[entity_plural][entity][ variables[variable]["name"] - ] = { - household_year: variables[variable][ - "defaultValue" - ] - } + ] = {household_year: variables[variable]["defaultValue"]} else: household[entity_plural][entity][ variables[variable]["name"] @@ -75,9 +71,7 @@ def get_household_year(household): @validate_country -def get_household_under_policy( - country_id: str, household_id: str, policy_id: str -): +def get_household_under_policy(country_id: str, household_id: str, policy_id: str): """Get a household's output data under a given policy. Args: diff --git a/policyengine_api/endpoints/policy.py b/policyengine_api/endpoints/policy.py index 90cfa9bd7..daf428b32 100644 --- a/policyengine_api/endpoints/policy.py +++ b/policyengine_api/endpoints/policy.py @@ -30,9 +30,7 @@ def get_policy_search(country_id: str) -> dict: query = request.args.get("query", "") # The "json.loads" default type is added to convert lowercase # "true" and "false" to Python-friendly bool values - unique_only = request.args.get( - "unique_only", default=False, type=json.loads - ) + unique_only = request.args.get("unique_only", default=False, type=json.loads) try: results = database.query( @@ -47,9 +45,7 @@ def get_policy_search(country_id: str) -> dict: status="error", message=f"No policies found for country {country_id} for query '{query}", ) - return Response( - json.dumps(body), status=404, mimetype="application/json" - ) + return Response(json.dumps(body), status=404, mimetype="application/json") # If unique_only is true, filter results to only include # items where everything except ID is unique @@ -70,22 +66,16 @@ def get_policy_search(country_id: str) -> dict: results = new_results # Format into: [{ id: 1, label: "My policy" }, ...] - policies = [ - dict(id=result["id"], label=result["label"]) for result in results - ] + policies = [dict(id=result["id"], label=result["label"]) for result in results] body = dict( status="ok", message="Policies found", result=policies, ) - return Response( - json.dumps(body), status=200, mimetype="application/json" - ) + return Response(json.dumps(body), status=200, mimetype="application/json") except Exception as e: body = dict(status="error", message=f"Internal server error: {e}") - return Response( - json.dumps(body), status=500, mimetype="application/json" - ) + return Response(json.dumps(body), status=500, mimetype="application/json") @validate_country @@ -177,9 +167,7 @@ def set_user_policy(country_id: str) -> dict: except Exception as e: return Response( json.dumps( - { - "message": f"Internal database error: {e}; please try again later." - } + {"message": f"Internal database error: {e}; please try again later."} ), status=500, mimetype="application/json", @@ -236,9 +224,7 @@ def set_user_policy(country_id: str) -> dict: except Exception as e: return Response( json.dumps( - { - "message": f"Internal database error: {e}; please try again later." - } + {"message": f"Internal database error: {e}; please try again later."} ), status=500, mimetype="application/json", @@ -350,9 +336,7 @@ def update_user_policy(country_id: str) -> dict: except Exception as e: return Response( json.dumps( - { - "message": f"Internal database error: {e}; please try again later." - } + {"message": f"Internal database error: {e}; please try again later."} ), status=500, mimetype="application/json", diff --git a/policyengine_api/libs/simulation_api.py b/policyengine_api/libs/simulation_api.py index 1fbd12b48..0b271e7f1 100644 --- a/policyengine_api/libs/simulation_api.py +++ b/policyengine_api/libs/simulation_api.py @@ -75,13 +75,9 @@ def get_execution_status(self, execution: executions_v1.Execution) -> str: status : str The status of the execution """ - return self.execution_client.get_execution( - name=execution.name - ).state.name + return self.execution_client.get_execution(name=execution.name).state.name - def get_execution_result( - self, execution: executions_v1.Execution - ) -> dict | None: + def get_execution_result(self, execution: executions_v1.Execution) -> dict | None: """ Get the result of an execution @@ -95,9 +91,7 @@ def get_execution_result( result : str The result of the execution """ - result = self.execution_client.get_execution( - name=execution.name - ).result + result = self.execution_client.get_execution(name=execution.name).result try: return json.loads(result) except: diff --git a/policyengine_api/routes/economy_routes.py b/policyengine_api/routes/economy_routes.py index c0de06730..84850b17d 100644 --- a/policyengine_api/routes/economy_routes.py +++ b/policyengine_api/routes/economy_routes.py @@ -18,9 +18,7 @@ "//economy//over/", methods=["GET"], ) -def get_economic_impact( - country_id: str, policy_id: int, baseline_policy_id: int -): +def get_economic_impact(country_id: str, policy_id: int, baseline_policy_id: int): policy_id = int(policy_id or get_current_law_policy_id(country_id)) baseline_policy_id = int( @@ -35,27 +33,21 @@ def get_economic_impact( dataset = options.pop("dataset", "default") time_period = options.pop("time_period") target: Literal["general", "cliff"] = options.pop("target", "general") - api_version = options.pop( - "version", COUNTRY_PACKAGE_VERSIONS.get(country_id) + api_version = options.pop("version", COUNTRY_PACKAGE_VERSIONS.get(country_id)) + + economic_impact_result: EconomicImpactResult = economy_service.get_economic_impact( + country_id=country_id, + policy_id=policy_id, + baseline_policy_id=baseline_policy_id, + region=region, + dataset=dataset, + time_period=time_period, + options=options, + api_version=api_version, + target=target, ) - economic_impact_result: EconomicImpactResult = ( - economy_service.get_economic_impact( - country_id=country_id, - policy_id=policy_id, - baseline_policy_id=baseline_policy_id, - region=region, - dataset=dataset, - time_period=time_period, - options=options, - api_version=api_version, - target=target, - ) - ) - - result_dict: dict[str, str | dict | None] = ( - economic_impact_result.to_dict() - ) + result_dict: dict[str, str | dict | None] = economic_impact_result.to_dict() return Response( json.dumps( diff --git a/policyengine_api/routes/household_routes.py b/policyengine_api/routes/household_routes.py index 893d6defd..59eff51ee 100644 --- a/policyengine_api/routes/household_routes.py +++ b/policyengine_api/routes/household_routes.py @@ -13,9 +13,7 @@ household_service = HouseholdService() -@household_bp.route( - "//household/", methods=["GET"] -) +@household_bp.route("//household/", methods=["GET"]) @validate_country def get_household(country_id: str, household_id: int) -> Response: """ @@ -27,9 +25,7 @@ def get_household(country_id: str, household_id: int) -> Response: """ print(f"Got request for household {household_id} in country {country_id}") - household: dict | None = household_service.get_household( - country_id, household_id - ) + household: dict | None = household_service.get_household(country_id, household_id) if household is None: raise NotFound(f"Household #{household_id} not found.") else: @@ -67,9 +63,7 @@ def post_household(country_id: str) -> Response: label: str | None = payload.get("label") household_json: dict = payload.get("data") - household_id = household_service.create_household( - country_id, household_json, label - ) + household_id = household_service.create_household(country_id, household_json, label) return Response( json.dumps( @@ -86,9 +80,7 @@ def post_household(country_id: str) -> Response: ) -@household_bp.route( - "//household/", methods=["PUT"] -) +@household_bp.route("//household/", methods=["PUT"]) @validate_country def update_household(country_id: str, household_id: int) -> Response: """ @@ -111,9 +103,7 @@ def update_household(country_id: str, household_id: int) -> Response: label: str | None = payload.get("label") household_json: dict = payload.get("data") - household: dict | None = household_service.get_household( - country_id, household_id - ) + household: dict | None = household_service.get_household(country_id, household_id) if household is None: raise NotFound(f"Household #{household_id} not found.") diff --git a/policyengine_api/routes/metadata_routes.py b/policyengine_api/routes/metadata_routes.py index 496d9556d..8dd5465e4 100644 --- a/policyengine_api/routes/metadata_routes.py +++ b/policyengine_api/routes/metadata_routes.py @@ -20,9 +20,7 @@ def get_metadata(country_id: str) -> Response: # Retrieve country metadata and add status and message to the response country_metadata = metadata_service.get_metadata(country_id) return Response( - json.dumps( - {"status": "ok", "message": None, "result": country_metadata} - ), + json.dumps({"status": "ok", "message": None, "result": country_metadata}), status=200, mimetype="application/json", ) diff --git a/policyengine_api/routes/policy_routes.py b/policyengine_api/routes/policy_routes.py index 913eb105c..3fc88fbf4 100644 --- a/policyengine_api/routes/policy_routes.py +++ b/policyengine_api/routes/policy_routes.py @@ -76,6 +76,4 @@ def set_policy(country_id: str) -> Response: ) code = 200 if is_existing_policy else 201 - return Response( - json.dumps(response_body), status=code, mimetype="application/json" - ) + return Response(json.dumps(response_body), status=code, mimetype="application/json") diff --git a/policyengine_api/routes/report_output_routes.py b/policyengine_api/routes/report_output_routes.py index 4dfb9218a..a95630c33 100644 --- a/policyengine_api/routes/report_output_routes.py +++ b/policyengine_api/routes/report_output_routes.py @@ -33,9 +33,7 @@ def create_report_output(country_id: str) -> Response: # Extract required fields simulation_1_id = payload.get("simulation_1_id") simulation_2_id = payload.get("simulation_2_id") # Optional - year = payload.get( - "year", CURRENT_YEAR - ) # Default to current year as string + year = payload.get("year", CURRENT_YEAR) # Default to current year as string # Validate required fields if simulation_1_id is None: @@ -95,9 +93,7 @@ def create_report_output(country_id: str) -> Response: raise BadRequest(f"Failed to create report output: {str(e)}") -@report_output_bp.route( - "//report/", methods=["GET"] -) +@report_output_bp.route("//report/", methods=["GET"]) @validate_country def get_report_output(country_id: str, report_id: int) -> Response: """ @@ -109,9 +105,7 @@ def get_report_output(country_id: str, report_id: int) -> Response: """ print(f"Getting report output {report_id} for country {country_id}") - report_output: dict | None = report_output_service.get_report_output( - report_id - ) + report_output: dict | None = report_output_service.get_report_output(report_id) if report_output is None: raise NotFound(f"Report #{report_id} not found.") diff --git a/policyengine_api/routes/simulation_analysis_routes.py b/policyengine_api/routes/simulation_analysis_routes.py index 893d7cae4..5157b807d 100644 --- a/policyengine_api/routes/simulation_analysis_routes.py +++ b/policyengine_api/routes/simulation_analysis_routes.py @@ -16,9 +16,7 @@ simulation_analysis_service = SimulationAnalysisService() -@simulation_analysis_bp.route( - "//simulation-analysis", methods=["POST"] -) +@simulation_analysis_bp.route("//simulation-analysis", methods=["POST"]) @validate_country def execute_simulation_analysis(country_id): print("Got POST request for simulation analysis") diff --git a/policyengine_api/routes/simulation_routes.py b/policyengine_api/routes/simulation_routes.py index c1210d97d..151c4f942 100644 --- a/policyengine_api/routes/simulation_routes.py +++ b/policyengine_api/routes/simulation_routes.py @@ -96,9 +96,7 @@ def create_simulation(country_id: str) -> Response: raise BadRequest(f"Failed to create simulation: {str(e)}") -@simulation_bp.route( - "//simulation/", methods=["GET"] -) +@simulation_bp.route("//simulation/", methods=["GET"]) @validate_country def get_simulation(country_id: str, simulation_id: int) -> Response: """ diff --git a/policyengine_api/services/ai_analysis_service.py b/policyengine_api/services/ai_analysis_service.py index fa6c56db4..f2fc3c710 100644 --- a/policyengine_api/services/ai_analysis_service.py +++ b/policyengine_api/services/ai_analysis_service.py @@ -45,9 +45,7 @@ def get_existing_analysis(self, prompt: str) -> Optional[str]: def trigger_ai_analysis(self, prompt: str) -> Generator[str, None, None]: # Configure a Claude client - claude_client = anthropic.Anthropic( - api_key=os.getenv("ANTHROPIC_API_KEY") - ) + claude_client = anthropic.Anthropic(api_key=os.getenv("ANTHROPIC_API_KEY")) def generate(): response_text = "" diff --git a/policyengine_api/services/economy_service.py b/policyengine_api/services/economy_service.py index 9ca08b69d..3dd23f447 100644 --- a/policyengine_api/services/economy_service.py +++ b/policyengine_api/services/economy_service.py @@ -164,24 +164,22 @@ def get_economic_impact( if country_id == "uk": country_package_version = None - economic_impact_setup_options = ( - EconomicImpactSetupOptions.model_validate( - { - "process_id": process_id, - "country_id": country_id, - "reform_policy_id": policy_id, - "baseline_policy_id": baseline_policy_id, - "region": region, - "dataset": dataset, - "time_period": time_period, - "options": options, - "api_version": api_version, - "target": target, - "model_version": country_package_version, - "data_version": get_dataset_version(country_id), - "options_hash": options_hash, - } - ) + economic_impact_setup_options = EconomicImpactSetupOptions.model_validate( + { + "process_id": process_id, + "country_id": country_id, + "reform_policy_id": policy_id, + "baseline_policy_id": baseline_policy_id, + "region": region, + "dataset": dataset, + "time_period": time_period, + "options": options, + "api_version": api_version, + "target": target, + "model_version": country_package_version, + "data_version": get_dataset_version(country_id), + "options_hash": options_hash, + } ) # Logging that we've received a request @@ -259,17 +257,15 @@ def _get_previous_impacts( Fetch any previous simulation runs for the given policy reform. """ - previous_impacts: list[Any] = ( - reform_impacts_service.get_all_reform_impacts( - country_id, - policy_id, - baseline_policy_id, - region, - dataset, - time_period, - options_hash, - api_version, - ) + previous_impacts: list[Any] = reform_impacts_service.get_all_reform_impacts( + country_id, + policy_id, + baseline_policy_id, + region, + dataset, + time_period, + options_hash, + api_version, ) return previous_impacts @@ -348,9 +344,7 @@ def _handle_execution_state( and hasattr(execution, "error") and execution.error ): - error_message = ( - f"Simulation API execution failed: {execution.error}" - ) + error_message = f"Simulation API execution failed: {execution.error}" self._set_reform_impact_error( setup_options=setup_options, @@ -371,9 +365,7 @@ def _handle_execution_state( return EconomicImpactResult.computing() else: - raise ValueError( - f"Unexpected sim API execution state: {execution_state}" - ) + raise ValueError(f"Unexpected sim API execution state: {execution_state}") def _handle_completed_impact( self, @@ -473,9 +465,7 @@ def _setup_sim_options( "baseline": json.loads(baseline_policy), "time_period": time_period, "include_cliffs": include_cliffs, - "region": self._setup_region( - country_id=country_id, region=region - ), + "region": self._setup_region(country_id=country_id, region=region), "data": self._setup_data(country_id=country_id, region=region), "model_version": model_version, "data_version": data_version, @@ -514,9 +504,7 @@ def _validate_us_region(self, region: str) -> None: elif region.startswith("congressional_district/"): district_id = region[len("congressional_district/") :] if district_id.lower() not in get_valid_congressional_districts(): - raise ValueError( - f"Invalid congressional district: '{district_id}'" - ) + raise ValueError(f"Invalid congressional district: '{district_id}'") else: raise ValueError(f"Invalid US region: '{region}'") diff --git a/policyengine_api/services/household_service.py b/policyengine_api/services/household_service.py index 4091f71d9..dafc8bc6f 100644 --- a/policyengine_api/services/household_service.py +++ b/policyengine_api/services/household_service.py @@ -40,9 +40,7 @@ def get_household(self, country_id: str, household_id: int) -> dict | None: return household except Exception as e: - print( - f"Error fetching household #{household_id}. Details: {str(e)}" - ) + print(f"Error fetching household #{household_id}. Details: {str(e)}") raise e def create_household( @@ -123,12 +121,8 @@ def update_household( ) # Fetch the updated JSON back from the table - updated_household: dict = self.get_household( - country_id, household_id - ) + updated_household: dict = self.get_household(country_id, household_id) return updated_household except Exception as e: - print( - f"Error updating household #{household_id}. Details: {str(e)}" - ) + print(f"Error updating household #{household_id}. Details: {str(e)}") raise e diff --git a/policyengine_api/services/report_output_service.py b/policyengine_api/services/report_output_service.py index 4793ae018..c0dba45f1 100644 --- a/policyengine_api/services/report_output_service.py +++ b/policyengine_api/services/report_output_service.py @@ -43,17 +43,13 @@ def find_existing_report_output( existing_report = None if row is not None: existing_report = dict(row) - print( - f"Found existing report output with ID: {existing_report['id']}" - ) + print(f"Found existing report output with ID: {existing_report['id']}") # Keep output as JSON string - frontend expects string format return existing_report except Exception as e: - print( - f"Error checking for existing report output. Details: {str(e)}" - ) + print(f"Error checking for existing report output. Details: {str(e)}") raise e def create_report_output( @@ -217,7 +213,5 @@ def update_report_output( return True except Exception as e: - print( - f"Error updating report output #{report_id}. Details: {str(e)}" - ) + print(f"Error updating report output #{report_id}. Details: {str(e)}") raise e diff --git a/policyengine_api/services/simulation_analysis_service.py b/policyengine_api/services/simulation_analysis_service.py index 8949bf2ae..140fe4987 100644 --- a/policyengine_api/services/simulation_analysis_service.py +++ b/policyengine_api/services/simulation_analysis_service.py @@ -29,9 +29,7 @@ def execute_analysis( relevant_parameters: list[dict], relevant_parameter_baseline_values: list[dict], audience: str | None, - ) -> tuple[ - Generator[str, None, None] | str, Literal["streaming", "static"] - ]: + ) -> tuple[Generator[str, None, None] | str, Literal["streaming", "static"]]: """ Execute AI analysis for economy-wide simulation @@ -67,9 +65,7 @@ def execute_analysis( if existing_analysis is not None: return existing_analysis, "static" - print( - "Found no existing AI analysis; triggering new analysis with Claude" - ) + print("Found no existing AI analysis; triggering new analysis with Claude") # Otherwise, pass prompt to Claude, then return streaming function try: analysis = self.trigger_ai_analysis(prompt) @@ -109,9 +105,7 @@ def _generate_simulation_analysis_prompt( } try: - prompt = ai_prompt_service.get_prompt( - "simulation_analysis", prompt_data - ) + prompt = ai_prompt_service.get_prompt("simulation_analysis", prompt_data) return prompt except Exception as e: diff --git a/policyengine_api/services/simulation_service.py b/policyengine_api/services/simulation_service.py index 88f359ae7..a7985cb9b 100644 --- a/policyengine_api/services/simulation_service.py +++ b/policyengine_api/services/simulation_service.py @@ -38,9 +38,7 @@ def find_existing_simulation( existing_simulation = None if row is not None: existing_simulation = dict(row) - print( - f"Found existing simulation with ID: {existing_simulation['id']}" - ) + print(f"Found existing simulation with ID: {existing_simulation['id']}") return existing_simulation @@ -98,9 +96,7 @@ def create_simulation( print(f"Error creating simulation. Details: {str(e)}") raise e - def get_simulation( - self, country_id: str, simulation_id: int - ) -> dict | None: + def get_simulation(self, country_id: str, simulation_id: int) -> dict | None: """ Get a simulation record by ID. @@ -131,9 +127,7 @@ def get_simulation( return simulation except Exception as e: - print( - f"Error fetching simulation #{simulation_id}. Details: {str(e)}" - ) + print(f"Error fetching simulation #{simulation_id}. Details: {str(e)}") raise e def update_simulation( @@ -198,7 +192,5 @@ def update_simulation( return True except Exception as e: - print( - f"Error updating simulation #{simulation_id}. Details: {str(e)}" - ) + print(f"Error updating simulation #{simulation_id}. Details: {str(e)}") raise e diff --git a/policyengine_api/services/tracer_analysis_service.py b/policyengine_api/services/tracer_analysis_service.py index 5857fcef6..2fd072f83 100644 --- a/policyengine_api/services/tracer_analysis_service.py +++ b/policyengine_api/services/tracer_analysis_service.py @@ -18,9 +18,7 @@ def execute_analysis( household_id: str, policy_id: str, variable: str, - ) -> tuple[ - Generator[str, None, None] | str, Literal["static", "streaming"] - ]: + ) -> tuple[Generator[str, None, None] | str, Literal["static", "streaming"]]: """ Executes tracer analysis for a variable in a household @@ -44,9 +42,7 @@ def execute_analysis( # Parse the tracer output for our given variable try: - tracer_segment: list[str] = self._parse_tracer_output( - tracer, variable - ) + tracer_segment: list[str] = self._parse_tracer_output(tracer, variable) except Exception as e: print(f"Error parsing tracer output: {str(e)}") raise e @@ -107,17 +103,13 @@ def _parse_tracer_output(self, tracer_output, target_variable): capturing = False # Input validation - if not isinstance(target_variable, str) or not isinstance( - tracer_output, list - ): + if not isinstance(target_variable, str) or not isinstance(tracer_output, list): return result # Create a regex pattern to match the exact variable name # This will match the variable name followed by optional whitespace, # then optional angle brackets with any content, then optional whitespace - pattern = ( - rf"^(\s*)({re.escape(target_variable)})(?!\w)\s*(?:<[^>]*>)?\s*" - ) + pattern = rf"^(\s*)({re.escape(target_variable)})(?!\w)\s*(?:<[^>]*>)?\s*" for line in tracer_output: # Count leading spaces to determine indentation level diff --git a/policyengine_api/utils/payload_validators/validate_household_payload.py b/policyengine_api/utils/payload_validators/validate_household_payload.py index 7b4f7d951..c66f15e26 100644 --- a/policyengine_api/utils/payload_validators/validate_household_payload.py +++ b/policyengine_api/utils/payload_validators/validate_household_payload.py @@ -19,9 +19,7 @@ def validate_household_payload(payload): # Check that label is either string or None, if present if "label" in payload: - if payload["label"] is not None and not isinstance( - payload["label"], str - ): + if payload["label"] is not None and not isinstance(payload["label"], str): return False, "Label must be a string or None" # Check that data is a dictionary diff --git a/policyengine_api/utils/payload_validators/validate_set_policy_payload.py b/policyengine_api/utils/payload_validators/validate_set_policy_payload.py index a48c75bda..f90f80d17 100644 --- a/policyengine_api/utils/payload_validators/validate_set_policy_payload.py +++ b/policyengine_api/utils/payload_validators/validate_set_policy_payload.py @@ -8,9 +8,7 @@ def validate_set_policy_payload(payload: dict) -> tuple[bool, str | None]: # Check that label is either string or None if "label" in payload: - if payload["label"] is not None and not isinstance( - payload["label"], str - ): + if payload["label"] is not None and not isinstance(payload["label"], str): return False, "Label must be a string or None" # Check that data is a dictionary diff --git a/policyengine_api/utils/singleton.py b/policyengine_api/utils/singleton.py index 28e8a0984..3776cb92d 100644 --- a/policyengine_api/utils/singleton.py +++ b/policyengine_api/utils/singleton.py @@ -3,7 +3,5 @@ class Singleton(type): def __call__(cls, *args, **kwargs): if cls not in cls._instances: - cls._instances[cls] = super(Singleton, cls).__call__( - *args, **kwargs - ) + cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs) return cls._instances[cls] diff --git a/tests/env_variables/test_environment_variables.py b/tests/env_variables/test_environment_variables.py index 23a21ea1d..9bcaa2bc3 100644 --- a/tests/env_variables/test_environment_variables.py +++ b/tests/env_variables/test_environment_variables.py @@ -39,9 +39,7 @@ def test_github_microdata_auth_token(self): """Test if POLICYENGINE_GITHUB_MICRODATA_AUTH_TOKEN is valid by querying GitHub user API.""" token = os.getenv("POLICYENGINE_GITHUB_MICRODATA_AUTH_TOKEN") - assert ( - token is not None - ), "POLICYENGINE_GITHUB_MICRODATA_AUTH_TOKEN is not set" + assert token is not None, "POLICYENGINE_GITHUB_MICRODATA_AUTH_TOKEN is not set" headers = { "Authorization": f"Bearer {token}", diff --git a/tests/fixtures/integration/simulations.py b/tests/fixtures/integration/simulations.py index aefddc9fe..741676047 100644 --- a/tests/fixtures/integration/simulations.py +++ b/tests/fixtures/integration/simulations.py @@ -6,7 +6,9 @@ from unittest.mock import Mock, MagicMock, patch from policyengine_api.endpoints.household import add_yearly_variables -STANDARD_AXES_COUNT = 401 # Not formally defined anywhere, but this value is used throughout the API +STANDARD_AXES_COUNT = ( + 401 # Not formally defined anywhere, but this value is used throughout the API +) SMALL_AXES_COUNT = 5 TEST_YEAR = "2025" TEST_STATE = "NY" @@ -67,10 +69,6 @@ def create_household_with_axes(base_household, axes_config): def setup_small_axes_household(base_household, small_axes_config): """Fixture to setup a household with small axes for testing""" - household_with_axes = create_household_with_axes( - base_household, small_axes_config - ) - household_with_axes = add_yearly_variables( - household_with_axes, TEST_COUNTRY_ID - ) + household_with_axes = create_household_with_axes(base_household, small_axes_config) + household_with_axes = add_yearly_variables(household_with_axes, TEST_COUNTRY_ID) return household_with_axes diff --git a/tests/fixtures/services/ai_analysis_service.py b/tests/fixtures/services/ai_analysis_service.py index a2f4d21c4..95bba3039 100644 --- a/tests/fixtures/services/ai_analysis_service.py +++ b/tests/fixtures/services/ai_analysis_service.py @@ -39,14 +39,10 @@ def _configure(text_chunks: list[str]): # Set up mock stream mock_stream = MagicMock() - mock_client.messages.stream.return_value.__enter__.return_value = ( - mock_stream - ) + mock_client.messages.stream.return_value.__enter__.return_value = mock_stream # Configure stream to yield text events - events = [ - MockEvent(event_type="text", text=chunk) for chunk in text_chunks - ] + events = [MockEvent(event_type="text", text=chunk) for chunk in text_chunks] mock_stream.__iter__.return_value = events return mock_client @@ -67,9 +63,7 @@ def _configure(error_type: str): # Set up mock stream mock_stream = MagicMock() - mock_client.messages.stream.return_value.__enter__.return_value = ( - mock_stream - ) + mock_client.messages.stream.return_value.__enter__.return_value = mock_stream # Configure stream to yield an error event error_event = MockEvent(event_type="error", error={"type": error_type}) diff --git a/tests/fixtures/services/economy_service.py b/tests/fixtures/services/economy_service.py index 293b8909e..d94ffe9b4 100644 --- a/tests/fixtures/services/economy_service.py +++ b/tests/fixtures/services/economy_service.py @@ -30,9 +30,7 @@ MOCK_MODEL_VERSION = "1.2.3" MOCK_DATA_VERSION = None -MOCK_REFORM_POLICY_JSON = json.dumps( - {"sample_param": {"2024-01-01.2100-12-31": 15}} -) +MOCK_REFORM_POLICY_JSON = json.dumps({"sample_param": {"2024-01-01.2100-12-31": 15}}) MOCK_BASELINE_POLICY_JSON = json.dumps({}) @@ -140,9 +138,7 @@ def mock_logger(): def mock_datetime(): """Mock datetime.datetime.now().""" mock_now = datetime.datetime(2025, 6, 26, 12, 0, 0) - with patch( - "policyengine_api.services.economy_service.datetime.datetime" - ) as mock: + with patch("policyengine_api.services.economy_service.datetime.datetime") as mock: mock.now.return_value = mock_now yield mock @@ -172,14 +168,11 @@ def create_mock_reform_impact( "options_hash": MOCK_OPTIONS_HASH, "status": status, "api_version": MOCK_API_VERSION, - "reform_impact_json": reform_impact_json - or json.dumps(MOCK_REFORM_IMPACT_DATA), + "reform_impact_json": reform_impact_json or json.dumps(MOCK_REFORM_IMPACT_DATA), "execution_id": execution_id, "start_time": datetime.datetime(2025, 6, 26, 12, 0, 0), "end_time": ( - datetime.datetime(2025, 6, 26, 12, 5, 0) - if status == "ok" - else None + datetime.datetime(2025, 6, 26, 12, 5, 0) if status == "ok" else None ), } @@ -251,9 +244,7 @@ def mock_simulation_api_modal(): MOCK_US_NATIONWIDE_DATASET = "gs://policyengine-us-data/cps_2023.h5" MOCK_US_STATE_CA_DATASET = "gs://policyengine-us-data/states/CA.h5" MOCK_US_STATE_UT_DATASET = "gs://policyengine-us-data/states/UT.h5" -MOCK_US_CITY_NYC_DATASET = ( - "gs://policyengine-us-data/pooled_3_year_cps_2023.h5" -) +MOCK_US_CITY_NYC_DATASET = "gs://policyengine-us-data/pooled_3_year_cps_2023.h5" MOCK_US_DISTRICT_CA37_DATASET = "gs://policyengine-us-data/districts/CA-37.h5" MOCK_UK_DATASET = "gs://policyengine-uk-data-private/enhanced_frs_2023_24.h5" diff --git a/tests/fixtures/services/household_fixtures.py b/tests/fixtures/services/household_fixtures.py index f84d99c95..d68cad86a 100644 --- a/tests/fixtures/services/household_fixtures.py +++ b/tests/fixtures/services/household_fixtures.py @@ -23,9 +23,7 @@ @pytest.fixture def mock_hash_object(): """Mock the hash_object function.""" - with patch( - "policyengine_api.services.household_service.hash_object" - ) as mock: + with patch("policyengine_api.services.household_service.hash_object") as mock: mock.return_value = valid_hash_value yield mock diff --git a/tests/fixtures/services/policy_service.py b/tests/fixtures/services/policy_service.py index 18ee9071e..6c4a27f66 100644 --- a/tests/fixtures/services/policy_service.py +++ b/tests/fixtures/services/policy_service.py @@ -3,9 +3,7 @@ from unittest.mock import patch valid_policy_json = { - "data": { - "gov.irs.income.bracket.rates.2": {"2024-01-01.2024-12-31": 0.2433} - }, + "data": {"gov.irs.income.bracket.rates.2": {"2024-01-01.2024-12-31": 0.2433}}, } valid_hash_value = "NgJhpeuRVnIAwgYWuJsd2fI/N88rIE6Kcj8q4TPD/i4=" diff --git a/tests/integration/test_simulations.py b/tests/integration/test_simulations.py index 36056f239..37f8da106 100644 --- a/tests/integration/test_simulations.py +++ b/tests/integration/test_simulations.py @@ -40,13 +40,9 @@ def test__given_any_number_of_axes__sim_returns_valid_arrays( print("Variable name: ", variable_name) if variable_name in FORBIDDEN_VARIABLES: continue - for period in result[entity_type][entity_id][ - variable_name - ]: + for period in result[entity_type][entity_id][variable_name]: print("Period: ", period) - value = result[entity_type][entity_id][variable_name][ - period - ] + value = result[entity_type][entity_id][variable_name][period] print(f"Value: {value}") if isinstance(value, list): # Assert no Nones diff --git a/tests/to_refactor/api/test_api.py b/tests/to_refactor/api/test_api.py index 74f3e2bd6..f0855a6ec 100644 --- a/tests/to_refactor/api/test_api.py +++ b/tests/to_refactor/api/test_api.py @@ -23,9 +23,7 @@ def client(): # - expected_result: the expected result of the endpoint test_paths = [ - path - for path in (Path(__file__).parent).rglob("*") - if path.suffix == ".yaml" + path for path in (Path(__file__).parent).rglob("*") if path.suffix == ".yaml" ] test_data = [yaml.safe_load(path.read_text()) for path in test_paths] test_names = [test["name"] for test in test_data] @@ -70,6 +68,4 @@ def test_response(client, test: dict): json.loads(response.data), test.get("response", {}).get("data", {}) ) elif "html" in test.get("response", {}): - assert response.data.decode("utf-8") == test.get("response", {}).get( - "html", "" - ) + assert response.data.decode("utf-8") == test.get("response", {}).get("html", "") diff --git a/tests/to_refactor/fixtures/to_refactor_household_fixtures.py b/tests/to_refactor/fixtures/to_refactor_household_fixtures.py index 89b854f19..5fa6af91c 100644 --- a/tests/to_refactor/fixtures/to_refactor_household_fixtures.py +++ b/tests/to_refactor/fixtures/to_refactor_household_fixtures.py @@ -22,9 +22,7 @@ @pytest.fixture def mock_hash_object(): """Mock the hash_object function.""" - with patch( - "policyengine_api.services.household_service.hash_object" - ) as mock: + with patch("policyengine_api.services.household_service.hash_object") as mock: mock.return_value = valid_hash_value yield mock @@ -32,7 +30,5 @@ def mock_hash_object(): @pytest.fixture def mock_database(): """Mock the database module.""" - with patch( - "policyengine_api.services.household_service.database" - ) as mock_db: + with patch("policyengine_api.services.household_service.database") as mock_db: yield mock_db diff --git a/tests/to_refactor/python/test_ai_analysis_service_old.py b/tests/to_refactor/python/test_ai_analysis_service_old.py index aa8c825e3..0df3928ca 100644 --- a/tests/to_refactor/python/test_ai_analysis_service_old.py +++ b/tests/to_refactor/python/test_ai_analysis_service_old.py @@ -9,9 +9,7 @@ @patch("policyengine_api.services.ai_analysis_service.local_database") def test_get_existing_analysis_found(mock_db): - mock_db.query.return_value.fetchone.return_value = { - "analysis": "Existing analysis" - } + mock_db.query.return_value.fetchone.return_value = {"analysis": "Existing analysis"} prompt = "Test prompt" output = test_ai_service.get_existing_analysis(prompt) diff --git a/tests/to_refactor/python/test_household_routes.py b/tests/to_refactor/python/test_household_routes.py index e4ea05a1c..5b3ccb812 100644 --- a/tests/to_refactor/python/test_household_routes.py +++ b/tests/to_refactor/python/test_household_routes.py @@ -46,9 +46,7 @@ def test_get_household_invalid_id(self, rest_client): response = rest_client.get("/us/household/invalid") assert response.status_code == 404 - assert ( - b"The requested URL was not found on the server" in response.data - ) + assert b"The requested URL was not found on the server" in response.data class TestCreateHousehold: @@ -116,9 +114,7 @@ def test_update_household_success( mock_row.keys.return_value = valid_db_row.keys() mock_database.query().fetchone.return_value = mock_row - updated_household = { - "people": {"person1": {"age": 31, "income": 55000}} - } + updated_household = {"people": {"person1": {"age": 31, "income": 55000}}} updated_data = { "data": updated_household, @@ -182,9 +178,7 @@ def test_update_household_invalid_payload(self, rest_client): class TestHouseholdRouteServiceErrors: """Test handling of service-level errors in routes.""" - @patch( - "policyengine_api.services.household_service.HouseholdService.get_household" - ) + @patch("policyengine_api.services.household_service.HouseholdService.get_household") def test_get_household_service_error(self, mock_get, rest_client): """Test GET endpoint when service raises an error.""" mock_get.side_effect = Exception("Database connection failed") diff --git a/tests/to_refactor/python/test_policy_service_old.py b/tests/to_refactor/python/test_policy_service_old.py index a90680d80..832816f83 100644 --- a/tests/to_refactor/python/test_policy_service_old.py +++ b/tests/to_refactor/python/test_policy_service_old.py @@ -30,17 +30,13 @@ def policy_service(): class TestPolicyService: - a_test_policy_id = ( - 8 # Pre-seeded current law policies occupy IDs 1 through 5 - ) + a_test_policy_id = 8 # Pre-seeded current law policies occupy IDs 1 through 5 def test_get_policy_success( self, policy_service, mock_database, sample_policy_data ): # Setup mock - mock_database.query.return_value.fetchone.return_value = ( - sample_policy_data - ) + mock_database.query.return_value.fetchone.return_value = sample_policy_data # Test result = policy_service.get_policy("us", self.a_test_policy_id) @@ -64,9 +60,7 @@ def test_get_policy_not_found(self, policy_service, mock_database): assert result is None mock_database.query.assert_called_once() - def test_get_policy_json( - self, policy_service, mock_database, sample_policy_data - ): + def test_get_policy_json(self, policy_service, mock_database, sample_policy_data): # Setup mock mock_database.query.return_value.fetchone.return_value = { "policy_json": sample_policy_data["policy_json"] @@ -131,9 +125,7 @@ def test_set_policy_existing( self, policy_service, mock_database, sample_policy_data ): # Setup mock - mock_database.query.return_value.fetchone.return_value = ( - sample_policy_data - ) + mock_database.query.return_value.fetchone.return_value = sample_policy_data # Test policy_id, message, exists = policy_service.set_policy( @@ -152,9 +144,7 @@ def test_get_unique_policy_with_label( self, policy_service, mock_database, sample_policy_data ): # Setup mock - mock_database.query.return_value.fetchone.return_value = ( - sample_policy_data - ) + mock_database.query.return_value.fetchone.return_value = sample_policy_data # Test result = policy_service._get_unique_policy_with_label( @@ -167,16 +157,12 @@ def test_get_unique_policy_with_label( assert result == sample_policy_data mock_database.query.assert_called_once() - def test_get_unique_policy_with_null_label( - self, policy_service, mock_database - ): + def test_get_unique_policy_with_null_label(self, policy_service, mock_database): # Setup mock mock_database.query.return_value.fetchone.return_value = None # Test - result = policy_service._get_unique_policy_with_label( - "us", "hash123", None - ) + result = policy_service._get_unique_policy_with_label("us", "hash123", None) # Verify assert result is None @@ -207,8 +193,6 @@ def test_error_handling(self, policy_service, mock_database, error_method): elif error_method == "set_policy": policy_service.set_policy("us", "label", {}) else: - policy_service._get_unique_policy_with_label( - "us", "hash", "label" - ) + policy_service._get_unique_policy_with_label("us", "hash", "label") assert str(exc_info.value) == "Database error" diff --git a/tests/to_refactor/python/test_simulation_analysis_routes.py b/tests/to_refactor/python/test_simulation_analysis_routes.py index 0a4812e31..f1f2ab6f1 100644 --- a/tests/to_refactor/python/test_simulation_analysis_routes.py +++ b/tests/to_refactor/python/test_simulation_analysis_routes.py @@ -40,9 +40,7 @@ def test_execute_simulation_analysis_new_analysis(rest_client): ) as mock_trigger: mock_trigger.return_value = (s for s in ["New analysis"]) - response = rest_client.post( - "/us/simulation-analysis", json=test_json - ) + response = rest_client.post("/us/simulation-analysis", json=test_json) assert response.status_code == 200 assert b"New analysis" in response.data @@ -58,9 +56,7 @@ def test_execute_simulation_analysis_error(rest_client): ) as mock_trigger: mock_trigger.side_effect = Exception("Test error") - response = rest_client.post( - "/us/simulation-analysis", json=test_json - ) + response = rest_client.post("/us/simulation-analysis", json=test_json) assert response.status_code == 500 assert "Test error" in response.json.get("message") @@ -95,9 +91,7 @@ def test_execute_simulation_analysis_enhanced_cps(rest_client): with patch( "policyengine_api.services.ai_analysis_service.AIAnalysisService.trigger_ai_analysis" ) as mock_trigger: - mock_trigger.return_value = ( - s for s in ["Enhanced CPS analysis"] - ) + mock_trigger.return_value = (s for s in ["Enhanced CPS analysis"]) response = rest_client.post( "/us/simulation-analysis", json=test_json_enhanced_cps diff --git a/tests/to_refactor/python/test_tracer_analysis_routes.py b/tests/to_refactor/python/test_tracer_analysis_routes.py index f88805f8d..83f7bde23 100644 --- a/tests/to_refactor/python/test_tracer_analysis_routes.py +++ b/tests/to_refactor/python/test_tracer_analysis_routes.py @@ -58,8 +58,7 @@ def test_execute_tracer_analysis_no_tracer(mock_db, rest_client): assert response.status_code == 404 assert ( - "No household simulation tracer found" - in json.loads(response.data)["message"] + "No household simulation tracer found" in json.loads(response.data)["message"] ) @@ -115,9 +114,7 @@ def test_invalid_variable_types(mock_db, rest_client): }, ) assert response.status_code == 400 - assert ( - "variable must be a string" in json.loads(response.data)["message"] - ) + assert "variable must be a string" in json.loads(response.data)["message"] # Test invalid country @@ -218,7 +215,4 @@ def test_validate_tracer_analysis_payload_failure(rest_client): }, ) assert response.status_code == 400 - assert ( - "Missing required key: variable" - in json.loads(response.data)["message"] - ) + assert "Missing required key: variable" in json.loads(response.data)["message"] diff --git a/tests/to_refactor/python/test_us_policy_macro.py b/tests/to_refactor/python/test_us_policy_macro.py index 03cb620d2..9d6c20d82 100644 --- a/tests/to_refactor/python/test_us_policy_macro.py +++ b/tests/to_refactor/python/test_us_policy_macro.py @@ -72,13 +72,9 @@ def utah_reform_runner(rest_client, region: str = "us"): cost = round(result["budget"]["budgetary_impact"] / 1e6, 1) assert ( cost / 95.4 - 1 - ) < 0.01, ( - f"Expected budgetary impact to be 95.4 million, got {cost} million" - ) + ) < 0.01, f"Expected budgetary impact to be 95.4 million, got {cost} million" - assert ( - result["intra_decile"]["all"]["Lose less than 5%"] / 0.637 - 1 - ) < 0.01, ( + assert (result["intra_decile"]["all"]["Lose less than 5%"] / 0.637 - 1) < 0.01, ( f"Expected 63.7% of people to lose less than 5%, got " f"{result['intra_decile']['all']['Lose less than 5%']}" ) diff --git a/tests/to_refactor/python/test_user_profile_routes.py b/tests/to_refactor/python/test_user_profile_routes.py index a3d873dbb..ec30f9eef 100644 --- a/tests/to_refactor/python/test_user_profile_routes.py +++ b/tests/to_refactor/python/test_user_profile_routes.py @@ -42,9 +42,7 @@ def test_set_and_get_record(self, rest_client): assert res.status_code == 200 assert return_object["status"] == "ok" assert return_object["result"]["auth0_id"] == self.auth0_id - assert ( - return_object["result"]["primary_country"] == self.primary_country - ) + assert return_object["result"]["primary_country"] == self.primary_country assert return_object["result"]["username"] == None user_id = return_object["result"]["user_id"] @@ -54,9 +52,7 @@ def test_set_and_get_record(self, rest_client): assert res.status_code == 200 assert return_object["status"] == "ok" - assert ( - return_object["result"]["primary_country"] == self.primary_country - ) + assert return_object["result"]["primary_country"] == self.primary_country assert return_object["result"].get("auth0_id") is None assert return_object["result"]["username"] == None @@ -77,9 +73,7 @@ def test_set_and_get_record(self, rest_client): malicious_updated_profile = {**updated_profile, "auth0_id": "BOGUS"} - res = rest_client.put( - "/us/user-profile", json=malicious_updated_profile - ) + res = rest_client.put("/us/user-profile", json=malicious_updated_profile) return_object = json.loads(res.text) assert res.status_code == 200 @@ -99,9 +93,7 @@ def test_set_and_get_record(self, rest_client): def test_non_existent_record(self, rest_client): non_existent_auth0_id = "non-existent-auth0-id" - res = rest_client.get( - f"/us/user-profile?auth0_id={non_existent_auth0_id}" - ) + res = rest_client.get(f"/us/user-profile?auth0_id={non_existent_auth0_id}") return_object = json.loads(res.text) assert res.status_code == 404 diff --git a/tests/to_refactor/python/test_validate_household_payload.py b/tests/to_refactor/python/test_validate_household_payload.py index 42e6a0708..d45363d0d 100644 --- a/tests/to_refactor/python/test_validate_household_payload.py +++ b/tests/to_refactor/python/test_validate_household_payload.py @@ -14,9 +14,7 @@ class TestHouseholdRouteValidation: {"data": {}, "label": 123}, # Invalid label type ], ) - def test_post_household_invalid_payload( - self, rest_client, invalid_payload - ): + def test_post_household_invalid_payload(self, rest_client, invalid_payload): """Test POST endpoint with various invalid payloads.""" response = rest_client.post( "/us/household", @@ -40,9 +38,7 @@ def test_get_household_invalid_id(self, rest_client, invalid_id): # Default Werkzeug validation returns 404, not 400 assert response.status_code == 404 - assert ( - b"The requested URL was not found on the server" in response.data - ) + assert b"The requested URL was not found on the server" in response.data @pytest.mark.parametrize( "country_id", diff --git a/tests/to_refactor/python/test_yearly_var_removal.py b/tests/to_refactor/python/test_yearly_var_removal.py index e4f463e19..9e8294479 100644 --- a/tests/to_refactor/python/test_yearly_var_removal.py +++ b/tests/to_refactor/python/test_yearly_var_removal.py @@ -154,17 +154,14 @@ def interface_test_household_under_policy( # Skip ignored variables if ( variable in excluded_vars - or metadata["variables"][variable]["definitionPeriod"] - != "year" + or metadata["variables"][variable]["definitionPeriod"] != "year" ): continue # Ensure that the variable exists in both # result_object and test_object if variable not in metadata["variables"]: - print( - f"Failing due to variable {variable} not in metadata" - ) + print(f"Failing due to variable {variable} not in metadata") is_test_passing = False break @@ -188,14 +185,10 @@ def interface_test_household_under_policy( results_diff = result_var_set.difference(metadata_var_set) metadata_diff = metadata_var_set.difference(result_var_set) if len(results_diff) > 0: - print( - "Error: The following values are only present in the result object:" - ) + print("Error: The following values are only present in the result object:") print(results_diff) if len(metadata_diff) > 0: - print( - "Error: The following values are only present in the metadata:" - ) + print("Error: The following values are only present in the metadata:") print(metadata_diff) is_test_passing = False @@ -207,9 +200,7 @@ def test_us_household_under_policy(): Test that a US household under current law is created correctly """ - is_test_passing = interface_test_household_under_policy( - "us", "2", ["members"] - ) + is_test_passing = interface_test_household_under_policy("us", "2", ["members"]) assert is_test_passing == True @@ -285,17 +276,14 @@ def test_get_calculate(client): # Skip ignored variables if ( variable in excluded_vars - or metadata["variables"][variable]["definitionPeriod"] - != "year" + or metadata["variables"][variable]["definitionPeriod"] != "year" ): continue # Ensure that the variable exists in both # result_object and test_object if variable not in metadata["variables"]: - print( - f"Failing due to variable {variable} not in metadata" - ) + print(f"Failing due to variable {variable} not in metadata") is_test_passing = False break @@ -319,14 +307,10 @@ def test_get_calculate(client): results_diff = result_var_set.difference(metadata_var_set) metadata_diff = metadata_var_set.difference(result_var_set) if len(results_diff) > 0: - print( - "Error: The following values are only present in the result object:" - ) + print("Error: The following values are only present in the result object:") print(results_diff) if len(metadata_diff) > 0: - print( - "Error: The following values are only present in the metadata:" - ) + print("Error: The following values are only present in the metadata:") print(metadata_diff) is_test_passing = False diff --git a/tests/unit/ai_prompts/test_simulation_analysis_prompt.py b/tests/unit/ai_prompts/test_simulation_analysis_prompt.py index 05f1931e7..429a9ed10 100644 --- a/tests/unit/ai_prompts/test_simulation_analysis_prompt.py +++ b/tests/unit/ai_prompts/test_simulation_analysis_prompt.py @@ -29,13 +29,11 @@ def test_given_valid_uk_input(self, snapshot): def test_given_dataset_is_enhanced_cps(self, snapshot): snapshot.snapshot_dir = "tests/snapshots" - valid_enhanced_cps_input_data = ( - given_valid_data_and_dataset_is_enhanced_cps(valid_input_us) + valid_enhanced_cps_input_data = given_valid_data_and_dataset_is_enhanced_cps( + valid_input_us ) - prompt = generate_simulation_analysis_prompt( - valid_enhanced_cps_input_data - ) + prompt = generate_simulation_analysis_prompt(valid_enhanced_cps_input_data) snapshot.assert_match( prompt, "simulation_analysis_prompt_dataset_enhanced_cps.txt" ) @@ -46,6 +44,4 @@ def test_given_missing_input_field(self): Exception, match="1 validation error for InboundParameters\ntime_period\n Field required", ): - generate_simulation_analysis_prompt( - invalid_data_missing_input_field - ) + generate_simulation_analysis_prompt(invalid_data_missing_input_field) diff --git a/tests/unit/data/test_congressional_districts.py b/tests/unit/data/test_congressional_districts.py index 05819916a..255cfd4dc 100644 --- a/tests/unit/data/test_congressional_districts.py +++ b/tests/unit/data/test_congressional_districts.py @@ -78,15 +78,11 @@ def test__all_state_codes_are_in_state_code_to_name(self): assert district.state_code in STATE_CODE_TO_NAME def test__california_has_52_districts(self): - ca_districts = [ - d for d in CONGRESSIONAL_DISTRICTS if d.state_code == "CA" - ] + ca_districts = [d for d in CONGRESSIONAL_DISTRICTS if d.state_code == "CA"] assert len(ca_districts) == 52 def test__texas_has_38_districts(self): - tx_districts = [ - d for d in CONGRESSIONAL_DISTRICTS if d.state_code == "TX" - ] + tx_districts = [d for d in CONGRESSIONAL_DISTRICTS if d.state_code == "TX"] assert len(tx_districts) == 38 def test__at_large_states_have_1_district(self): @@ -94,31 +90,23 @@ def test__at_large_states_have_1_district(self): at_large_states = [s for s in AT_LARGE_STATES if s != "DC"] for state_code in at_large_states: state_districts = [ - d - for d in CONGRESSIONAL_DISTRICTS - if d.state_code == state_code + d for d in CONGRESSIONAL_DISTRICTS if d.state_code == state_code ] assert len(state_districts) == 1 assert state_districts[0].number == 1 def test__dc_has_1_district(self): - dc_districts = [ - d for d in CONGRESSIONAL_DISTRICTS if d.state_code == "DC" - ] + dc_districts = [d for d in CONGRESSIONAL_DISTRICTS if d.state_code == "DC"] assert len(dc_districts) == 1 assert dc_districts[0].number == 1 def test__dc_comes_after_delaware(self): # Find indices de_indices = [ - i - for i, d in enumerate(CONGRESSIONAL_DISTRICTS) - if d.state_code == "DE" + i for i, d in enumerate(CONGRESSIONAL_DISTRICTS) if d.state_code == "DE" ] dc_indices = [ - i - for i, d in enumerate(CONGRESSIONAL_DISTRICTS) - if d.state_code == "DC" + i for i, d in enumerate(CONGRESSIONAL_DISTRICTS) if d.state_code == "DC" ] # DC should come after all DE districts assert min(dc_indices) > max(de_indices) @@ -144,36 +132,27 @@ def test__name_has_correct_format(self): metadata = build_congressional_district_metadata() # Check first California district ca_01 = next( - item - for item in metadata - if item["name"] == "congressional_district/CA-01" + item for item in metadata if item["name"] == "congressional_district/CA-01" ) assert ca_01 is not None def test__label_has_correct_format(self): metadata = build_congressional_district_metadata() ca_01 = next( - item - for item in metadata - if item["name"] == "congressional_district/CA-01" + item for item in metadata if item["name"] == "congressional_district/CA-01" ) assert ca_01["label"] == "California's 1st congressional district" def test__state_abbreviation_is_uppercase(self): metadata = build_congressional_district_metadata() for item in metadata: - assert ( - item["state_abbreviation"] - == item["state_abbreviation"].upper() - ) + assert item["state_abbreviation"] == item["state_abbreviation"].upper() assert len(item["state_abbreviation"]) == 2 def test__state_name_matches_abbreviation(self): metadata = build_congressional_district_metadata() ca_01 = next( - item - for item in metadata - if item["name"] == "congressional_district/CA-01" + item for item in metadata if item["name"] == "congressional_district/CA-01" ) assert ca_01["state_abbreviation"] == "CA" assert ca_01["state_name"] == "California" @@ -181,9 +160,7 @@ def test__state_name_matches_abbreviation(self): def test__dc_state_fields(self): metadata = build_congressional_district_metadata() dc_01 = next( - item - for item in metadata - if item["name"] == "congressional_district/DC-01" + item for item in metadata if item["name"] == "congressional_district/DC-01" ) assert dc_01["state_abbreviation"] == "DC" assert dc_01["state_name"] == "District of Columbia" @@ -198,39 +175,25 @@ def test__ordinal_suffixes_are_correct(self): # Find specific districts to test ordinal suffixes ca_01 = next( - item - for item in metadata - if item["name"] == "congressional_district/CA-01" + item for item in metadata if item["name"] == "congressional_district/CA-01" ) ca_02 = next( - item - for item in metadata - if item["name"] == "congressional_district/CA-02" + item for item in metadata if item["name"] == "congressional_district/CA-02" ) ca_03 = next( - item - for item in metadata - if item["name"] == "congressional_district/CA-03" + item for item in metadata if item["name"] == "congressional_district/CA-03" ) ca_11 = next( - item - for item in metadata - if item["name"] == "congressional_district/CA-11" + item for item in metadata if item["name"] == "congressional_district/CA-11" ) ca_12 = next( - item - for item in metadata - if item["name"] == "congressional_district/CA-12" + item for item in metadata if item["name"] == "congressional_district/CA-12" ) ca_21 = next( - item - for item in metadata - if item["name"] == "congressional_district/CA-21" + item for item in metadata if item["name"] == "congressional_district/CA-21" ) ca_22 = next( - item - for item in metadata - if item["name"] == "congressional_district/CA-22" + item for item in metadata if item["name"] == "congressional_district/CA-22" ) assert "1st" in ca_01["label"] @@ -245,17 +208,13 @@ def test__district_numbers_have_leading_zeros(self): metadata = build_congressional_district_metadata() # Single digit districts should have leading zero ca_01 = next( - item - for item in metadata - if item["name"] == "congressional_district/CA-01" + item for item in metadata if item["name"] == "congressional_district/CA-01" ) assert ca_01["name"] == "congressional_district/CA-01" # Double digit districts should not have leading zero ca_37 = next( - item - for item in metadata - if item["name"] == "congressional_district/CA-37" + item for item in metadata if item["name"] == "congressional_district/CA-37" ) assert ca_37["name"] == "congressional_district/CA-37" @@ -275,18 +234,14 @@ def test__at_large_states_have_at_large_label(self): def test__alaska_at_large_label(self): metadata = build_congressional_district_metadata() ak_01 = next( - item - for item in metadata - if item["name"] == "congressional_district/AK-01" + item for item in metadata if item["name"] == "congressional_district/AK-01" ) assert ak_01["label"] == "Alaska's at-large congressional district" def test__wyoming_at_large_label(self): metadata = build_congressional_district_metadata() wy_01 = next( - item - for item in metadata - if item["name"] == "congressional_district/WY-01" + item for item in metadata if item["name"] == "congressional_district/WY-01" ) assert wy_01["label"] == "Wyoming's at-large congressional district" diff --git a/tests/unit/endpoints/economy/test_compare.py b/tests/unit/endpoints/economy/test_compare.py index 17ff66275..8ef1eaec2 100644 --- a/tests/unit/endpoints/economy/test_compare.py +++ b/tests/unit/endpoints/economy/test_compare.py @@ -118,9 +118,7 @@ def test__given_non_uk_country_canada__returns_none(self): result = uk_local_authority_breakdown({}, {}, "ca") assert result is None - @patch( - "policyengine_api.endpoints.economy.compare.download_huggingface_dataset" - ) + @patch("policyengine_api.endpoints.economy.compare.download_huggingface_dataset") @patch("policyengine_api.endpoints.economy.compare.h5py.File") @patch("policyengine_api.endpoints.economy.compare.pd.read_csv") def test__given_uk_country__returns_breakdown( @@ -135,9 +133,7 @@ def test__given_uk_country__returns_breakdown( # Create mock weights - 3 local authorities, 10 households mock_weights = np.ones((3, 10)) mock_h5py_context = MagicMock() - mock_h5py_context.__enter__ = MagicMock( - return_value={"2025": mock_weights} - ) + mock_h5py_context.__enter__ = MagicMock(return_value={"2025": mock_weights}) mock_h5py_context.__exit__ = MagicMock(return_value=False) mock_h5py_file.return_value = mock_h5py_context @@ -229,9 +225,7 @@ def test__outcome_bucket_categorization_logic(self): bucket == expected_bucket ), f"Failed for {percent_change}: expected {expected_bucket}, got {bucket}" - @patch( - "policyengine_api.endpoints.economy.compare.download_huggingface_dataset" - ) + @patch("policyengine_api.endpoints.economy.compare.download_huggingface_dataset") @patch("policyengine_api.endpoints.economy.compare.h5py.File") @patch("policyengine_api.endpoints.economy.compare.pd.read_csv") def test__outcome_buckets_are_correct( @@ -244,9 +238,7 @@ def test__outcome_buckets_are_correct( mock_weights = np.ones((1, 10)) mock_h5py_context = MagicMock() - mock_h5py_context.__enter__ = MagicMock( - return_value={"2025": mock_weights} - ) + mock_h5py_context.__enter__ = MagicMock(return_value={"2025": mock_weights}) mock_h5py_context.__exit__ = MagicMock(return_value=False) mock_h5py_file.return_value = mock_h5py_context @@ -269,9 +261,7 @@ def test__outcome_buckets_are_correct( assert result.outcomes_by_region["uk"]["Gain more than 5%"] == 1 assert result.outcomes_by_region["uk"]["Gain less than 5%"] == 0 - @patch( - "policyengine_api.endpoints.economy.compare.download_huggingface_dataset" - ) + @patch("policyengine_api.endpoints.economy.compare.download_huggingface_dataset") @patch("policyengine_api.endpoints.economy.compare.h5py.File") @patch("policyengine_api.endpoints.economy.compare.pd.read_csv") def test__downloads_from_correct_repos( @@ -284,9 +274,7 @@ def test__downloads_from_correct_repos( mock_weights = np.ones((1, 10)) mock_h5py_context = MagicMock() - mock_h5py_context.__enter__ = MagicMock( - return_value={"2025": mock_weights} - ) + mock_h5py_context.__enter__ = MagicMock(return_value={"2025": mock_weights}) mock_h5py_context.__exit__ = MagicMock(return_value=False) mock_h5py_file.return_value = mock_h5py_context @@ -307,32 +295,22 @@ def test__downloads_from_correct_repos( # Verify correct repos are used calls = mock_download.call_args_list - assert ( - calls[0][1]["repo"] == "policyengine/policyengine-uk-data-private" - ) + assert calls[0][1]["repo"] == "policyengine/policyengine-uk-data-private" assert calls[0][1]["repo_filename"] == "local_authority_weights.h5" - assert ( - calls[1][1]["repo"] == "policyengine/policyengine-uk-data-public" - ) + assert calls[1][1]["repo"] == "policyengine/policyengine-uk-data-public" assert calls[1][1]["repo_filename"] == "local_authorities_2021.csv" def test__given_constituency_region__returns_none(self): """When simulating a constituency, local authority breakdown should not be computed.""" - result = uk_local_authority_breakdown( - {}, {}, "uk", "constituency/Aldershot" - ) + result = uk_local_authority_breakdown({}, {}, "uk", "constituency/Aldershot") assert result is None def test__given_constituency_region_with_code__returns_none(self): """When simulating a constituency by code, local authority breakdown should not be computed.""" - result = uk_local_authority_breakdown( - {}, {}, "uk", "constituency/E12345678" - ) + result = uk_local_authority_breakdown({}, {}, "uk", "constituency/E12345678") assert result is None - @patch( - "policyengine_api.endpoints.economy.compare.download_huggingface_dataset" - ) + @patch("policyengine_api.endpoints.economy.compare.download_huggingface_dataset") @patch("policyengine_api.endpoints.economy.compare.h5py.File") @patch("policyengine_api.endpoints.economy.compare.pd.read_csv") def test__given_specific_la_region__returns_only_that_la( @@ -346,9 +324,7 @@ def test__given_specific_la_region__returns_only_that_la( mock_weights = np.ones((3, 10)) mock_h5py_context = MagicMock() - mock_h5py_context.__enter__ = MagicMock( - return_value={"2025": mock_weights} - ) + mock_h5py_context.__enter__ = MagicMock(return_value={"2025": mock_weights}) mock_h5py_context.__exit__ = MagicMock(return_value=False) mock_h5py_file.return_value = mock_h5py_context @@ -375,9 +351,7 @@ def test__given_specific_la_region__returns_only_that_la( assert "Aberdeen City" not in result.by_local_authority assert "Isle of Anglesey" not in result.by_local_authority - @patch( - "policyengine_api.endpoints.economy.compare.download_huggingface_dataset" - ) + @patch("policyengine_api.endpoints.economy.compare.download_huggingface_dataset") @patch("policyengine_api.endpoints.economy.compare.h5py.File") @patch("policyengine_api.endpoints.economy.compare.pd.read_csv") def test__given_country_scotland_region__returns_only_scottish_las( @@ -391,9 +365,7 @@ def test__given_country_scotland_region__returns_only_scottish_las( mock_weights = np.ones((3, 10)) mock_h5py_context = MagicMock() - mock_h5py_context.__enter__ = MagicMock( - return_value={"2025": mock_weights} - ) + mock_h5py_context.__enter__ = MagicMock(return_value={"2025": mock_weights}) mock_h5py_context.__exit__ = MagicMock(return_value=False) mock_h5py_file.return_value = mock_h5py_context @@ -420,9 +392,7 @@ def test__given_country_scotland_region__returns_only_scottish_las( assert "Hartlepool" not in result.by_local_authority assert "Isle of Anglesey" not in result.by_local_authority - @patch( - "policyengine_api.endpoints.economy.compare.download_huggingface_dataset" - ) + @patch("policyengine_api.endpoints.economy.compare.download_huggingface_dataset") @patch("policyengine_api.endpoints.economy.compare.h5py.File") @patch("policyengine_api.endpoints.economy.compare.pd.read_csv") def test__given_uk_region__returns_all_las( @@ -436,9 +406,7 @@ def test__given_uk_region__returns_all_las( mock_weights = np.ones((3, 10)) mock_h5py_context = MagicMock() - mock_h5py_context.__enter__ = MagicMock( - return_value={"2025": mock_weights} - ) + mock_h5py_context.__enter__ = MagicMock(return_value={"2025": mock_weights}) mock_h5py_context.__exit__ = MagicMock(return_value=False) mock_h5py_file.return_value = mock_h5py_context @@ -463,9 +431,7 @@ def test__given_uk_region__returns_all_las( assert "Aberdeen City" in result.by_local_authority assert "Isle of Anglesey" in result.by_local_authority - @patch( - "policyengine_api.endpoints.economy.compare.download_huggingface_dataset" - ) + @patch("policyengine_api.endpoints.economy.compare.download_huggingface_dataset") @patch("policyengine_api.endpoints.economy.compare.h5py.File") @patch("policyengine_api.endpoints.economy.compare.pd.read_csv") def test__given_no_region__returns_all_las( @@ -479,9 +445,7 @@ def test__given_no_region__returns_all_las( mock_weights = np.ones((3, 10)) mock_h5py_context = MagicMock() - mock_h5py_context.__enter__ = MagicMock( - return_value={"2025": mock_weights} - ) + mock_h5py_context.__enter__ = MagicMock(return_value={"2025": mock_weights}) mock_h5py_context.__exit__ = MagicMock(return_value=False) mock_h5py_file.return_value = mock_h5py_context @@ -545,21 +509,15 @@ def test__given_non_uk_country_nigeria__returns_none(self): def test__given_local_authority_region__returns_none(self): """When simulating a local authority, constituency breakdown should not be computed.""" - result = uk_constituency_breakdown( - {}, {}, "uk", "local_authority/Leicester" - ) + result = uk_constituency_breakdown({}, {}, "uk", "local_authority/Leicester") assert result is None def test__given_local_authority_region_with_code__returns_none(self): """When simulating a local authority by code, constituency breakdown should not be computed.""" - result = uk_constituency_breakdown( - {}, {}, "uk", "local_authority/E06000016" - ) + result = uk_constituency_breakdown({}, {}, "uk", "local_authority/E06000016") assert result is None - @patch( - "policyengine_api.endpoints.economy.compare.download_huggingface_dataset" - ) + @patch("policyengine_api.endpoints.economy.compare.download_huggingface_dataset") @patch("policyengine_api.endpoints.economy.compare.h5py.File") @patch("policyengine_api.endpoints.economy.compare.pd.read_csv") def test__given_specific_constituency_region__returns_only_that_constituency( @@ -574,9 +532,7 @@ def test__given_specific_constituency_region__returns_only_that_constituency( # Create mock weights - 3 constituencies, 10 households mock_weights = np.ones((3, 10)) mock_h5py_context = MagicMock() - mock_h5py_context.__enter__ = MagicMock( - return_value={"2025": mock_weights} - ) + mock_h5py_context.__enter__ = MagicMock(return_value={"2025": mock_weights}) mock_h5py_context.__exit__ = MagicMock(return_value=False) mock_h5py_file.return_value = mock_h5py_context @@ -604,9 +560,7 @@ def test__given_specific_constituency_region__returns_only_that_constituency( assert "Edinburgh East" not in result.by_constituency assert "Cardiff South" not in result.by_constituency - @patch( - "policyengine_api.endpoints.economy.compare.download_huggingface_dataset" - ) + @patch("policyengine_api.endpoints.economy.compare.download_huggingface_dataset") @patch("policyengine_api.endpoints.economy.compare.h5py.File") @patch("policyengine_api.endpoints.economy.compare.pd.read_csv") def test__given_country_scotland_region__returns_only_scottish_constituencies( @@ -620,9 +574,7 @@ def test__given_country_scotland_region__returns_only_scottish_constituencies( mock_weights = np.ones((3, 10)) mock_h5py_context = MagicMock() - mock_h5py_context.__enter__ = MagicMock( - return_value={"2025": mock_weights} - ) + mock_h5py_context.__enter__ = MagicMock(return_value={"2025": mock_weights}) mock_h5py_context.__exit__ = MagicMock(return_value=False) mock_h5py_file.return_value = mock_h5py_context @@ -639,9 +591,7 @@ def test__given_country_scotland_region__returns_only_scottish_constituencies( baseline = {"household_net_income": np.array([1000.0] * 10)} reform = {"household_net_income": np.array([1050.0] * 10)} - result = uk_constituency_breakdown( - baseline, reform, "uk", "country/scotland" - ) + result = uk_constituency_breakdown(baseline, reform, "uk", "country/scotland") assert result is not None assert len(result.by_constituency) == 1 @@ -649,9 +599,7 @@ def test__given_country_scotland_region__returns_only_scottish_constituencies( assert "Aldershot" not in result.by_constituency assert "Cardiff South" not in result.by_constituency - @patch( - "policyengine_api.endpoints.economy.compare.download_huggingface_dataset" - ) + @patch("policyengine_api.endpoints.economy.compare.download_huggingface_dataset") @patch("policyengine_api.endpoints.economy.compare.h5py.File") @patch("policyengine_api.endpoints.economy.compare.pd.read_csv") def test__given_uk_region__returns_all_constituencies( @@ -665,9 +613,7 @@ def test__given_uk_region__returns_all_constituencies( mock_weights = np.ones((3, 10)) mock_h5py_context = MagicMock() - mock_h5py_context.__enter__ = MagicMock( - return_value={"2025": mock_weights} - ) + mock_h5py_context.__enter__ = MagicMock(return_value={"2025": mock_weights}) mock_h5py_context.__exit__ = MagicMock(return_value=False) mock_h5py_file.return_value = mock_h5py_context @@ -692,9 +638,7 @@ def test__given_uk_region__returns_all_constituencies( assert "Edinburgh East" in result.by_constituency assert "Cardiff South" in result.by_constituency - @patch( - "policyengine_api.endpoints.economy.compare.download_huggingface_dataset" - ) + @patch("policyengine_api.endpoints.economy.compare.download_huggingface_dataset") @patch("policyengine_api.endpoints.economy.compare.h5py.File") @patch("policyengine_api.endpoints.economy.compare.pd.read_csv") def test__given_no_region__returns_all_constituencies( @@ -708,9 +652,7 @@ def test__given_no_region__returns_all_constituencies( mock_weights = np.ones((3, 10)) mock_h5py_context = MagicMock() - mock_h5py_context.__enter__ = MagicMock( - return_value={"2025": mock_weights} - ) + mock_h5py_context.__enter__ = MagicMock(return_value={"2025": mock_weights}) mock_h5py_context.__exit__ = MagicMock(return_value=False) mock_h5py_file.return_value = mock_h5py_context diff --git a/tests/unit/libs/test_simulation_api_factory.py b/tests/unit/libs/test_simulation_api_factory.py index 6602c47b4..43d5ea339 100644 --- a/tests/unit/libs/test_simulation_api_factory.py +++ b/tests/unit/libs/test_simulation_api_factory.py @@ -120,36 +120,29 @@ def test__given_use_modal_env_false__then_returns_gcp_api( # Then assert isinstance(api, SimulationAPI) - def test__given_use_modal_env_not_set__then_returns_gcp_api( + def test__given_use_modal_env_not_set__then_returns_modal_api( self, mock_factory_logger, ): - # Given + # Given - default is now Modal when env var is not set import os env_copy = dict(os.environ) env_copy.pop("USE_MODAL_SIMULATION_API", None) - env_copy["GOOGLE_APPLICATION_CREDENTIALS"] = "/path/to/creds.json" with patch.dict("os.environ", env_copy, clear=True): - with patch( - "policyengine_api.libs.simulation_api.executions_v1.ExecutionsClient" - ): - with patch( - "policyengine_api.libs.simulation_api.workflows_v1.WorkflowsClient" - ): - from policyengine_api.libs.simulation_api_factory import ( - get_simulation_api, - ) - from policyengine_api.libs.simulation_api import ( - SimulationAPI, - ) + from policyengine_api.libs.simulation_api_factory import ( + get_simulation_api, + ) + from policyengine_api.libs.simulation_api_modal import ( + SimulationAPIModal, + ) - # When - api = get_simulation_api() + # When + api = get_simulation_api() - # Then - assert isinstance(api, SimulationAPI) + # Then + assert isinstance(api, SimulationAPIModal) def test__given_use_modal_env_false__then_logs_gcp_selection( self, @@ -178,9 +171,7 @@ def test__given_use_modal_env_false__then_logs_gcp_selection( # Then mock_factory_logger.log_struct.assert_called() - call_args = mock_factory_logger.log_struct.call_args[ - 0 - ][0] + call_args = mock_factory_logger.log_struct.call_args[0][0] assert "GCP" in call_args["message"] class TestGCPCredentialsError: @@ -189,11 +180,11 @@ def test__given_gcp_selected_without_credentials__then_raises_error( self, mock_factory_logger, ): - # Given + # Given - explicitly select GCP without credentials import os env_copy = dict(os.environ) - env_copy.pop("USE_MODAL_SIMULATION_API", None) + env_copy["USE_MODAL_SIMULATION_API"] = "false" env_copy.pop("GOOGLE_APPLICATION_CREDENTIALS", None) with patch.dict("os.environ", env_copy, clear=True): diff --git a/tests/unit/libs/test_simulation_api_modal.py b/tests/unit/libs/test_simulation_api_modal.py index 4ba7d0616..25704e63a 100644 --- a/tests/unit/libs/test_simulation_api_modal.py +++ b/tests/unit/libs/test_simulation_api_modal.py @@ -93,9 +93,7 @@ class TestSimulationAPIModal: class TestInit: - def test__given_env_var_set__then_uses_env_url( - self, mock_httpx_client - ): + def test__given_env_var_set__then_uses_env_url(self, mock_httpx_client): # Given with patch.dict( "os.environ", @@ -107,9 +105,7 @@ def test__given_env_var_set__then_uses_env_url( # Then assert api.base_url == MOCK_MODAL_BASE_URL - def test__given_env_var_not_set__then_uses_default_url( - self, mock_httpx_client - ): + def test__given_env_var_not_set__then_uses_default_url(self, mock_httpx_client): # Given with patch.dict("os.environ", {}, clear=False): import os @@ -188,9 +184,7 @@ def test__given_network_error__then_raises_exception( mock_modal_logger, ): # Given - mock_httpx_client.post.side_effect = httpx.RequestError( - "Connection failed" - ) + mock_httpx_client.post.side_effect = httpx.RequestError("Connection failed") api = SimulationAPIModal() # When/Then @@ -278,9 +272,7 @@ def test__given_job_id__then_polls_correct_endpoint( class TestGetExecutionId: - def test__given_execution__then_returns_job_id( - self, mock_httpx_client - ): + def test__given_execution__then_returns_job_id(self, mock_httpx_client): # Given api = SimulationAPIModal() execution = ModalSimulationExecution( @@ -296,9 +288,7 @@ def test__given_execution__then_returns_job_id( class TestGetExecutionStatus: - def test__given_execution__then_returns_status_string( - self, mock_httpx_client - ): + def test__given_execution__then_returns_status_string(self, mock_httpx_client): # Given api = SimulationAPIModal() execution = ModalSimulationExecution( @@ -386,9 +376,7 @@ def test__given_network_error__then_returns_false( self, mock_httpx_client, mock_modal_logger ): # Given - mock_httpx_client.get.side_effect = httpx.RequestError( - "Connection failed" - ) + mock_httpx_client.get.side_effect = httpx.RequestError("Connection failed") api = SimulationAPIModal() # When diff --git a/tests/unit/services/test_ai_analysis_service.py b/tests/unit/services/test_ai_analysis_service.py index 34810cc2b..2ff182b5c 100644 --- a/tests/unit/services/test_ai_analysis_service.py +++ b/tests/unit/services/test_ai_analysis_service.py @@ -33,8 +33,7 @@ def test_trigger_ai_analysis_given_successful_streaming( for i, chunk in enumerate(results): if i < len(text_chunks): expected_chunk = ( - json.dumps({"type": "text", "stream": text_chunks[i][:5]}) - + "\n" + json.dumps({"type": "text", "stream": text_chunks[i][:5]}) + "\n" ) assert chunk == expected_chunk diff --git a/tests/unit/services/test_economy_service.py b/tests/unit/services/test_economy_service.py index ba4a4e586..30be14333 100644 --- a/tests/unit/services/test_economy_service.py +++ b/tests/unit/services/test_economy_service.py @@ -175,9 +175,7 @@ def test__given_no_previous_impact__creates_new_simulation( mock_datetime, mock_numpy_random, ): - mock_reform_impacts_service.get_all_reform_impacts.return_value = ( - [] - ) + mock_reform_impacts_service.get_all_reform_impacts.return_value = [] result = economy_service.get_economic_impact(**base_params) @@ -199,8 +197,8 @@ def test__given_exception__raises_error( mock_datetime, mock_numpy_random, ): - mock_reform_impacts_service.get_all_reform_impacts.side_effect = ( - Exception("Database error") + mock_reform_impacts_service.get_all_reform_impacts.side_effect = Exception( + "Database error" ) with pytest.raises(Exception) as exc_info: @@ -273,9 +271,7 @@ def test__given_existing_impacts__returns_first_impact( create_mock_reform_impact(), create_mock_reform_impact(), ] - mock_reform_impacts_service.get_all_reform_impacts.return_value = ( - impacts - ) + mock_reform_impacts_service.get_all_reform_impacts.return_value = impacts result = economy_service._get_most_recent_impact(setup_options) @@ -285,9 +281,7 @@ def test__given_no_impacts__returns_none( self, economy_service, setup_options, mock_reform_impacts_service ): # Arrange - mock_reform_impacts_service.get_all_reform_impacts.return_value = ( - [] - ) + mock_reform_impacts_service.get_all_reform_impacts.return_value = [] # Act result = economy_service._get_most_recent_impact(setup_options) @@ -320,9 +314,7 @@ def test__given_error_status__returns_completed(self, economy_service): assert result == ImpactAction.COMPLETED - def test__given_computing_status__returns_computing( - self, economy_service - ): + def test__given_computing_status__returns_computing(self, economy_service): impact = create_mock_reform_impact(status="computing") result = economy_service._determine_impact_action(impact) @@ -418,9 +410,7 @@ def test__given_unknown_state__raises_error( economy_service._handle_execution_state( setup_options, "UNKNOWN", reform_impact ) - assert "Unexpected sim API execution state: UNKNOWN" in str( - exc_info.value - ) + assert "Unexpected sim API execution state: UNKNOWN" in str(exc_info.value) # Modal status tests def test__given_modal_complete_state__then_returns_completed_result( @@ -490,9 +480,7 @@ def test__given_modal_failed_state_with_error_message__then_includes_error_in_me # Then assert result.status == ImpactStatus.ERROR # Verify the error message was passed to the service - call_args = ( - mock_reform_impacts_service.set_error_reform_impact.call_args - ) + call_args = mock_reform_impacts_service.set_error_reform_impact.call_args assert "Simulation timed out" in call_args[1]["message"] def test__given_modal_running_state__then_returns_computing_result( @@ -632,9 +620,7 @@ class TestSetupSimOptions: """ test_country_id = "us" - test_reform_policy = json.dumps( - {"sample_param": {"2024-01-01.2100-12-31": 15}} - ) + test_reform_policy = json.dumps({"sample_param": {"2024-01-01.2100-12-31": 15}}) test_current_law_baseline_policy = json.dumps({}) test_region = "us" test_time_period = 2025 @@ -662,16 +648,12 @@ def test__given_us_nationwide__returns_correct_sim_options(self): ) assert sim_options["time_period"] == self.test_time_period assert sim_options["region"] == "us" - assert ( - sim_options["data"] == "gs://policyengine-us-data/cps_2023.h5" - ) + assert sim_options["data"] == "gs://policyengine-us-data/cps_2023.h5" def test__given_us_state_ca__returns_correct_sim_options(self): # Test with a normalized US state (prefixed format) country_id = "us" - reform_policy = json.dumps( - {"sample_param": {"2024-01-01.2100-12-31": 15}} - ) + reform_policy = json.dumps({"sample_param": {"2024-01-01.2100-12-31": 15}}) current_law_baseline_policy = json.dumps({}) region = "state/ca" # Pre-normalized time_period = 2025 @@ -691,21 +673,15 @@ def test__given_us_state_ca__returns_correct_sim_options(self): assert sim_options["country"] == country_id assert sim_options["scope"] == scope assert sim_options["reform"] == json.loads(reform_policy) - assert sim_options["baseline"] == json.loads( - current_law_baseline_policy - ) + assert sim_options["baseline"] == json.loads(current_law_baseline_policy) assert sim_options["time_period"] == time_period assert sim_options["region"] == "state/ca" - assert ( - sim_options["data"] == "gs://policyengine-us-data/states/CA.h5" - ) + assert sim_options["data"] == "gs://policyengine-us-data/states/CA.h5" def test__given_us_state_utah__returns_correct_sim_options(self): # Test with normalized Utah state country_id = "us" - reform_policy = json.dumps( - {"sample_param": {"2024-01-01.2100-12-31": 15}} - ) + reform_policy = json.dumps({"sample_param": {"2024-01-01.2100-12-31": 15}}) current_law_baseline_policy = json.dumps({}) region = "state/ut" # Pre-normalized time_period = 2025 @@ -725,20 +701,14 @@ def test__given_us_state_utah__returns_correct_sim_options(self): assert sim_options["country"] == country_id assert sim_options["scope"] == scope assert sim_options["reform"] == json.loads(reform_policy) - assert sim_options["baseline"] == json.loads( - current_law_baseline_policy - ) + assert sim_options["baseline"] == json.loads(current_law_baseline_policy) assert sim_options["time_period"] == time_period assert sim_options["region"] == "state/ut" - assert ( - sim_options["data"] == "gs://policyengine-us-data/states/UT.h5" - ) + assert sim_options["data"] == "gs://policyengine-us-data/states/UT.h5" def test__given_cliff_target__returns_correct_sim_options(self): country_id = "us" - reform_policy = json.dumps( - {"sample_param": {"2024-01-01.2100-12-31": 15}} - ) + reform_policy = json.dumps({"sample_param": {"2024-01-01.2100-12-31": 15}}) current_law_baseline_policy = json.dumps({}) region = "us" time_period = 2025 @@ -760,21 +730,15 @@ def test__given_cliff_target__returns_correct_sim_options(self): assert sim_options["country"] == country_id assert sim_options["scope"] == scope assert sim_options["reform"] == json.loads(reform_policy) - assert sim_options["baseline"] == json.loads( - current_law_baseline_policy - ) + assert sim_options["baseline"] == json.loads(current_law_baseline_policy) assert sim_options["time_period"] == time_period assert sim_options["region"] == region - assert ( - sim_options["data"] == "gs://policyengine-us-data/cps_2023.h5" - ) + assert sim_options["data"] == "gs://policyengine-us-data/cps_2023.h5" assert sim_options["include_cliffs"] is True def test__given_uk__returns_correct_sim_options(self): country_id = "uk" - reform_policy = json.dumps( - {"sample_param": {"2024-01-01.2100-12-31": 15}} - ) + reform_policy = json.dumps({"sample_param": {"2024-01-01.2100-12-31": 15}}) current_law_baseline_policy = json.dumps({}) region = "uk" time_period = 2025 @@ -803,9 +767,7 @@ def test__given_congressional_district__returns_correct_sim_options( self, ): country_id = "us" - reform_policy = json.dumps( - {"sample_param": {"2024-01-01.2100-12-31": 15}} - ) + reform_policy = json.dumps({"sample_param": {"2024-01-01.2100-12-31": 15}}) current_law_baseline_policy = json.dumps({}) region = "congressional_district/CA-37" # Pre-normalized time_period = 2025 @@ -824,10 +786,7 @@ def test__given_congressional_district__returns_correct_sim_options( sim_options = sim_options_model.model_dump() assert sim_options["region"] == "congressional_district/CA-37" - assert ( - sim_options["data"] - == "gs://policyengine-us-data/districts/CA-37.h5" - ) + assert sim_options["data"] == "gs://policyengine-us-data/districts/CA-37.h5" class TestSetupRegion: """Tests for _setup_region method. @@ -860,18 +819,14 @@ def test__given_prefixed_state_tx__returns_unchanged(self): def test__given_congressional_district__returns_unchanged(self): service = EconomyService() - result = service._setup_region( - "us", "congressional_district/CA-37" - ) + result = service._setup_region("us", "congressional_district/CA-37") assert result == "congressional_district/CA-37" def test__given_lowercase_congressional_district__returns_unchanged( self, ): service = EconomyService() - result = service._setup_region( - "us", "congressional_district/ca-37" - ) + result = service._setup_region("us", "congressional_district/ca-37") assert result == "congressional_district/ca-37" def test__given_invalid_prefixed_state__raises_value_error(self): @@ -886,17 +841,13 @@ def test__given_invalid_congressional_district__raises_value_error( service = EconomyService() with pytest.raises(ValueError) as exc_info: service._setup_region("us", "congressional_district/cruft") - assert "Invalid congressional district: 'cruft'" in str( - exc_info.value - ) + assert "Invalid congressional district: 'cruft'" in str(exc_info.value) def test__given_invalid_prefix__raises_value_error(self): service = EconomyService() with pytest.raises(ValueError) as exc_info: service._setup_region("us", "invalid_prefix/tx") - assert "Invalid US region: 'invalid_prefix/tx'" in str( - exc_info.value - ) + assert "Invalid US region: 'invalid_prefix/tx'" in str(exc_info.value) def test__given_invalid_bare_value__raises_value_error(self): # Bare values without prefix are now invalid (should be normalized first) @@ -922,9 +873,7 @@ def test__given_us_city_nyc__returns_pooled_cps(self): # Test with normalized city/nyc format service = EconomyService() result = service._setup_data("us", "city/nyc") - assert ( - result == "gs://policyengine-us-data/pooled_3_year_cps_2023.h5" - ) + assert result == "gs://policyengine-us-data/pooled_3_year_cps_2023.h5" def test__given_us_state_ca__returns_state_dataset(self): # Test with US state - returns state-specific dataset @@ -954,10 +903,7 @@ def test__given_uk__returns_efrs_dataset(self): # Test with UK - returns enhanced FRS dataset service = EconomyService() result = service._setup_data("uk", "uk") - assert ( - result - == "gs://policyengine-uk-data-private/enhanced_frs_2023_24.h5" - ) + assert result == "gs://policyengine-uk-data-private/enhanced_frs_2023_24.h5" def test__given_invalid_country__raises_value_error(self, mock_logger): # Test with invalid country @@ -1001,14 +947,10 @@ def test__given_invalid_congressional_district__raises_value_error( service = EconomyService() with pytest.raises(ValueError) as exc_info: service._validate_us_region("congressional_district/CA-99") - assert "Invalid congressional district: 'CA-99'" in str( - exc_info.value - ) + assert "Invalid congressional district: 'CA-99'" in str(exc_info.value) def test__given_nonexistent_district__raises_value_error(self): service = EconomyService() with pytest.raises(ValueError) as exc_info: service._validate_us_region("congressional_district/cruft") - assert "Invalid congressional district: 'cruft'" in str( - exc_info.value - ) + assert "Invalid congressional district: 'cruft'" in str(exc_info.value) diff --git a/tests/unit/services/test_household_service.py b/tests/unit/services/test_household_service.py index 9a3ccad6d..a67abfdb2 100644 --- a/tests/unit/services/test_household_service.py +++ b/tests/unit/services/test_household_service.py @@ -27,9 +27,7 @@ def test_get_household_given_existing_record( # GIVEN an existing record... (included as fixture) # WHEN we call get_household for this record... - result = service.get_household( - valid_db_row["country_id"], valid_db_row["id"] - ) + result = service.get_household(valid_db_row["country_id"], valid_db_row["id"]) valid_household_json = valid_request_body["data"] diff --git a/tests/unit/services/test_metadata_service.py b/tests/unit/services/test_metadata_service.py index 70ea9262e..42c266399 100644 --- a/tests/unit/services/test_metadata_service.py +++ b/tests/unit/services/test_metadata_service.py @@ -127,9 +127,7 @@ def test_verify_metadata_for_given_country( ("us", ["national", "state", "city", "congressional_district"]), ], ) - def test_verify_region_types_for_given_country( - self, country_id, expected_types - ): + def test_verify_region_types_for_given_country(self, country_id, expected_types): """ Verifies that all regions for UK and US have a 'type' field with valid values. @@ -139,9 +137,7 @@ def test_verify_region_types_for_given_country( regions = metadata["economy_options"]["region"] for region in regions: - assert ( - "type" in region - ), f"Region '{region['name']}' missing 'type' field" + assert "type" in region, f"Region '{region['name']}' missing 'type' field" assert ( region["type"] in expected_types ), f"Region '{region['name']}' has invalid type '{region['type']}'" diff --git a/tests/unit/services/test_policy_service.py b/tests/unit/services/test_policy_service.py index 4530dd9d5..b93814fca 100644 --- a/tests/unit/services/test_policy_service.py +++ b/tests/unit/services/test_policy_service.py @@ -16,9 +16,7 @@ class TestGetPolicy: - def test_get_policy_given_existing_record( - self, test_db, existing_policy_record - ): + def test_get_policy_given_existing_record(self, test_db, existing_policy_record): # GIVEN an existing record... (included as fixture) # WHEN we call get_policy for this record... @@ -43,9 +41,7 @@ def test_get_policy_given_nonexistent_record(self, test_db): # WHEN we call get_policy for a nonexistent record NO_SUCH_RECORD_ID = 999 - result = service.get_policy( - valid_policy_data["country_id"], NO_SUCH_RECORD_ID - ) + result = service.get_policy(valid_policy_data["country_id"], NO_SUCH_RECORD_ID) # THEN the result should be None assert result is None @@ -60,9 +56,7 @@ def test_get_policy_given_str_id(self): ): # WHEN we call get_policy with the invalid ID # THEN an exception should be raised - service.get_policy( - valid_policy_data["country_id"], INVALID_RECORD_ID - ) + service.get_policy(valid_policy_data["country_id"], INVALID_RECORD_ID) def test_get_policy_given_negative_int_id(self): # GIVEN an invalid ID @@ -74,18 +68,14 @@ def test_get_policy_given_negative_int_id(self): ): # WHEN we call get_policy with the invalid ID # THEN an exception should be raised - service.get_policy( - valid_policy_data["country_id"], INVALID_RECORD_ID - ) + service.get_policy(valid_policy_data["country_id"], INVALID_RECORD_ID) def test_get_policy_given_invalid_country_id(self): # GIVEN an invalid country_id INVALID_COUNTRY_ID = "xx" # Unsupported country code # WHEN we call get_policy with the invalid country_id - result = service.get_policy( - INVALID_COUNTRY_ID, valid_policy_data["id"] - ) + result = service.get_policy(INVALID_COUNTRY_ID, valid_policy_data["id"]) # THEN the result should be None or raise an exception assert result is None @@ -236,9 +226,7 @@ def test_set_policy_existing( existing_policy = existing_policy_record # Setup mock - mock_database.query.return_value.fetchone.return_value = ( - existing_policy - ) + mock_database.query.return_value.fetchone.return_value = existing_policy # Define expected database calls - matches actual implementation expected_calls = [ @@ -277,9 +265,7 @@ def test_set_policy_given_database_insert_failure( # Setup mock to raise exception on insert mock_database.query.return_value.fetchone.side_effect = [ None, # First call: policy does not exist - Exception( - "Database insertion failed" - ), # Second call: insertion fails + Exception("Database insertion failed"), # Second call: insertion fails ] # WHEN we call set_policy @@ -300,9 +286,7 @@ def test_set_policy_given_invalid_country_id(self, mock_hash_object): # THEN an exception should be raised service.set_policy(INVALID_COUNTRY_ID, test_label, test_policy) - def test_set_policy_given_empty_label( - self, mock_database, mock_hash_object - ): + def test_set_policy_given_empty_label(self, mock_database, mock_hash_object): # GIVEN an empty label EMPTY_LABEL = "" test_policy = {"param": "value"} diff --git a/tests/unit/services/test_report_output_service.py b/tests/unit/services/test_report_output_service.py index 15f6b8576..c1f6b3e55 100644 --- a/tests/unit/services/test_report_output_service.py +++ b/tests/unit/services/test_report_output_service.py @@ -13,9 +13,7 @@ class TestFindExistingReportOutput: """Test finding existing report outputs in the database.""" - def test_find_existing_report_output_found( - self, test_db, existing_report_record - ): + def test_find_existing_report_output_found(self, test_db, existing_report_record): """Test finding an existing report output.""" # GIVEN an existing report record (from fixture) @@ -29,10 +27,7 @@ def test_find_existing_report_output_found( # THEN the result should contain the existing report assert result is not None assert result["id"] == existing_report_record["id"] - assert ( - result["simulation_1_id"] - == existing_report_record["simulation_1_id"] - ) + assert result["simulation_1_id"] == existing_report_record["simulation_1_id"] assert result["status"] == existing_report_record["status"] def test_find_existing_report_output_not_found(self, test_db): @@ -248,10 +243,7 @@ def test_get_report_output_existing(self, test_db, existing_report_record): # THEN the correct report should be returned assert result is not None assert result["id"] == existing_report_record["id"] - assert ( - result["simulation_1_id"] - == existing_report_record["simulation_1_id"] - ) + assert result["simulation_1_id"] == existing_report_record["simulation_1_id"] assert result["status"] == existing_report_record["status"] def test_get_report_output_nonexistent(self, test_db): @@ -335,21 +327,15 @@ def test_duplicate_report_returns_existing(self, test_db): # THEN the same report should be returned (no duplicate created) assert first_report["id"] == second_report["id"] assert first_report["country_id"] == second_report["country_id"] - assert ( - first_report["simulation_1_id"] == second_report["simulation_1_id"] - ) - assert ( - first_report["simulation_2_id"] == second_report["simulation_2_id"] - ) + assert first_report["simulation_1_id"] == second_report["simulation_1_id"] + assert first_report["simulation_2_id"] == second_report["simulation_2_id"] assert first_report["year"] == second_report["year"] class TestUpdateReportOutput: """Test updating report outputs in the database.""" - def test_update_report_output_to_complete( - self, test_db, existing_report_record - ): + def test_update_report_output_to_complete(self, test_db, existing_report_record): """Test updating a report to complete status with output.""" # GIVEN an existing pending report report_id = existing_report_record["id"] @@ -374,9 +360,7 @@ def test_update_report_output_to_complete( assert result["status"] == "complete" assert result["output"] == test_output_json - def test_update_report_output_to_error( - self, test_db, existing_report_record - ): + def test_update_report_output_to_error(self, test_db, existing_report_record): """Test updating a report to error status with message.""" # GIVEN an existing pending report report_id = existing_report_record["id"] @@ -400,9 +384,7 @@ def test_update_report_output_to_error( assert result["status"] == "error" assert result["error_message"] == error_msg - def test_update_report_output_partial_update( - self, test_db, existing_report_record - ): + def test_update_report_output_partial_update(self, test_db, existing_report_record): """Test that partial updates work correctly.""" # GIVEN an existing report report_id = existing_report_record["id"] @@ -424,9 +406,7 @@ def test_update_report_output_partial_update( assert result["status"] == "complete" assert result["output"] is None # Should remain unchanged - def test_update_report_output_no_fields( - self, test_db, existing_report_record - ): + def test_update_report_output_no_fields(self, test_db, existing_report_record): """Test that update with no optional fields still updates API version.""" # GIVEN an existing report diff --git a/tests/unit/services/test_simulation_service.py b/tests/unit/services/test_simulation_service.py index 49c8654a3..ac1fbccf6 100644 --- a/tests/unit/services/test_simulation_service.py +++ b/tests/unit/services/test_simulation_service.py @@ -31,9 +31,7 @@ def test_find_existing_simulation_given_existing_record( assert result is not None assert result["id"] == existing_simulation_record["id"] assert result["country_id"] == valid_simulation_data["country_id"] - assert ( - result["population_id"] == valid_simulation_data["population_id"] - ) + assert result["population_id"] == valid_simulation_data["population_id"] assert result["policy_id"] == valid_simulation_data["policy_id"] def test_find_existing_simulation_given_no_match(self, test_db): @@ -154,9 +152,7 @@ def test_create_simulation_retrieves_correct_id(self, test_db): class TestGetSimulation: """Test retrieving simulations from the database.""" - def test_get_simulation_existing( - self, test_db, existing_simulation_record - ): + def test_get_simulation_existing(self, test_db, existing_simulation_record): """Test retrieving an existing simulation.""" # GIVEN an existing simulation record @@ -181,9 +177,7 @@ def test_get_simulation_nonexistent(self, test_db): # THEN None should be returned assert result is None - def test_get_simulation_wrong_country( - self, test_db, existing_simulation_record - ): + def test_get_simulation_wrong_country(self, test_db, existing_simulation_record): """Test that simulations are country-specific.""" # GIVEN an existing simulation for 'us' @@ -234,11 +228,6 @@ def test_duplicate_simulation_returns_existing(self, test_db): # THEN the same simulation should be returned (no duplicate created) assert first_simulation["id"] == second_simulation["id"] - assert ( - first_simulation["country_id"] == second_simulation["country_id"] - ) - assert ( - first_simulation["population_id"] - == second_simulation["population_id"] - ) + assert first_simulation["country_id"] == second_simulation["country_id"] + assert first_simulation["population_id"] == second_simulation["population_id"] assert first_simulation["policy_id"] == second_simulation["policy_id"] diff --git a/tests/unit/services/test_tracer_analysis_service.py b/tests/unit/services/test_tracer_analysis_service.py index 1e87c41a6..fd1ba8364 100644 --- a/tests/unit/services/test_tracer_analysis_service.py +++ b/tests/unit/services/test_tracer_analysis_service.py @@ -78,9 +78,7 @@ def test_tracer_output_for_empty_tracer(): valid_target_variable = "snap" # When: Extracting from an empty output - result = test_service._parse_tracer_output( - empty_tracer, valid_target_variable - ) + result = test_service._parse_tracer_output(empty_tracer, valid_target_variable) # Then: It should return an empty list since there is no data to parse expected_output = empty_tracer @@ -138,9 +136,7 @@ def test_tracer_output_for_variable_that_is_substring_of_another(): target_variable = "snap_net_income" # When: Extracting the segment for this variable - result = test_service._parse_tracer_output( - valid_tracer_output, target_variable - ) + result = test_service._parse_tracer_output(valid_tracer_output, target_variable) # Then: It should return only the exact match for "snap_net_income", not "snap_net_income_fpg_ratio" diff --git a/tests/unit/services/test_tracer_service.py b/tests/unit/services/test_tracer_service.py index e5436d476..84ece8df3 100644 --- a/tests/unit/services/test_tracer_service.py +++ b/tests/unit/services/test_tracer_service.py @@ -58,6 +58,4 @@ def test_get_tracer_database_error(test_db): valid_api_version, ] with pytest.raises(Exception): - tracer_service.get_tracer( - *missing_parameter_causing_database_exception - ) + tracer_service.get_tracer(*missing_parameter_causing_database_exception) diff --git a/tests/unit/services/test_update_profile_service.py b/tests/unit/services/test_update_profile_service.py index f9fd607b7..f240ccac0 100644 --- a/tests/unit/services/test_update_profile_service.py +++ b/tests/unit/services/test_update_profile_service.py @@ -11,9 +11,7 @@ class TestUpdateProfile: - def test_update_profile_given_existing_record( - self, test_db, existing_user_profile - ): + def test_update_profile_given_existing_record(self, test_db, existing_user_profile): # GIVEN an existing profile record (from fixture) # WHEN we call update_profile with new data @@ -54,9 +52,7 @@ def test_update_profile_given_nonexistent_record(self, test_db): # THEN the result should be False assert result is False - def test_update_profile_with_partial_fields( - self, test_db, existing_user_profile - ): + def test_update_profile_with_partial_fields(self, test_db, existing_user_profile): # GIVEN an existing profile record (from fixture) # WHEN we call update_profile with only some fields provided @@ -93,9 +89,7 @@ def test_update_profile_with_database_error( def mock_db_query_error(*args, **kwargs): raise Exception("Database error") - monkeypatch.setattr( - "policyengine_api.data.database.query", mock_db_query_error - ) + monkeypatch.setattr("policyengine_api.data.database.query", mock_db_query_error) # WHEN we call update_profile # THEN an exception should be raised diff --git a/tests/unit/services/test_user_service.py b/tests/unit/services/test_user_service.py index 75fe4c834..49072a34a 100644 --- a/tests/unit/services/test_user_service.py +++ b/tests/unit/services/test_user_service.py @@ -33,9 +33,7 @@ def test_get_profile_nonexistent_record(self): def test_get_profile_auth0_id(self, existing_user_profile): # WHEN we call get_profile with auth0_id - result = service.get_profile( - auth0_id=existing_user_profile["auth0_id"] - ) + result = service.get_profile(auth0_id=existing_user_profile["auth0_id"]) # THEN returns record assert result == existing_user_profile diff --git a/tests/unit/test_country.py b/tests/unit/test_country.py index b57e8ceee..55a1f7c70 100644 --- a/tests/unit/test_country.py +++ b/tests/unit/test_country.py @@ -30,9 +30,7 @@ def test__uk_has_360_local_authorities(self, uk_regions): ] assert len(local_authority_regions) == 360 - def test__local_authority_regions_have_correct_name_format( - self, uk_regions - ): + def test__local_authority_regions_have_correct_name_format(self, uk_regions): """Verify local authority region names have the correct prefix.""" local_authority_regions = [ r for r in uk_regions if r.get("type") == "local_authority" @@ -121,9 +119,7 @@ def test__coordinates_are_numeric(self, local_authorities_df): assert local_authorities_df["x"].dtype in ["float64", "int64"] assert local_authorities_df["y"].dtype in ["float64", "int64"] - def test__english_local_authorities_have_e_prefix( - self, local_authorities_df - ): + def test__english_local_authorities_have_e_prefix(self, local_authorities_df): """Verify English local authorities have E prefix codes.""" english_las = local_authorities_df[ local_authorities_df["code"].str.startswith("E") @@ -131,9 +127,7 @@ def test__english_local_authorities_have_e_prefix( # England has 296 local authorities (majority of the 360 total) assert len(english_las) == 296 - def test__scottish_local_authorities_have_s_prefix( - self, local_authorities_df - ): + def test__scottish_local_authorities_have_s_prefix(self, local_authorities_df): """Verify Scottish local authorities have S prefix codes.""" scottish_las = local_authorities_df[ local_authorities_df["code"].str.startswith("S") @@ -141,9 +135,7 @@ def test__scottish_local_authorities_have_s_prefix( # Scotland has 32 council areas assert len(scottish_las) == 32 - def test__welsh_local_authorities_have_w_prefix( - self, local_authorities_df - ): + def test__welsh_local_authorities_have_w_prefix(self, local_authorities_df): """Verify Welsh local authorities have W prefix codes.""" welsh_las = local_authorities_df[ local_authorities_df["code"].str.startswith("W") From d4d7b396fd79c91b928b1c2d6c56a797ae0cd1a4 Mon Sep 17 00:00:00 2001 From: Anthony Volk Date: Thu, 15 Jan 2026 00:34:31 +0300 Subject: [PATCH 3/4] chore: Run linter MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 --- tests/unit/services/test_economy_service.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/tests/unit/services/test_economy_service.py b/tests/unit/services/test_economy_service.py index 4fbf4a5e0..025f490c9 100644 --- a/tests/unit/services/test_economy_service.py +++ b/tests/unit/services/test_economy_service.py @@ -649,8 +649,7 @@ def test__given_us_nationwide__returns_correct_sim_options(self): assert sim_options["time_period"] == self.test_time_period assert sim_options["region"] == "us" assert ( - sim_options["data"] - == "gs://policyengine-us-data/enhanced_cps_2024.h5" + sim_options["data"] == "gs://policyengine-us-data/enhanced_cps_2024.h5" ) def test__given_us_state_ca__returns_correct_sim_options(self): @@ -737,8 +736,7 @@ def test__given_cliff_target__returns_correct_sim_options(self): assert sim_options["time_period"] == time_period assert sim_options["region"] == region assert ( - sim_options["data"] - == "gs://policyengine-us-data/enhanced_cps_2024.h5" + sim_options["data"] == "gs://policyengine-us-data/enhanced_cps_2024.h5" ) assert sim_options["include_cliffs"] is True From 3f04db8510315d2c58858705f7e304d0376b6236 Mon Sep 17 00:00:00 2001 From: Anthony Volk Date: Thu, 15 Jan 2026 00:36:10 +0300 Subject: [PATCH 4/4] chore: Lint --- .../ai_prompts/simulation_analysis_prompt.py | 12 +- policyengine_api/api.py | 16 ++- policyengine_api/country.py | 32 +++-- .../data/congressional_districts.py | 4 +- policyengine_api/data/data.py | 8 +- policyengine_api/data/model_setup.py | 8 +- policyengine_api/endpoints/economy/compare.py | 68 +++++++--- policyengine_api/endpoints/household.py | 10 +- policyengine_api/endpoints/policy.py | 32 +++-- policyengine_api/libs/simulation_api.py | 12 +- .../libs/simulation_api_factory.py | 8 +- policyengine_api/routes/economy_routes.py | 36 +++--- policyengine_api/routes/household_routes.py | 20 ++- policyengine_api/routes/metadata_routes.py | 4 +- policyengine_api/routes/policy_routes.py | 4 +- .../routes/report_output_routes.py | 12 +- .../routes/simulation_analysis_routes.py | 4 +- policyengine_api/routes/simulation_routes.py | 4 +- .../services/ai_analysis_service.py | 4 +- policyengine_api/services/economy_service.py | 70 ++++++----- .../services/household_service.py | 12 +- .../services/report_output_service.py | 12 +- .../services/simulation_analysis_service.py | 12 +- .../services/simulation_service.py | 16 ++- .../services/tracer_analysis_service.py | 16 ++- .../validate_household_payload.py | 4 +- .../validate_set_policy_payload.py | 4 +- policyengine_api/utils/singleton.py | 4 +- .../test_environment_variables.py | 4 +- tests/fixtures/integration/simulations.py | 12 +- .../fixtures/services/ai_analysis_service.py | 12 +- tests/fixtures/services/economy_service.py | 19 ++- tests/fixtures/services/household_fixtures.py | 4 +- tests/fixtures/services/policy_service.py | 4 +- tests/integration/test_simulations.py | 8 +- tests/to_refactor/api/test_api.py | 8 +- .../to_refactor_household_fixtures.py | 8 +- .../python/test_ai_analysis_service_old.py | 4 +- .../python/test_household_routes.py | 12 +- .../python/test_policy_service_old.py | 32 +++-- .../python/test_simulation_analysis_routes.py | 12 +- .../python/test_tracer_analysis_routes.py | 12 +- .../python/test_us_policy_macro.py | 8 +- .../python/test_user_profile_routes.py | 16 ++- .../python/test_validate_household_payload.py | 8 +- .../python/test_yearly_var_removal.py | 34 +++-- .../test_simulation_analysis_prompt.py | 12 +- .../unit/data/test_congressional_districts.py | 89 ++++++++++---- tests/unit/endpoints/economy/test_compare.py | 116 +++++++++++++----- .../unit/libs/test_simulation_api_factory.py | 4 +- tests/unit/libs/test_simulation_api_modal.py | 24 +++- .../unit/services/test_ai_analysis_service.py | 3 +- tests/unit/services/test_economy_service.py | 116 +++++++++++++----- tests/unit/services/test_household_service.py | 4 +- tests/unit/services/test_metadata_service.py | 8 +- tests/unit/services/test_policy_service.py | 32 +++-- .../services/test_report_output_service.py | 38 ++++-- .../unit/services/test_simulation_service.py | 21 +++- .../services/test_tracer_analysis_service.py | 8 +- tests/unit/services/test_tracer_service.py | 4 +- .../services/test_update_profile_service.py | 12 +- tests/unit/services/test_user_service.py | 4 +- tests/unit/test_country.py | 16 ++- 63 files changed, 862 insertions(+), 314 deletions(-) diff --git a/policyengine_api/ai_prompts/simulation_analysis_prompt.py b/policyengine_api/ai_prompts/simulation_analysis_prompt.py index dc809312e..e7605771f 100644 --- a/policyengine_api/ai_prompts/simulation_analysis_prompt.py +++ b/policyengine_api/ai_prompts/simulation_analysis_prompt.py @@ -95,12 +95,18 @@ def generate_simulation_analysis_prompt(params: InboundParameters) -> str: ) impact_budget: str = json.dumps(parameters.impact["budget"]) - impact_intra_decile: dict[str, Any] = json.dumps(parameters.impact["intra_decile"]) + impact_intra_decile: dict[str, Any] = json.dumps( + parameters.impact["intra_decile"] + ) impact_decile: str = json.dumps(parameters.impact["decile"]) impact_inequality: str = json.dumps(parameters.impact["inequality"]) impact_poverty: str = json.dumps(parameters.impact["poverty"]["poverty"]) - impact_deep_poverty: str = json.dumps(parameters.impact["poverty"]["deep_poverty"]) - impact_poverty_by_gender: str = json.dumps(parameters.impact["poverty_by_gender"]) + impact_deep_poverty: str = json.dumps( + parameters.impact["poverty"]["deep_poverty"] + ) + impact_poverty_by_gender: str = json.dumps( + parameters.impact["poverty_by_gender"] + ) all_parameters: AllParameters = AllParameters.model_validate( { diff --git a/policyengine_api/api.py b/policyengine_api/api.py index 112cce9ac..b22529b31 100644 --- a/policyengine_api/api.py +++ b/policyengine_api/api.py @@ -132,7 +132,9 @@ def log_timing(message): app.route("//calculate-full", methods=["POST"])( cache.cached(make_cache_key=make_cache_key)( - lambda *args, **kwargs: get_calculate(*args, **kwargs, add_missing=True) + lambda *args, **kwargs: get_calculate( + *args, **kwargs, add_missing=True + ) ) ) log_timing("Calculate-full endpoint registered") @@ -151,7 +153,9 @@ def log_timing(message): app.route("//user-policy", methods=["PUT"])(update_user_policy) log_timing("User policy update endpoint registered") -app.route("//user-policy/", methods=["GET"])(get_user_policy) +app.route("//user-policy/", methods=["GET"])( + get_user_policy +) log_timing("User policy get endpoint registered") app.register_blueprint(user_profile_bp) @@ -173,7 +177,9 @@ def log_timing(message): @app.route("/liveness-check", methods=["GET"]) def liveness_check(): - return flask.Response("OK", status=200, headers={"Content-Type": "text/plain"}) + return flask.Response( + "OK", status=200, headers={"Content-Type": "text/plain"} + ) log_timing("Liveness check endpoint registered") @@ -181,7 +187,9 @@ def liveness_check(): @app.route("/readiness-check", methods=["GET"]) def readiness_check(): - return flask.Response("OK", status=200, headers={"Content-Type": "text/plain"}) + return flask.Response( + "OK", status=200, headers={"Content-Type": "text/plain"} + ) log_timing("Readiness check endpoint registered") diff --git a/policyengine_api/country.py b/policyengine_api/country.py index a9b4695ec..29f64fbbe 100644 --- a/policyengine_api/country.py +++ b/policyengine_api/country.py @@ -60,7 +60,9 @@ def build_metadata(self): }[self.country_id], basicInputs=self.tax_benefit_system.basic_inputs, modelled_policies=self.tax_benefit_system.modelled_policies, - version=pkg_resources.get_distribution(self.country_package_name).version, + version=pkg_resources.get_distribution( + self.country_package_name + ).version, ) def build_microsimulation_options(self) -> dict: @@ -75,9 +77,13 @@ def build_microsimulation_options(self) -> dict: region = [ dict(name="uk", label="the UK", type="national"), dict(name="country/england", label="England", type="country"), - dict(name="country/scotland", label="Scotland", type="country"), + dict( + name="country/scotland", label="Scotland", type="country" + ), dict(name="country/wales", label="Wales", type="country"), - dict(name="country/ni", label="Northern Ireland", type="country"), + dict( + name="country/ni", label="Northern Ireland", type="country" + ), ] for i in range(len(constituency_names)): region.append( @@ -124,7 +130,9 @@ def build_microsimulation_options(self) -> dict: dict(name="state/co", label="Colorado", type="state"), dict(name="state/ct", label="Connecticut", type="state"), dict(name="state/de", label="Delaware", type="state"), - dict(name="state/dc", label="District of Columbia", type="state"), + dict( + name="state/dc", label="District of Columbia", type="state" + ), dict(name="state/fl", label="Florida", type="state"), dict(name="state/ga", label="Georgia", type="state"), dict(name="state/hi", label="Hawaii", type="state"), @@ -292,7 +300,9 @@ def build_parameters(self) -> dict: ), } elif isinstance(parameter, ParameterScaleBracket): - bracket_index = int(parameter.name[parameter.name.index("[") + 1 : -1]) + bracket_index = int( + parameter.name[parameter.name.index("[") + 1 : -1] + ) # Set the label to 'first bracket' for the first bracket, 'second bracket' for the second, etc. bracket_label = f"bracket {bracket_index + 1}" parameter_data[parameter.name] = { @@ -369,7 +379,9 @@ def calculate( for parameter_name in reform: for time_period, value in reform[parameter_name].items(): start_instant, end_instant = time_period.split(".") - parameter = get_parameter(system.parameters, parameter_name) + parameter = get_parameter( + system.parameters, parameter_name + ) node_type = type(parameter.values_list[-1].value) if node_type == int: node_type = float @@ -449,8 +461,12 @@ def calculate( if "axes" in household: pass else: - household[entity_plural][entity_id][variable_name][period] = None - print(f"Error computing {variable_name} for {entity_id}: {e}") + household[entity_plural][entity_id][variable_name][ + period + ] = None + print( + f"Error computing {variable_name} for {entity_id}: {e}" + ) tracer_output = simulation.tracer.computation_log log_lines = tracer_output.lines(aggregate=False, max_depth=10) diff --git a/policyengine_api/data/congressional_districts.py b/policyengine_api/data/congressional_districts.py index b085a0fa5..7aa54ab8c 100644 --- a/policyengine_api/data/congressional_districts.py +++ b/policyengine_api/data/congressional_districts.py @@ -684,7 +684,9 @@ def build_congressional_district_metadata() -> list[dict]: return [ { "name": _build_district_name(district.state_code, district.number), - "label": _build_district_label(district.state_code, district.number), + "label": _build_district_label( + district.state_code, district.number + ), "type": "congressional_district", "state_abbreviation": district.state_code, "state_name": STATE_CODE_TO_NAME[district.state_code], diff --git a/policyengine_api/data/data.py b/policyengine_api/data/data.py index a1f479227..c64ffd065 100644 --- a/policyengine_api/data/data.py +++ b/policyengine_api/data/data.py @@ -30,7 +30,9 @@ def __init__( self.local = local if local: # Local development uses a sqlite database. - self.db_url = REPO / "policyengine_api" / "data" / "policyengine.db" + self.db_url = ( + REPO / "policyengine_api" / "data" / "policyengine.db" + ) if initialize or not Path(self.db_url).exists(): self.initialize() else: @@ -39,7 +41,9 @@ def __init__( self.initialize() def _create_pool(self): - instance_connection_name = "policyengine-api:us-central1:policyengine-api-data" + instance_connection_name = ( + "policyengine-api:us-central1:policyengine-api-data" + ) self.connector = Connector() db_user = "policyengine" db_pass = os.environ["POLICYENGINE_DB_PASSWORD"] diff --git a/policyengine_api/data/model_setup.py b/policyengine_api/data/model_setup.py index 739f7bbcc..a2a6a3ee7 100644 --- a/policyengine_api/data/model_setup.py +++ b/policyengine_api/data/model_setup.py @@ -37,7 +37,11 @@ def get_dataset_version(country_id: str) -> str | None: for dataset in datasets["uk"]: - datasets["uk"][dataset] = f"{datasets['uk'][dataset]}@{get_dataset_version('uk')}" + datasets["uk"][ + dataset + ] = f"{datasets['uk'][dataset]}@{get_dataset_version('uk')}" for dataset in datasets["us"]: - datasets["us"][dataset] = f"{datasets['us'][dataset]}@{get_dataset_version('us')}" + datasets["us"][ + dataset + ] = f"{datasets['us'][dataset]}@{get_dataset_version('us')}" diff --git a/policyengine_api/endpoints/economy/compare.py b/policyengine_api/endpoints/economy/compare.py index 117decb39..c97a03f6f 100644 --- a/policyengine_api/endpoints/economy/compare.py +++ b/policyengine_api/endpoints/economy/compare.py @@ -10,8 +10,12 @@ def budgetary_impact(baseline: dict, reform: dict) -> dict: tax_revenue_impact = reform["total_tax"] - baseline["total_tax"] - state_tax_revenue_impact = reform["total_state_tax"] - baseline["total_state_tax"] - benefit_spending_impact = reform["total_benefits"] - baseline["total_benefits"] + state_tax_revenue_impact = ( + reform["total_state_tax"] - baseline["total_state_tax"] + ) + benefit_spending_impact = ( + reform["total_benefits"] - baseline["total_benefits"] + ) budgetary_impact = tax_revenue_impact - benefit_spending_impact return dict( budgetary_impact=budgetary_impact, @@ -24,10 +28,14 @@ def budgetary_impact(baseline: dict, reform: dict) -> dict: def labor_supply_response(baseline: dict, reform: dict) -> dict: - substitution_lsr = reform["substitution_lsr"] - baseline["substitution_lsr"] + substitution_lsr = ( + reform["substitution_lsr"] - baseline["substitution_lsr"] + ) income_lsr = reform["income_lsr"] - baseline["income_lsr"] total_change = substitution_lsr + income_lsr - revenue_change = reform["budgetary_impact_lsr"] - baseline["budgetary_impact_lsr"] + revenue_change = ( + reform["budgetary_impact_lsr"] - baseline["budgetary_impact_lsr"] + ) substitution_lsr_hh = np.array(reform["substitution_lsr_hh"]) - np.array( baseline["substitution_lsr_hh"] @@ -40,13 +48,17 @@ def labor_supply_response(baseline: dict, reform: dict) -> dict: total_lsr_hh = substitution_lsr_hh + income_lsr_hh - emp_income = MicroSeries(baseline["employment_income_hh"], weights=household_weight) + emp_income = MicroSeries( + baseline["employment_income_hh"], weights=household_weight + ) self_emp_income = MicroSeries( baseline["self_employment_income_hh"], weights=household_weight ) earnings = emp_income + self_emp_income original_earnings = earnings - total_lsr_hh - substitution_lsr_hh = MicroSeries(substitution_lsr_hh, weights=household_weight) + substitution_lsr_hh = MicroSeries( + substitution_lsr_hh, weights=household_weight + ) income_lsr_hh = MicroSeries(income_lsr_hh, weights=household_weight) decile_avg = dict( @@ -69,7 +81,9 @@ def labor_supply_response(baseline: dict, reform: dict) -> dict: substitution=(substitution_lsr_hh.sum() / original_earnings.sum()), ) - decile_rel["income"] = {int(k): v for k, v in decile_rel["income"].items() if k > 0} + decile_rel["income"] = { + int(k): v for k, v in decile_rel["income"].items() if k > 0 + } decile_rel["substitution"] = { int(k): v for k, v in decile_rel["substitution"].items() if k > 0 } @@ -98,7 +112,9 @@ def labor_supply_response(baseline: dict, reform: dict) -> dict: ) -def detailed_budgetary_impact(baseline: dict, reform: dict, country_id: str) -> dict: +def detailed_budgetary_impact( + baseline: dict, reform: dict, country_id: str +) -> dict: result = {} if country_id == "uk": for program in baseline["programs"]: @@ -106,7 +122,8 @@ def detailed_budgetary_impact(baseline: dict, reform: dict, country_id: str) -> result[program] = dict( baseline=baseline["programs"][program], reform=reform["programs"][program], - difference=reform["programs"][program] - baseline["programs"][program], + difference=reform["programs"][program] + - baseline["programs"][program], ) return result @@ -272,7 +289,9 @@ def poverty_impact(baseline: dict, reform: dict) -> dict: reform=float(reform_deep_poverty[age < 18].mean()), ), adult=dict( - baseline=float(baseline_deep_poverty[(age >= 18) & (age < 65)].mean()), + baseline=float( + baseline_deep_poverty[(age >= 18) & (age < 65)].mean() + ), reform=float(reform_deep_poverty[(age >= 18) & (age < 65)].mean()), ), senior=dict( @@ -304,7 +323,9 @@ def intra_decile_impact(baseline: dict, reform: dict) -> dict: decile = MicroSeries(baseline["household_income_decile"]).values absolute_change = (reform_income - baseline_income).values capped_baseline_income = np.maximum(baseline_income.values, 1) - capped_reform_income = np.maximum(reform_income.values, 1) + absolute_change + capped_reform_income = ( + np.maximum(reform_income.values, 1) + absolute_change + ) income_change = ( capped_reform_income - capped_baseline_income ) / capped_baseline_income @@ -341,7 +362,9 @@ def intra_decile_impact(baseline: dict, reform: dict) -> dict: if people_in_decile == 0 and people_in_both == 0: people_in_proportion: float = 0.0 else: - people_in_proportion: float = float(people_in_both / people_in_decile) + people_in_proportion: float = float( + people_in_both / people_in_decile + ) outcome_groups[label].append(people_in_proportion) @@ -362,7 +385,9 @@ def intra_wealth_decile_impact(baseline: dict, reform: dict) -> dict: decile = MicroSeries(baseline["household_wealth_decile"]).values absolute_change = (reform_income - baseline_income).values capped_baseline_income = np.maximum(baseline_income.values, 1) - capped_reform_income = np.maximum(reform_income.values, 1) + absolute_change + capped_reform_income = ( + np.maximum(reform_income.values, 1) + absolute_change + ) income_change = ( capped_reform_income - capped_baseline_income ) / capped_baseline_income @@ -399,7 +424,9 @@ def intra_wealth_decile_impact(baseline: dict, reform: dict) -> dict: if people_in_decile == 0 and people_in_both == 0: people_in_proportion = 0 else: - people_in_proportion: float = float(people_in_both / people_in_decile) + people_in_proportion: float = float( + people_in_both / people_in_decile + ) outcome_groups[label].append(people_in_proportion) @@ -481,7 +508,9 @@ def poverty_racial_breakdown(baseline: dict, reform: dict) -> dict: reform_poverty = MicroSeries( reform["person_in_poverty"], weights=baseline_poverty.weights ) - race = MicroSeries(baseline["race"]) # Can be WHITE, BLACK, HISPANIC, or OTHER. + race = MicroSeries( + baseline["race"] + ) # Can be WHITE, BLACK, HISPANIC, or OTHER. poverty = dict( white=dict( @@ -723,7 +752,10 @@ def uk_local_authority_breakdown( continue elif selected_country == "wales" and not code.startswith("W"): continue - elif selected_country == "northern_ireland" and not code.startswith("N"): + elif ( + selected_country == "northern_ireland" + and not code.startswith("N") + ): continue weight: np.ndarray = weights[i] @@ -809,7 +841,9 @@ def compare_economic_outputs( uk_local_authority_breakdown(baseline, reform, country_id, region) ) if local_authority_impact_data is not None: - local_authority_impact_data = local_authority_impact_data.model_dump() + local_authority_impact_data = ( + local_authority_impact_data.model_dump() + ) try: wealth_decile_impact_data = wealth_decile_impact(baseline, reform) intra_wealth_decile_impact_data = intra_wealth_decile_impact( diff --git a/policyengine_api/endpoints/household.py b/policyengine_api/endpoints/household.py index edd647906..b841c5e10 100644 --- a/policyengine_api/endpoints/household.py +++ b/policyengine_api/endpoints/household.py @@ -41,7 +41,11 @@ def add_yearly_variables(household, country_id): if variables[variable]["isInputVariable"]: household[entity_plural][entity][ variables[variable]["name"] - ] = {household_year: variables[variable]["defaultValue"]} + ] = { + household_year: variables[variable][ + "defaultValue" + ] + } else: household[entity_plural][entity][ variables[variable]["name"] @@ -71,7 +75,9 @@ def get_household_year(household): @validate_country -def get_household_under_policy(country_id: str, household_id: str, policy_id: str): +def get_household_under_policy( + country_id: str, household_id: str, policy_id: str +): """Get a household's output data under a given policy. Args: diff --git a/policyengine_api/endpoints/policy.py b/policyengine_api/endpoints/policy.py index daf428b32..90cfa9bd7 100644 --- a/policyengine_api/endpoints/policy.py +++ b/policyengine_api/endpoints/policy.py @@ -30,7 +30,9 @@ def get_policy_search(country_id: str) -> dict: query = request.args.get("query", "") # The "json.loads" default type is added to convert lowercase # "true" and "false" to Python-friendly bool values - unique_only = request.args.get("unique_only", default=False, type=json.loads) + unique_only = request.args.get( + "unique_only", default=False, type=json.loads + ) try: results = database.query( @@ -45,7 +47,9 @@ def get_policy_search(country_id: str) -> dict: status="error", message=f"No policies found for country {country_id} for query '{query}", ) - return Response(json.dumps(body), status=404, mimetype="application/json") + return Response( + json.dumps(body), status=404, mimetype="application/json" + ) # If unique_only is true, filter results to only include # items where everything except ID is unique @@ -66,16 +70,22 @@ def get_policy_search(country_id: str) -> dict: results = new_results # Format into: [{ id: 1, label: "My policy" }, ...] - policies = [dict(id=result["id"], label=result["label"]) for result in results] + policies = [ + dict(id=result["id"], label=result["label"]) for result in results + ] body = dict( status="ok", message="Policies found", result=policies, ) - return Response(json.dumps(body), status=200, mimetype="application/json") + return Response( + json.dumps(body), status=200, mimetype="application/json" + ) except Exception as e: body = dict(status="error", message=f"Internal server error: {e}") - return Response(json.dumps(body), status=500, mimetype="application/json") + return Response( + json.dumps(body), status=500, mimetype="application/json" + ) @validate_country @@ -167,7 +177,9 @@ def set_user_policy(country_id: str) -> dict: except Exception as e: return Response( json.dumps( - {"message": f"Internal database error: {e}; please try again later."} + { + "message": f"Internal database error: {e}; please try again later." + } ), status=500, mimetype="application/json", @@ -224,7 +236,9 @@ def set_user_policy(country_id: str) -> dict: except Exception as e: return Response( json.dumps( - {"message": f"Internal database error: {e}; please try again later."} + { + "message": f"Internal database error: {e}; please try again later." + } ), status=500, mimetype="application/json", @@ -336,7 +350,9 @@ def update_user_policy(country_id: str) -> dict: except Exception as e: return Response( json.dumps( - {"message": f"Internal database error: {e}; please try again later."} + { + "message": f"Internal database error: {e}; please try again later." + } ), status=500, mimetype="application/json", diff --git a/policyengine_api/libs/simulation_api.py b/policyengine_api/libs/simulation_api.py index 0b271e7f1..1fbd12b48 100644 --- a/policyengine_api/libs/simulation_api.py +++ b/policyengine_api/libs/simulation_api.py @@ -75,9 +75,13 @@ def get_execution_status(self, execution: executions_v1.Execution) -> str: status : str The status of the execution """ - return self.execution_client.get_execution(name=execution.name).state.name + return self.execution_client.get_execution( + name=execution.name + ).state.name - def get_execution_result(self, execution: executions_v1.Execution) -> dict | None: + def get_execution_result( + self, execution: executions_v1.Execution + ) -> dict | None: """ Get the result of an execution @@ -91,7 +95,9 @@ def get_execution_result(self, execution: executions_v1.Execution) -> dict | Non result : str The result of the execution """ - result = self.execution_client.get_execution(name=execution.name).result + result = self.execution_client.get_execution( + name=execution.name + ).result try: return json.loads(result) except: diff --git a/policyengine_api/libs/simulation_api_factory.py b/policyengine_api/libs/simulation_api_factory.py index 4ae3b84b2..38af346d1 100644 --- a/policyengine_api/libs/simulation_api_factory.py +++ b/policyengine_api/libs/simulation_api_factory.py @@ -17,7 +17,9 @@ from policyengine_api.gcp_logging import logger -def get_simulation_api() -> Union["SimulationAPI", "SimulationAPIModal"]: # noqa: F821 +def get_simulation_api() -> ( + Union["SimulationAPI", "SimulationAPIModal"] +): # noqa: F821 """ Get the appropriate simulation API client based on environment configuration. @@ -34,7 +36,9 @@ def get_simulation_api() -> Union["SimulationAPI", "SimulationAPIModal"]: # noq ValueError If GCP client is requested but GOOGLE_APPLICATION_CREDENTIALS is not set. """ - use_modal = os.environ.get("USE_MODAL_SIMULATION_API", "true").lower() == "true" + use_modal = ( + os.environ.get("USE_MODAL_SIMULATION_API", "true").lower() == "true" + ) if use_modal: logger.log_struct( diff --git a/policyengine_api/routes/economy_routes.py b/policyengine_api/routes/economy_routes.py index 84850b17d..c0de06730 100644 --- a/policyengine_api/routes/economy_routes.py +++ b/policyengine_api/routes/economy_routes.py @@ -18,7 +18,9 @@ "//economy//over/", methods=["GET"], ) -def get_economic_impact(country_id: str, policy_id: int, baseline_policy_id: int): +def get_economic_impact( + country_id: str, policy_id: int, baseline_policy_id: int +): policy_id = int(policy_id or get_current_law_policy_id(country_id)) baseline_policy_id = int( @@ -33,21 +35,27 @@ def get_economic_impact(country_id: str, policy_id: int, baseline_policy_id: int dataset = options.pop("dataset", "default") time_period = options.pop("time_period") target: Literal["general", "cliff"] = options.pop("target", "general") - api_version = options.pop("version", COUNTRY_PACKAGE_VERSIONS.get(country_id)) - - economic_impact_result: EconomicImpactResult = economy_service.get_economic_impact( - country_id=country_id, - policy_id=policy_id, - baseline_policy_id=baseline_policy_id, - region=region, - dataset=dataset, - time_period=time_period, - options=options, - api_version=api_version, - target=target, + api_version = options.pop( + "version", COUNTRY_PACKAGE_VERSIONS.get(country_id) ) - result_dict: dict[str, str | dict | None] = economic_impact_result.to_dict() + economic_impact_result: EconomicImpactResult = ( + economy_service.get_economic_impact( + country_id=country_id, + policy_id=policy_id, + baseline_policy_id=baseline_policy_id, + region=region, + dataset=dataset, + time_period=time_period, + options=options, + api_version=api_version, + target=target, + ) + ) + + result_dict: dict[str, str | dict | None] = ( + economic_impact_result.to_dict() + ) return Response( json.dumps( diff --git a/policyengine_api/routes/household_routes.py b/policyengine_api/routes/household_routes.py index 59eff51ee..893d6defd 100644 --- a/policyengine_api/routes/household_routes.py +++ b/policyengine_api/routes/household_routes.py @@ -13,7 +13,9 @@ household_service = HouseholdService() -@household_bp.route("//household/", methods=["GET"]) +@household_bp.route( + "//household/", methods=["GET"] +) @validate_country def get_household(country_id: str, household_id: int) -> Response: """ @@ -25,7 +27,9 @@ def get_household(country_id: str, household_id: int) -> Response: """ print(f"Got request for household {household_id} in country {country_id}") - household: dict | None = household_service.get_household(country_id, household_id) + household: dict | None = household_service.get_household( + country_id, household_id + ) if household is None: raise NotFound(f"Household #{household_id} not found.") else: @@ -63,7 +67,9 @@ def post_household(country_id: str) -> Response: label: str | None = payload.get("label") household_json: dict = payload.get("data") - household_id = household_service.create_household(country_id, household_json, label) + household_id = household_service.create_household( + country_id, household_json, label + ) return Response( json.dumps( @@ -80,7 +86,9 @@ def post_household(country_id: str) -> Response: ) -@household_bp.route("//household/", methods=["PUT"]) +@household_bp.route( + "//household/", methods=["PUT"] +) @validate_country def update_household(country_id: str, household_id: int) -> Response: """ @@ -103,7 +111,9 @@ def update_household(country_id: str, household_id: int) -> Response: label: str | None = payload.get("label") household_json: dict = payload.get("data") - household: dict | None = household_service.get_household(country_id, household_id) + household: dict | None = household_service.get_household( + country_id, household_id + ) if household is None: raise NotFound(f"Household #{household_id} not found.") diff --git a/policyengine_api/routes/metadata_routes.py b/policyengine_api/routes/metadata_routes.py index 8dd5465e4..496d9556d 100644 --- a/policyengine_api/routes/metadata_routes.py +++ b/policyengine_api/routes/metadata_routes.py @@ -20,7 +20,9 @@ def get_metadata(country_id: str) -> Response: # Retrieve country metadata and add status and message to the response country_metadata = metadata_service.get_metadata(country_id) return Response( - json.dumps({"status": "ok", "message": None, "result": country_metadata}), + json.dumps( + {"status": "ok", "message": None, "result": country_metadata} + ), status=200, mimetype="application/json", ) diff --git a/policyengine_api/routes/policy_routes.py b/policyengine_api/routes/policy_routes.py index 3fc88fbf4..913eb105c 100644 --- a/policyengine_api/routes/policy_routes.py +++ b/policyengine_api/routes/policy_routes.py @@ -76,4 +76,6 @@ def set_policy(country_id: str) -> Response: ) code = 200 if is_existing_policy else 201 - return Response(json.dumps(response_body), status=code, mimetype="application/json") + return Response( + json.dumps(response_body), status=code, mimetype="application/json" + ) diff --git a/policyengine_api/routes/report_output_routes.py b/policyengine_api/routes/report_output_routes.py index a95630c33..4dfb9218a 100644 --- a/policyengine_api/routes/report_output_routes.py +++ b/policyengine_api/routes/report_output_routes.py @@ -33,7 +33,9 @@ def create_report_output(country_id: str) -> Response: # Extract required fields simulation_1_id = payload.get("simulation_1_id") simulation_2_id = payload.get("simulation_2_id") # Optional - year = payload.get("year", CURRENT_YEAR) # Default to current year as string + year = payload.get( + "year", CURRENT_YEAR + ) # Default to current year as string # Validate required fields if simulation_1_id is None: @@ -93,7 +95,9 @@ def create_report_output(country_id: str) -> Response: raise BadRequest(f"Failed to create report output: {str(e)}") -@report_output_bp.route("//report/", methods=["GET"]) +@report_output_bp.route( + "//report/", methods=["GET"] +) @validate_country def get_report_output(country_id: str, report_id: int) -> Response: """ @@ -105,7 +109,9 @@ def get_report_output(country_id: str, report_id: int) -> Response: """ print(f"Getting report output {report_id} for country {country_id}") - report_output: dict | None = report_output_service.get_report_output(report_id) + report_output: dict | None = report_output_service.get_report_output( + report_id + ) if report_output is None: raise NotFound(f"Report #{report_id} not found.") diff --git a/policyengine_api/routes/simulation_analysis_routes.py b/policyengine_api/routes/simulation_analysis_routes.py index 5157b807d..893d7cae4 100644 --- a/policyengine_api/routes/simulation_analysis_routes.py +++ b/policyengine_api/routes/simulation_analysis_routes.py @@ -16,7 +16,9 @@ simulation_analysis_service = SimulationAnalysisService() -@simulation_analysis_bp.route("//simulation-analysis", methods=["POST"]) +@simulation_analysis_bp.route( + "//simulation-analysis", methods=["POST"] +) @validate_country def execute_simulation_analysis(country_id): print("Got POST request for simulation analysis") diff --git a/policyengine_api/routes/simulation_routes.py b/policyengine_api/routes/simulation_routes.py index 151c4f942..c1210d97d 100644 --- a/policyengine_api/routes/simulation_routes.py +++ b/policyengine_api/routes/simulation_routes.py @@ -96,7 +96,9 @@ def create_simulation(country_id: str) -> Response: raise BadRequest(f"Failed to create simulation: {str(e)}") -@simulation_bp.route("//simulation/", methods=["GET"]) +@simulation_bp.route( + "//simulation/", methods=["GET"] +) @validate_country def get_simulation(country_id: str, simulation_id: int) -> Response: """ diff --git a/policyengine_api/services/ai_analysis_service.py b/policyengine_api/services/ai_analysis_service.py index f2fc3c710..fa6c56db4 100644 --- a/policyengine_api/services/ai_analysis_service.py +++ b/policyengine_api/services/ai_analysis_service.py @@ -45,7 +45,9 @@ def get_existing_analysis(self, prompt: str) -> Optional[str]: def trigger_ai_analysis(self, prompt: str) -> Generator[str, None, None]: # Configure a Claude client - claude_client = anthropic.Anthropic(api_key=os.getenv("ANTHROPIC_API_KEY")) + claude_client = anthropic.Anthropic( + api_key=os.getenv("ANTHROPIC_API_KEY") + ) def generate(): response_text = "" diff --git a/policyengine_api/services/economy_service.py b/policyengine_api/services/economy_service.py index 3dd23f447..9ca08b69d 100644 --- a/policyengine_api/services/economy_service.py +++ b/policyengine_api/services/economy_service.py @@ -164,22 +164,24 @@ def get_economic_impact( if country_id == "uk": country_package_version = None - economic_impact_setup_options = EconomicImpactSetupOptions.model_validate( - { - "process_id": process_id, - "country_id": country_id, - "reform_policy_id": policy_id, - "baseline_policy_id": baseline_policy_id, - "region": region, - "dataset": dataset, - "time_period": time_period, - "options": options, - "api_version": api_version, - "target": target, - "model_version": country_package_version, - "data_version": get_dataset_version(country_id), - "options_hash": options_hash, - } + economic_impact_setup_options = ( + EconomicImpactSetupOptions.model_validate( + { + "process_id": process_id, + "country_id": country_id, + "reform_policy_id": policy_id, + "baseline_policy_id": baseline_policy_id, + "region": region, + "dataset": dataset, + "time_period": time_period, + "options": options, + "api_version": api_version, + "target": target, + "model_version": country_package_version, + "data_version": get_dataset_version(country_id), + "options_hash": options_hash, + } + ) ) # Logging that we've received a request @@ -257,15 +259,17 @@ def _get_previous_impacts( Fetch any previous simulation runs for the given policy reform. """ - previous_impacts: list[Any] = reform_impacts_service.get_all_reform_impacts( - country_id, - policy_id, - baseline_policy_id, - region, - dataset, - time_period, - options_hash, - api_version, + previous_impacts: list[Any] = ( + reform_impacts_service.get_all_reform_impacts( + country_id, + policy_id, + baseline_policy_id, + region, + dataset, + time_period, + options_hash, + api_version, + ) ) return previous_impacts @@ -344,7 +348,9 @@ def _handle_execution_state( and hasattr(execution, "error") and execution.error ): - error_message = f"Simulation API execution failed: {execution.error}" + error_message = ( + f"Simulation API execution failed: {execution.error}" + ) self._set_reform_impact_error( setup_options=setup_options, @@ -365,7 +371,9 @@ def _handle_execution_state( return EconomicImpactResult.computing() else: - raise ValueError(f"Unexpected sim API execution state: {execution_state}") + raise ValueError( + f"Unexpected sim API execution state: {execution_state}" + ) def _handle_completed_impact( self, @@ -465,7 +473,9 @@ def _setup_sim_options( "baseline": json.loads(baseline_policy), "time_period": time_period, "include_cliffs": include_cliffs, - "region": self._setup_region(country_id=country_id, region=region), + "region": self._setup_region( + country_id=country_id, region=region + ), "data": self._setup_data(country_id=country_id, region=region), "model_version": model_version, "data_version": data_version, @@ -504,7 +514,9 @@ def _validate_us_region(self, region: str) -> None: elif region.startswith("congressional_district/"): district_id = region[len("congressional_district/") :] if district_id.lower() not in get_valid_congressional_districts(): - raise ValueError(f"Invalid congressional district: '{district_id}'") + raise ValueError( + f"Invalid congressional district: '{district_id}'" + ) else: raise ValueError(f"Invalid US region: '{region}'") diff --git a/policyengine_api/services/household_service.py b/policyengine_api/services/household_service.py index dafc8bc6f..4091f71d9 100644 --- a/policyengine_api/services/household_service.py +++ b/policyengine_api/services/household_service.py @@ -40,7 +40,9 @@ def get_household(self, country_id: str, household_id: int) -> dict | None: return household except Exception as e: - print(f"Error fetching household #{household_id}. Details: {str(e)}") + print( + f"Error fetching household #{household_id}. Details: {str(e)}" + ) raise e def create_household( @@ -121,8 +123,12 @@ def update_household( ) # Fetch the updated JSON back from the table - updated_household: dict = self.get_household(country_id, household_id) + updated_household: dict = self.get_household( + country_id, household_id + ) return updated_household except Exception as e: - print(f"Error updating household #{household_id}. Details: {str(e)}") + print( + f"Error updating household #{household_id}. Details: {str(e)}" + ) raise e diff --git a/policyengine_api/services/report_output_service.py b/policyengine_api/services/report_output_service.py index c0dba45f1..4793ae018 100644 --- a/policyengine_api/services/report_output_service.py +++ b/policyengine_api/services/report_output_service.py @@ -43,13 +43,17 @@ def find_existing_report_output( existing_report = None if row is not None: existing_report = dict(row) - print(f"Found existing report output with ID: {existing_report['id']}") + print( + f"Found existing report output with ID: {existing_report['id']}" + ) # Keep output as JSON string - frontend expects string format return existing_report except Exception as e: - print(f"Error checking for existing report output. Details: {str(e)}") + print( + f"Error checking for existing report output. Details: {str(e)}" + ) raise e def create_report_output( @@ -213,5 +217,7 @@ def update_report_output( return True except Exception as e: - print(f"Error updating report output #{report_id}. Details: {str(e)}") + print( + f"Error updating report output #{report_id}. Details: {str(e)}" + ) raise e diff --git a/policyengine_api/services/simulation_analysis_service.py b/policyengine_api/services/simulation_analysis_service.py index 140fe4987..8949bf2ae 100644 --- a/policyengine_api/services/simulation_analysis_service.py +++ b/policyengine_api/services/simulation_analysis_service.py @@ -29,7 +29,9 @@ def execute_analysis( relevant_parameters: list[dict], relevant_parameter_baseline_values: list[dict], audience: str | None, - ) -> tuple[Generator[str, None, None] | str, Literal["streaming", "static"]]: + ) -> tuple[ + Generator[str, None, None] | str, Literal["streaming", "static"] + ]: """ Execute AI analysis for economy-wide simulation @@ -65,7 +67,9 @@ def execute_analysis( if existing_analysis is not None: return existing_analysis, "static" - print("Found no existing AI analysis; triggering new analysis with Claude") + print( + "Found no existing AI analysis; triggering new analysis with Claude" + ) # Otherwise, pass prompt to Claude, then return streaming function try: analysis = self.trigger_ai_analysis(prompt) @@ -105,7 +109,9 @@ def _generate_simulation_analysis_prompt( } try: - prompt = ai_prompt_service.get_prompt("simulation_analysis", prompt_data) + prompt = ai_prompt_service.get_prompt( + "simulation_analysis", prompt_data + ) return prompt except Exception as e: diff --git a/policyengine_api/services/simulation_service.py b/policyengine_api/services/simulation_service.py index a7985cb9b..88f359ae7 100644 --- a/policyengine_api/services/simulation_service.py +++ b/policyengine_api/services/simulation_service.py @@ -38,7 +38,9 @@ def find_existing_simulation( existing_simulation = None if row is not None: existing_simulation = dict(row) - print(f"Found existing simulation with ID: {existing_simulation['id']}") + print( + f"Found existing simulation with ID: {existing_simulation['id']}" + ) return existing_simulation @@ -96,7 +98,9 @@ def create_simulation( print(f"Error creating simulation. Details: {str(e)}") raise e - def get_simulation(self, country_id: str, simulation_id: int) -> dict | None: + def get_simulation( + self, country_id: str, simulation_id: int + ) -> dict | None: """ Get a simulation record by ID. @@ -127,7 +131,9 @@ def get_simulation(self, country_id: str, simulation_id: int) -> dict | None: return simulation except Exception as e: - print(f"Error fetching simulation #{simulation_id}. Details: {str(e)}") + print( + f"Error fetching simulation #{simulation_id}. Details: {str(e)}" + ) raise e def update_simulation( @@ -192,5 +198,7 @@ def update_simulation( return True except Exception as e: - print(f"Error updating simulation #{simulation_id}. Details: {str(e)}") + print( + f"Error updating simulation #{simulation_id}. Details: {str(e)}" + ) raise e diff --git a/policyengine_api/services/tracer_analysis_service.py b/policyengine_api/services/tracer_analysis_service.py index 2fd072f83..5857fcef6 100644 --- a/policyengine_api/services/tracer_analysis_service.py +++ b/policyengine_api/services/tracer_analysis_service.py @@ -18,7 +18,9 @@ def execute_analysis( household_id: str, policy_id: str, variable: str, - ) -> tuple[Generator[str, None, None] | str, Literal["static", "streaming"]]: + ) -> tuple[ + Generator[str, None, None] | str, Literal["static", "streaming"] + ]: """ Executes tracer analysis for a variable in a household @@ -42,7 +44,9 @@ def execute_analysis( # Parse the tracer output for our given variable try: - tracer_segment: list[str] = self._parse_tracer_output(tracer, variable) + tracer_segment: list[str] = self._parse_tracer_output( + tracer, variable + ) except Exception as e: print(f"Error parsing tracer output: {str(e)}") raise e @@ -103,13 +107,17 @@ def _parse_tracer_output(self, tracer_output, target_variable): capturing = False # Input validation - if not isinstance(target_variable, str) or not isinstance(tracer_output, list): + if not isinstance(target_variable, str) or not isinstance( + tracer_output, list + ): return result # Create a regex pattern to match the exact variable name # This will match the variable name followed by optional whitespace, # then optional angle brackets with any content, then optional whitespace - pattern = rf"^(\s*)({re.escape(target_variable)})(?!\w)\s*(?:<[^>]*>)?\s*" + pattern = ( + rf"^(\s*)({re.escape(target_variable)})(?!\w)\s*(?:<[^>]*>)?\s*" + ) for line in tracer_output: # Count leading spaces to determine indentation level diff --git a/policyengine_api/utils/payload_validators/validate_household_payload.py b/policyengine_api/utils/payload_validators/validate_household_payload.py index c66f15e26..7b4f7d951 100644 --- a/policyengine_api/utils/payload_validators/validate_household_payload.py +++ b/policyengine_api/utils/payload_validators/validate_household_payload.py @@ -19,7 +19,9 @@ def validate_household_payload(payload): # Check that label is either string or None, if present if "label" in payload: - if payload["label"] is not None and not isinstance(payload["label"], str): + if payload["label"] is not None and not isinstance( + payload["label"], str + ): return False, "Label must be a string or None" # Check that data is a dictionary diff --git a/policyengine_api/utils/payload_validators/validate_set_policy_payload.py b/policyengine_api/utils/payload_validators/validate_set_policy_payload.py index f90f80d17..a48c75bda 100644 --- a/policyengine_api/utils/payload_validators/validate_set_policy_payload.py +++ b/policyengine_api/utils/payload_validators/validate_set_policy_payload.py @@ -8,7 +8,9 @@ def validate_set_policy_payload(payload: dict) -> tuple[bool, str | None]: # Check that label is either string or None if "label" in payload: - if payload["label"] is not None and not isinstance(payload["label"], str): + if payload["label"] is not None and not isinstance( + payload["label"], str + ): return False, "Label must be a string or None" # Check that data is a dictionary diff --git a/policyengine_api/utils/singleton.py b/policyengine_api/utils/singleton.py index 3776cb92d..28e8a0984 100644 --- a/policyengine_api/utils/singleton.py +++ b/policyengine_api/utils/singleton.py @@ -3,5 +3,7 @@ class Singleton(type): def __call__(cls, *args, **kwargs): if cls not in cls._instances: - cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs) + cls._instances[cls] = super(Singleton, cls).__call__( + *args, **kwargs + ) return cls._instances[cls] diff --git a/tests/env_variables/test_environment_variables.py b/tests/env_variables/test_environment_variables.py index 9bcaa2bc3..23a21ea1d 100644 --- a/tests/env_variables/test_environment_variables.py +++ b/tests/env_variables/test_environment_variables.py @@ -39,7 +39,9 @@ def test_github_microdata_auth_token(self): """Test if POLICYENGINE_GITHUB_MICRODATA_AUTH_TOKEN is valid by querying GitHub user API.""" token = os.getenv("POLICYENGINE_GITHUB_MICRODATA_AUTH_TOKEN") - assert token is not None, "POLICYENGINE_GITHUB_MICRODATA_AUTH_TOKEN is not set" + assert ( + token is not None + ), "POLICYENGINE_GITHUB_MICRODATA_AUTH_TOKEN is not set" headers = { "Authorization": f"Bearer {token}", diff --git a/tests/fixtures/integration/simulations.py b/tests/fixtures/integration/simulations.py index 741676047..aefddc9fe 100644 --- a/tests/fixtures/integration/simulations.py +++ b/tests/fixtures/integration/simulations.py @@ -6,9 +6,7 @@ from unittest.mock import Mock, MagicMock, patch from policyengine_api.endpoints.household import add_yearly_variables -STANDARD_AXES_COUNT = ( - 401 # Not formally defined anywhere, but this value is used throughout the API -) +STANDARD_AXES_COUNT = 401 # Not formally defined anywhere, but this value is used throughout the API SMALL_AXES_COUNT = 5 TEST_YEAR = "2025" TEST_STATE = "NY" @@ -69,6 +67,10 @@ def create_household_with_axes(base_household, axes_config): def setup_small_axes_household(base_household, small_axes_config): """Fixture to setup a household with small axes for testing""" - household_with_axes = create_household_with_axes(base_household, small_axes_config) - household_with_axes = add_yearly_variables(household_with_axes, TEST_COUNTRY_ID) + household_with_axes = create_household_with_axes( + base_household, small_axes_config + ) + household_with_axes = add_yearly_variables( + household_with_axes, TEST_COUNTRY_ID + ) return household_with_axes diff --git a/tests/fixtures/services/ai_analysis_service.py b/tests/fixtures/services/ai_analysis_service.py index 95bba3039..a2f4d21c4 100644 --- a/tests/fixtures/services/ai_analysis_service.py +++ b/tests/fixtures/services/ai_analysis_service.py @@ -39,10 +39,14 @@ def _configure(text_chunks: list[str]): # Set up mock stream mock_stream = MagicMock() - mock_client.messages.stream.return_value.__enter__.return_value = mock_stream + mock_client.messages.stream.return_value.__enter__.return_value = ( + mock_stream + ) # Configure stream to yield text events - events = [MockEvent(event_type="text", text=chunk) for chunk in text_chunks] + events = [ + MockEvent(event_type="text", text=chunk) for chunk in text_chunks + ] mock_stream.__iter__.return_value = events return mock_client @@ -63,7 +67,9 @@ def _configure(error_type: str): # Set up mock stream mock_stream = MagicMock() - mock_client.messages.stream.return_value.__enter__.return_value = mock_stream + mock_client.messages.stream.return_value.__enter__.return_value = ( + mock_stream + ) # Configure stream to yield an error event error_event = MockEvent(event_type="error", error={"type": error_type}) diff --git a/tests/fixtures/services/economy_service.py b/tests/fixtures/services/economy_service.py index d94ffe9b4..293b8909e 100644 --- a/tests/fixtures/services/economy_service.py +++ b/tests/fixtures/services/economy_service.py @@ -30,7 +30,9 @@ MOCK_MODEL_VERSION = "1.2.3" MOCK_DATA_VERSION = None -MOCK_REFORM_POLICY_JSON = json.dumps({"sample_param": {"2024-01-01.2100-12-31": 15}}) +MOCK_REFORM_POLICY_JSON = json.dumps( + {"sample_param": {"2024-01-01.2100-12-31": 15}} +) MOCK_BASELINE_POLICY_JSON = json.dumps({}) @@ -138,7 +140,9 @@ def mock_logger(): def mock_datetime(): """Mock datetime.datetime.now().""" mock_now = datetime.datetime(2025, 6, 26, 12, 0, 0) - with patch("policyengine_api.services.economy_service.datetime.datetime") as mock: + with patch( + "policyengine_api.services.economy_service.datetime.datetime" + ) as mock: mock.now.return_value = mock_now yield mock @@ -168,11 +172,14 @@ def create_mock_reform_impact( "options_hash": MOCK_OPTIONS_HASH, "status": status, "api_version": MOCK_API_VERSION, - "reform_impact_json": reform_impact_json or json.dumps(MOCK_REFORM_IMPACT_DATA), + "reform_impact_json": reform_impact_json + or json.dumps(MOCK_REFORM_IMPACT_DATA), "execution_id": execution_id, "start_time": datetime.datetime(2025, 6, 26, 12, 0, 0), "end_time": ( - datetime.datetime(2025, 6, 26, 12, 5, 0) if status == "ok" else None + datetime.datetime(2025, 6, 26, 12, 5, 0) + if status == "ok" + else None ), } @@ -244,7 +251,9 @@ def mock_simulation_api_modal(): MOCK_US_NATIONWIDE_DATASET = "gs://policyengine-us-data/cps_2023.h5" MOCK_US_STATE_CA_DATASET = "gs://policyengine-us-data/states/CA.h5" MOCK_US_STATE_UT_DATASET = "gs://policyengine-us-data/states/UT.h5" -MOCK_US_CITY_NYC_DATASET = "gs://policyengine-us-data/pooled_3_year_cps_2023.h5" +MOCK_US_CITY_NYC_DATASET = ( + "gs://policyengine-us-data/pooled_3_year_cps_2023.h5" +) MOCK_US_DISTRICT_CA37_DATASET = "gs://policyengine-us-data/districts/CA-37.h5" MOCK_UK_DATASET = "gs://policyengine-uk-data-private/enhanced_frs_2023_24.h5" diff --git a/tests/fixtures/services/household_fixtures.py b/tests/fixtures/services/household_fixtures.py index d68cad86a..f84d99c95 100644 --- a/tests/fixtures/services/household_fixtures.py +++ b/tests/fixtures/services/household_fixtures.py @@ -23,7 +23,9 @@ @pytest.fixture def mock_hash_object(): """Mock the hash_object function.""" - with patch("policyengine_api.services.household_service.hash_object") as mock: + with patch( + "policyengine_api.services.household_service.hash_object" + ) as mock: mock.return_value = valid_hash_value yield mock diff --git a/tests/fixtures/services/policy_service.py b/tests/fixtures/services/policy_service.py index 6c4a27f66..18ee9071e 100644 --- a/tests/fixtures/services/policy_service.py +++ b/tests/fixtures/services/policy_service.py @@ -3,7 +3,9 @@ from unittest.mock import patch valid_policy_json = { - "data": {"gov.irs.income.bracket.rates.2": {"2024-01-01.2024-12-31": 0.2433}}, + "data": { + "gov.irs.income.bracket.rates.2": {"2024-01-01.2024-12-31": 0.2433} + }, } valid_hash_value = "NgJhpeuRVnIAwgYWuJsd2fI/N88rIE6Kcj8q4TPD/i4=" diff --git a/tests/integration/test_simulations.py b/tests/integration/test_simulations.py index 37f8da106..36056f239 100644 --- a/tests/integration/test_simulations.py +++ b/tests/integration/test_simulations.py @@ -40,9 +40,13 @@ def test__given_any_number_of_axes__sim_returns_valid_arrays( print("Variable name: ", variable_name) if variable_name in FORBIDDEN_VARIABLES: continue - for period in result[entity_type][entity_id][variable_name]: + for period in result[entity_type][entity_id][ + variable_name + ]: print("Period: ", period) - value = result[entity_type][entity_id][variable_name][period] + value = result[entity_type][entity_id][variable_name][ + period + ] print(f"Value: {value}") if isinstance(value, list): # Assert no Nones diff --git a/tests/to_refactor/api/test_api.py b/tests/to_refactor/api/test_api.py index f0855a6ec..74f3e2bd6 100644 --- a/tests/to_refactor/api/test_api.py +++ b/tests/to_refactor/api/test_api.py @@ -23,7 +23,9 @@ def client(): # - expected_result: the expected result of the endpoint test_paths = [ - path for path in (Path(__file__).parent).rglob("*") if path.suffix == ".yaml" + path + for path in (Path(__file__).parent).rglob("*") + if path.suffix == ".yaml" ] test_data = [yaml.safe_load(path.read_text()) for path in test_paths] test_names = [test["name"] for test in test_data] @@ -68,4 +70,6 @@ def test_response(client, test: dict): json.loads(response.data), test.get("response", {}).get("data", {}) ) elif "html" in test.get("response", {}): - assert response.data.decode("utf-8") == test.get("response", {}).get("html", "") + assert response.data.decode("utf-8") == test.get("response", {}).get( + "html", "" + ) diff --git a/tests/to_refactor/fixtures/to_refactor_household_fixtures.py b/tests/to_refactor/fixtures/to_refactor_household_fixtures.py index 5fa6af91c..89b854f19 100644 --- a/tests/to_refactor/fixtures/to_refactor_household_fixtures.py +++ b/tests/to_refactor/fixtures/to_refactor_household_fixtures.py @@ -22,7 +22,9 @@ @pytest.fixture def mock_hash_object(): """Mock the hash_object function.""" - with patch("policyengine_api.services.household_service.hash_object") as mock: + with patch( + "policyengine_api.services.household_service.hash_object" + ) as mock: mock.return_value = valid_hash_value yield mock @@ -30,5 +32,7 @@ def mock_hash_object(): @pytest.fixture def mock_database(): """Mock the database module.""" - with patch("policyengine_api.services.household_service.database") as mock_db: + with patch( + "policyengine_api.services.household_service.database" + ) as mock_db: yield mock_db diff --git a/tests/to_refactor/python/test_ai_analysis_service_old.py b/tests/to_refactor/python/test_ai_analysis_service_old.py index 0df3928ca..aa8c825e3 100644 --- a/tests/to_refactor/python/test_ai_analysis_service_old.py +++ b/tests/to_refactor/python/test_ai_analysis_service_old.py @@ -9,7 +9,9 @@ @patch("policyengine_api.services.ai_analysis_service.local_database") def test_get_existing_analysis_found(mock_db): - mock_db.query.return_value.fetchone.return_value = {"analysis": "Existing analysis"} + mock_db.query.return_value.fetchone.return_value = { + "analysis": "Existing analysis" + } prompt = "Test prompt" output = test_ai_service.get_existing_analysis(prompt) diff --git a/tests/to_refactor/python/test_household_routes.py b/tests/to_refactor/python/test_household_routes.py index 5b3ccb812..e4ea05a1c 100644 --- a/tests/to_refactor/python/test_household_routes.py +++ b/tests/to_refactor/python/test_household_routes.py @@ -46,7 +46,9 @@ def test_get_household_invalid_id(self, rest_client): response = rest_client.get("/us/household/invalid") assert response.status_code == 404 - assert b"The requested URL was not found on the server" in response.data + assert ( + b"The requested URL was not found on the server" in response.data + ) class TestCreateHousehold: @@ -114,7 +116,9 @@ def test_update_household_success( mock_row.keys.return_value = valid_db_row.keys() mock_database.query().fetchone.return_value = mock_row - updated_household = {"people": {"person1": {"age": 31, "income": 55000}}} + updated_household = { + "people": {"person1": {"age": 31, "income": 55000}} + } updated_data = { "data": updated_household, @@ -178,7 +182,9 @@ def test_update_household_invalid_payload(self, rest_client): class TestHouseholdRouteServiceErrors: """Test handling of service-level errors in routes.""" - @patch("policyengine_api.services.household_service.HouseholdService.get_household") + @patch( + "policyengine_api.services.household_service.HouseholdService.get_household" + ) def test_get_household_service_error(self, mock_get, rest_client): """Test GET endpoint when service raises an error.""" mock_get.side_effect = Exception("Database connection failed") diff --git a/tests/to_refactor/python/test_policy_service_old.py b/tests/to_refactor/python/test_policy_service_old.py index 832816f83..a90680d80 100644 --- a/tests/to_refactor/python/test_policy_service_old.py +++ b/tests/to_refactor/python/test_policy_service_old.py @@ -30,13 +30,17 @@ def policy_service(): class TestPolicyService: - a_test_policy_id = 8 # Pre-seeded current law policies occupy IDs 1 through 5 + a_test_policy_id = ( + 8 # Pre-seeded current law policies occupy IDs 1 through 5 + ) def test_get_policy_success( self, policy_service, mock_database, sample_policy_data ): # Setup mock - mock_database.query.return_value.fetchone.return_value = sample_policy_data + mock_database.query.return_value.fetchone.return_value = ( + sample_policy_data + ) # Test result = policy_service.get_policy("us", self.a_test_policy_id) @@ -60,7 +64,9 @@ def test_get_policy_not_found(self, policy_service, mock_database): assert result is None mock_database.query.assert_called_once() - def test_get_policy_json(self, policy_service, mock_database, sample_policy_data): + def test_get_policy_json( + self, policy_service, mock_database, sample_policy_data + ): # Setup mock mock_database.query.return_value.fetchone.return_value = { "policy_json": sample_policy_data["policy_json"] @@ -125,7 +131,9 @@ def test_set_policy_existing( self, policy_service, mock_database, sample_policy_data ): # Setup mock - mock_database.query.return_value.fetchone.return_value = sample_policy_data + mock_database.query.return_value.fetchone.return_value = ( + sample_policy_data + ) # Test policy_id, message, exists = policy_service.set_policy( @@ -144,7 +152,9 @@ def test_get_unique_policy_with_label( self, policy_service, mock_database, sample_policy_data ): # Setup mock - mock_database.query.return_value.fetchone.return_value = sample_policy_data + mock_database.query.return_value.fetchone.return_value = ( + sample_policy_data + ) # Test result = policy_service._get_unique_policy_with_label( @@ -157,12 +167,16 @@ def test_get_unique_policy_with_label( assert result == sample_policy_data mock_database.query.assert_called_once() - def test_get_unique_policy_with_null_label(self, policy_service, mock_database): + def test_get_unique_policy_with_null_label( + self, policy_service, mock_database + ): # Setup mock mock_database.query.return_value.fetchone.return_value = None # Test - result = policy_service._get_unique_policy_with_label("us", "hash123", None) + result = policy_service._get_unique_policy_with_label( + "us", "hash123", None + ) # Verify assert result is None @@ -193,6 +207,8 @@ def test_error_handling(self, policy_service, mock_database, error_method): elif error_method == "set_policy": policy_service.set_policy("us", "label", {}) else: - policy_service._get_unique_policy_with_label("us", "hash", "label") + policy_service._get_unique_policy_with_label( + "us", "hash", "label" + ) assert str(exc_info.value) == "Database error" diff --git a/tests/to_refactor/python/test_simulation_analysis_routes.py b/tests/to_refactor/python/test_simulation_analysis_routes.py index f1f2ab6f1..0a4812e31 100644 --- a/tests/to_refactor/python/test_simulation_analysis_routes.py +++ b/tests/to_refactor/python/test_simulation_analysis_routes.py @@ -40,7 +40,9 @@ def test_execute_simulation_analysis_new_analysis(rest_client): ) as mock_trigger: mock_trigger.return_value = (s for s in ["New analysis"]) - response = rest_client.post("/us/simulation-analysis", json=test_json) + response = rest_client.post( + "/us/simulation-analysis", json=test_json + ) assert response.status_code == 200 assert b"New analysis" in response.data @@ -56,7 +58,9 @@ def test_execute_simulation_analysis_error(rest_client): ) as mock_trigger: mock_trigger.side_effect = Exception("Test error") - response = rest_client.post("/us/simulation-analysis", json=test_json) + response = rest_client.post( + "/us/simulation-analysis", json=test_json + ) assert response.status_code == 500 assert "Test error" in response.json.get("message") @@ -91,7 +95,9 @@ def test_execute_simulation_analysis_enhanced_cps(rest_client): with patch( "policyengine_api.services.ai_analysis_service.AIAnalysisService.trigger_ai_analysis" ) as mock_trigger: - mock_trigger.return_value = (s for s in ["Enhanced CPS analysis"]) + mock_trigger.return_value = ( + s for s in ["Enhanced CPS analysis"] + ) response = rest_client.post( "/us/simulation-analysis", json=test_json_enhanced_cps diff --git a/tests/to_refactor/python/test_tracer_analysis_routes.py b/tests/to_refactor/python/test_tracer_analysis_routes.py index 83f7bde23..f88805f8d 100644 --- a/tests/to_refactor/python/test_tracer_analysis_routes.py +++ b/tests/to_refactor/python/test_tracer_analysis_routes.py @@ -58,7 +58,8 @@ def test_execute_tracer_analysis_no_tracer(mock_db, rest_client): assert response.status_code == 404 assert ( - "No household simulation tracer found" in json.loads(response.data)["message"] + "No household simulation tracer found" + in json.loads(response.data)["message"] ) @@ -114,7 +115,9 @@ def test_invalid_variable_types(mock_db, rest_client): }, ) assert response.status_code == 400 - assert "variable must be a string" in json.loads(response.data)["message"] + assert ( + "variable must be a string" in json.loads(response.data)["message"] + ) # Test invalid country @@ -215,4 +218,7 @@ def test_validate_tracer_analysis_payload_failure(rest_client): }, ) assert response.status_code == 400 - assert "Missing required key: variable" in json.loads(response.data)["message"] + assert ( + "Missing required key: variable" + in json.loads(response.data)["message"] + ) diff --git a/tests/to_refactor/python/test_us_policy_macro.py b/tests/to_refactor/python/test_us_policy_macro.py index 9d6c20d82..03cb620d2 100644 --- a/tests/to_refactor/python/test_us_policy_macro.py +++ b/tests/to_refactor/python/test_us_policy_macro.py @@ -72,9 +72,13 @@ def utah_reform_runner(rest_client, region: str = "us"): cost = round(result["budget"]["budgetary_impact"] / 1e6, 1) assert ( cost / 95.4 - 1 - ) < 0.01, f"Expected budgetary impact to be 95.4 million, got {cost} million" + ) < 0.01, ( + f"Expected budgetary impact to be 95.4 million, got {cost} million" + ) - assert (result["intra_decile"]["all"]["Lose less than 5%"] / 0.637 - 1) < 0.01, ( + assert ( + result["intra_decile"]["all"]["Lose less than 5%"] / 0.637 - 1 + ) < 0.01, ( f"Expected 63.7% of people to lose less than 5%, got " f"{result['intra_decile']['all']['Lose less than 5%']}" ) diff --git a/tests/to_refactor/python/test_user_profile_routes.py b/tests/to_refactor/python/test_user_profile_routes.py index ec30f9eef..a3d873dbb 100644 --- a/tests/to_refactor/python/test_user_profile_routes.py +++ b/tests/to_refactor/python/test_user_profile_routes.py @@ -42,7 +42,9 @@ def test_set_and_get_record(self, rest_client): assert res.status_code == 200 assert return_object["status"] == "ok" assert return_object["result"]["auth0_id"] == self.auth0_id - assert return_object["result"]["primary_country"] == self.primary_country + assert ( + return_object["result"]["primary_country"] == self.primary_country + ) assert return_object["result"]["username"] == None user_id = return_object["result"]["user_id"] @@ -52,7 +54,9 @@ def test_set_and_get_record(self, rest_client): assert res.status_code == 200 assert return_object["status"] == "ok" - assert return_object["result"]["primary_country"] == self.primary_country + assert ( + return_object["result"]["primary_country"] == self.primary_country + ) assert return_object["result"].get("auth0_id") is None assert return_object["result"]["username"] == None @@ -73,7 +77,9 @@ def test_set_and_get_record(self, rest_client): malicious_updated_profile = {**updated_profile, "auth0_id": "BOGUS"} - res = rest_client.put("/us/user-profile", json=malicious_updated_profile) + res = rest_client.put( + "/us/user-profile", json=malicious_updated_profile + ) return_object = json.loads(res.text) assert res.status_code == 200 @@ -93,7 +99,9 @@ def test_set_and_get_record(self, rest_client): def test_non_existent_record(self, rest_client): non_existent_auth0_id = "non-existent-auth0-id" - res = rest_client.get(f"/us/user-profile?auth0_id={non_existent_auth0_id}") + res = rest_client.get( + f"/us/user-profile?auth0_id={non_existent_auth0_id}" + ) return_object = json.loads(res.text) assert res.status_code == 404 diff --git a/tests/to_refactor/python/test_validate_household_payload.py b/tests/to_refactor/python/test_validate_household_payload.py index d45363d0d..42e6a0708 100644 --- a/tests/to_refactor/python/test_validate_household_payload.py +++ b/tests/to_refactor/python/test_validate_household_payload.py @@ -14,7 +14,9 @@ class TestHouseholdRouteValidation: {"data": {}, "label": 123}, # Invalid label type ], ) - def test_post_household_invalid_payload(self, rest_client, invalid_payload): + def test_post_household_invalid_payload( + self, rest_client, invalid_payload + ): """Test POST endpoint with various invalid payloads.""" response = rest_client.post( "/us/household", @@ -38,7 +40,9 @@ def test_get_household_invalid_id(self, rest_client, invalid_id): # Default Werkzeug validation returns 404, not 400 assert response.status_code == 404 - assert b"The requested URL was not found on the server" in response.data + assert ( + b"The requested URL was not found on the server" in response.data + ) @pytest.mark.parametrize( "country_id", diff --git a/tests/to_refactor/python/test_yearly_var_removal.py b/tests/to_refactor/python/test_yearly_var_removal.py index 9e8294479..e4f463e19 100644 --- a/tests/to_refactor/python/test_yearly_var_removal.py +++ b/tests/to_refactor/python/test_yearly_var_removal.py @@ -154,14 +154,17 @@ def interface_test_household_under_policy( # Skip ignored variables if ( variable in excluded_vars - or metadata["variables"][variable]["definitionPeriod"] != "year" + or metadata["variables"][variable]["definitionPeriod"] + != "year" ): continue # Ensure that the variable exists in both # result_object and test_object if variable not in metadata["variables"]: - print(f"Failing due to variable {variable} not in metadata") + print( + f"Failing due to variable {variable} not in metadata" + ) is_test_passing = False break @@ -185,10 +188,14 @@ def interface_test_household_under_policy( results_diff = result_var_set.difference(metadata_var_set) metadata_diff = metadata_var_set.difference(result_var_set) if len(results_diff) > 0: - print("Error: The following values are only present in the result object:") + print( + "Error: The following values are only present in the result object:" + ) print(results_diff) if len(metadata_diff) > 0: - print("Error: The following values are only present in the metadata:") + print( + "Error: The following values are only present in the metadata:" + ) print(metadata_diff) is_test_passing = False @@ -200,7 +207,9 @@ def test_us_household_under_policy(): Test that a US household under current law is created correctly """ - is_test_passing = interface_test_household_under_policy("us", "2", ["members"]) + is_test_passing = interface_test_household_under_policy( + "us", "2", ["members"] + ) assert is_test_passing == True @@ -276,14 +285,17 @@ def test_get_calculate(client): # Skip ignored variables if ( variable in excluded_vars - or metadata["variables"][variable]["definitionPeriod"] != "year" + or metadata["variables"][variable]["definitionPeriod"] + != "year" ): continue # Ensure that the variable exists in both # result_object and test_object if variable not in metadata["variables"]: - print(f"Failing due to variable {variable} not in metadata") + print( + f"Failing due to variable {variable} not in metadata" + ) is_test_passing = False break @@ -307,10 +319,14 @@ def test_get_calculate(client): results_diff = result_var_set.difference(metadata_var_set) metadata_diff = metadata_var_set.difference(result_var_set) if len(results_diff) > 0: - print("Error: The following values are only present in the result object:") + print( + "Error: The following values are only present in the result object:" + ) print(results_diff) if len(metadata_diff) > 0: - print("Error: The following values are only present in the metadata:") + print( + "Error: The following values are only present in the metadata:" + ) print(metadata_diff) is_test_passing = False diff --git a/tests/unit/ai_prompts/test_simulation_analysis_prompt.py b/tests/unit/ai_prompts/test_simulation_analysis_prompt.py index 429a9ed10..05f1931e7 100644 --- a/tests/unit/ai_prompts/test_simulation_analysis_prompt.py +++ b/tests/unit/ai_prompts/test_simulation_analysis_prompt.py @@ -29,11 +29,13 @@ def test_given_valid_uk_input(self, snapshot): def test_given_dataset_is_enhanced_cps(self, snapshot): snapshot.snapshot_dir = "tests/snapshots" - valid_enhanced_cps_input_data = given_valid_data_and_dataset_is_enhanced_cps( - valid_input_us + valid_enhanced_cps_input_data = ( + given_valid_data_and_dataset_is_enhanced_cps(valid_input_us) ) - prompt = generate_simulation_analysis_prompt(valid_enhanced_cps_input_data) + prompt = generate_simulation_analysis_prompt( + valid_enhanced_cps_input_data + ) snapshot.assert_match( prompt, "simulation_analysis_prompt_dataset_enhanced_cps.txt" ) @@ -44,4 +46,6 @@ def test_given_missing_input_field(self): Exception, match="1 validation error for InboundParameters\ntime_period\n Field required", ): - generate_simulation_analysis_prompt(invalid_data_missing_input_field) + generate_simulation_analysis_prompt( + invalid_data_missing_input_field + ) diff --git a/tests/unit/data/test_congressional_districts.py b/tests/unit/data/test_congressional_districts.py index 255cfd4dc..05819916a 100644 --- a/tests/unit/data/test_congressional_districts.py +++ b/tests/unit/data/test_congressional_districts.py @@ -78,11 +78,15 @@ def test__all_state_codes_are_in_state_code_to_name(self): assert district.state_code in STATE_CODE_TO_NAME def test__california_has_52_districts(self): - ca_districts = [d for d in CONGRESSIONAL_DISTRICTS if d.state_code == "CA"] + ca_districts = [ + d for d in CONGRESSIONAL_DISTRICTS if d.state_code == "CA" + ] assert len(ca_districts) == 52 def test__texas_has_38_districts(self): - tx_districts = [d for d in CONGRESSIONAL_DISTRICTS if d.state_code == "TX"] + tx_districts = [ + d for d in CONGRESSIONAL_DISTRICTS if d.state_code == "TX" + ] assert len(tx_districts) == 38 def test__at_large_states_have_1_district(self): @@ -90,23 +94,31 @@ def test__at_large_states_have_1_district(self): at_large_states = [s for s in AT_LARGE_STATES if s != "DC"] for state_code in at_large_states: state_districts = [ - d for d in CONGRESSIONAL_DISTRICTS if d.state_code == state_code + d + for d in CONGRESSIONAL_DISTRICTS + if d.state_code == state_code ] assert len(state_districts) == 1 assert state_districts[0].number == 1 def test__dc_has_1_district(self): - dc_districts = [d for d in CONGRESSIONAL_DISTRICTS if d.state_code == "DC"] + dc_districts = [ + d for d in CONGRESSIONAL_DISTRICTS if d.state_code == "DC" + ] assert len(dc_districts) == 1 assert dc_districts[0].number == 1 def test__dc_comes_after_delaware(self): # Find indices de_indices = [ - i for i, d in enumerate(CONGRESSIONAL_DISTRICTS) if d.state_code == "DE" + i + for i, d in enumerate(CONGRESSIONAL_DISTRICTS) + if d.state_code == "DE" ] dc_indices = [ - i for i, d in enumerate(CONGRESSIONAL_DISTRICTS) if d.state_code == "DC" + i + for i, d in enumerate(CONGRESSIONAL_DISTRICTS) + if d.state_code == "DC" ] # DC should come after all DE districts assert min(dc_indices) > max(de_indices) @@ -132,27 +144,36 @@ def test__name_has_correct_format(self): metadata = build_congressional_district_metadata() # Check first California district ca_01 = next( - item for item in metadata if item["name"] == "congressional_district/CA-01" + item + for item in metadata + if item["name"] == "congressional_district/CA-01" ) assert ca_01 is not None def test__label_has_correct_format(self): metadata = build_congressional_district_metadata() ca_01 = next( - item for item in metadata if item["name"] == "congressional_district/CA-01" + item + for item in metadata + if item["name"] == "congressional_district/CA-01" ) assert ca_01["label"] == "California's 1st congressional district" def test__state_abbreviation_is_uppercase(self): metadata = build_congressional_district_metadata() for item in metadata: - assert item["state_abbreviation"] == item["state_abbreviation"].upper() + assert ( + item["state_abbreviation"] + == item["state_abbreviation"].upper() + ) assert len(item["state_abbreviation"]) == 2 def test__state_name_matches_abbreviation(self): metadata = build_congressional_district_metadata() ca_01 = next( - item for item in metadata if item["name"] == "congressional_district/CA-01" + item + for item in metadata + if item["name"] == "congressional_district/CA-01" ) assert ca_01["state_abbreviation"] == "CA" assert ca_01["state_name"] == "California" @@ -160,7 +181,9 @@ def test__state_name_matches_abbreviation(self): def test__dc_state_fields(self): metadata = build_congressional_district_metadata() dc_01 = next( - item for item in metadata if item["name"] == "congressional_district/DC-01" + item + for item in metadata + if item["name"] == "congressional_district/DC-01" ) assert dc_01["state_abbreviation"] == "DC" assert dc_01["state_name"] == "District of Columbia" @@ -175,25 +198,39 @@ def test__ordinal_suffixes_are_correct(self): # Find specific districts to test ordinal suffixes ca_01 = next( - item for item in metadata if item["name"] == "congressional_district/CA-01" + item + for item in metadata + if item["name"] == "congressional_district/CA-01" ) ca_02 = next( - item for item in metadata if item["name"] == "congressional_district/CA-02" + item + for item in metadata + if item["name"] == "congressional_district/CA-02" ) ca_03 = next( - item for item in metadata if item["name"] == "congressional_district/CA-03" + item + for item in metadata + if item["name"] == "congressional_district/CA-03" ) ca_11 = next( - item for item in metadata if item["name"] == "congressional_district/CA-11" + item + for item in metadata + if item["name"] == "congressional_district/CA-11" ) ca_12 = next( - item for item in metadata if item["name"] == "congressional_district/CA-12" + item + for item in metadata + if item["name"] == "congressional_district/CA-12" ) ca_21 = next( - item for item in metadata if item["name"] == "congressional_district/CA-21" + item + for item in metadata + if item["name"] == "congressional_district/CA-21" ) ca_22 = next( - item for item in metadata if item["name"] == "congressional_district/CA-22" + item + for item in metadata + if item["name"] == "congressional_district/CA-22" ) assert "1st" in ca_01["label"] @@ -208,13 +245,17 @@ def test__district_numbers_have_leading_zeros(self): metadata = build_congressional_district_metadata() # Single digit districts should have leading zero ca_01 = next( - item for item in metadata if item["name"] == "congressional_district/CA-01" + item + for item in metadata + if item["name"] == "congressional_district/CA-01" ) assert ca_01["name"] == "congressional_district/CA-01" # Double digit districts should not have leading zero ca_37 = next( - item for item in metadata if item["name"] == "congressional_district/CA-37" + item + for item in metadata + if item["name"] == "congressional_district/CA-37" ) assert ca_37["name"] == "congressional_district/CA-37" @@ -234,14 +275,18 @@ def test__at_large_states_have_at_large_label(self): def test__alaska_at_large_label(self): metadata = build_congressional_district_metadata() ak_01 = next( - item for item in metadata if item["name"] == "congressional_district/AK-01" + item + for item in metadata + if item["name"] == "congressional_district/AK-01" ) assert ak_01["label"] == "Alaska's at-large congressional district" def test__wyoming_at_large_label(self): metadata = build_congressional_district_metadata() wy_01 = next( - item for item in metadata if item["name"] == "congressional_district/WY-01" + item + for item in metadata + if item["name"] == "congressional_district/WY-01" ) assert wy_01["label"] == "Wyoming's at-large congressional district" diff --git a/tests/unit/endpoints/economy/test_compare.py b/tests/unit/endpoints/economy/test_compare.py index 8ef1eaec2..17ff66275 100644 --- a/tests/unit/endpoints/economy/test_compare.py +++ b/tests/unit/endpoints/economy/test_compare.py @@ -118,7 +118,9 @@ def test__given_non_uk_country_canada__returns_none(self): result = uk_local_authority_breakdown({}, {}, "ca") assert result is None - @patch("policyengine_api.endpoints.economy.compare.download_huggingface_dataset") + @patch( + "policyengine_api.endpoints.economy.compare.download_huggingface_dataset" + ) @patch("policyengine_api.endpoints.economy.compare.h5py.File") @patch("policyengine_api.endpoints.economy.compare.pd.read_csv") def test__given_uk_country__returns_breakdown( @@ -133,7 +135,9 @@ def test__given_uk_country__returns_breakdown( # Create mock weights - 3 local authorities, 10 households mock_weights = np.ones((3, 10)) mock_h5py_context = MagicMock() - mock_h5py_context.__enter__ = MagicMock(return_value={"2025": mock_weights}) + mock_h5py_context.__enter__ = MagicMock( + return_value={"2025": mock_weights} + ) mock_h5py_context.__exit__ = MagicMock(return_value=False) mock_h5py_file.return_value = mock_h5py_context @@ -225,7 +229,9 @@ def test__outcome_bucket_categorization_logic(self): bucket == expected_bucket ), f"Failed for {percent_change}: expected {expected_bucket}, got {bucket}" - @patch("policyengine_api.endpoints.economy.compare.download_huggingface_dataset") + @patch( + "policyengine_api.endpoints.economy.compare.download_huggingface_dataset" + ) @patch("policyengine_api.endpoints.economy.compare.h5py.File") @patch("policyengine_api.endpoints.economy.compare.pd.read_csv") def test__outcome_buckets_are_correct( @@ -238,7 +244,9 @@ def test__outcome_buckets_are_correct( mock_weights = np.ones((1, 10)) mock_h5py_context = MagicMock() - mock_h5py_context.__enter__ = MagicMock(return_value={"2025": mock_weights}) + mock_h5py_context.__enter__ = MagicMock( + return_value={"2025": mock_weights} + ) mock_h5py_context.__exit__ = MagicMock(return_value=False) mock_h5py_file.return_value = mock_h5py_context @@ -261,7 +269,9 @@ def test__outcome_buckets_are_correct( assert result.outcomes_by_region["uk"]["Gain more than 5%"] == 1 assert result.outcomes_by_region["uk"]["Gain less than 5%"] == 0 - @patch("policyengine_api.endpoints.economy.compare.download_huggingface_dataset") + @patch( + "policyengine_api.endpoints.economy.compare.download_huggingface_dataset" + ) @patch("policyengine_api.endpoints.economy.compare.h5py.File") @patch("policyengine_api.endpoints.economy.compare.pd.read_csv") def test__downloads_from_correct_repos( @@ -274,7 +284,9 @@ def test__downloads_from_correct_repos( mock_weights = np.ones((1, 10)) mock_h5py_context = MagicMock() - mock_h5py_context.__enter__ = MagicMock(return_value={"2025": mock_weights}) + mock_h5py_context.__enter__ = MagicMock( + return_value={"2025": mock_weights} + ) mock_h5py_context.__exit__ = MagicMock(return_value=False) mock_h5py_file.return_value = mock_h5py_context @@ -295,22 +307,32 @@ def test__downloads_from_correct_repos( # Verify correct repos are used calls = mock_download.call_args_list - assert calls[0][1]["repo"] == "policyengine/policyengine-uk-data-private" + assert ( + calls[0][1]["repo"] == "policyengine/policyengine-uk-data-private" + ) assert calls[0][1]["repo_filename"] == "local_authority_weights.h5" - assert calls[1][1]["repo"] == "policyengine/policyengine-uk-data-public" + assert ( + calls[1][1]["repo"] == "policyengine/policyengine-uk-data-public" + ) assert calls[1][1]["repo_filename"] == "local_authorities_2021.csv" def test__given_constituency_region__returns_none(self): """When simulating a constituency, local authority breakdown should not be computed.""" - result = uk_local_authority_breakdown({}, {}, "uk", "constituency/Aldershot") + result = uk_local_authority_breakdown( + {}, {}, "uk", "constituency/Aldershot" + ) assert result is None def test__given_constituency_region_with_code__returns_none(self): """When simulating a constituency by code, local authority breakdown should not be computed.""" - result = uk_local_authority_breakdown({}, {}, "uk", "constituency/E12345678") + result = uk_local_authority_breakdown( + {}, {}, "uk", "constituency/E12345678" + ) assert result is None - @patch("policyengine_api.endpoints.economy.compare.download_huggingface_dataset") + @patch( + "policyengine_api.endpoints.economy.compare.download_huggingface_dataset" + ) @patch("policyengine_api.endpoints.economy.compare.h5py.File") @patch("policyengine_api.endpoints.economy.compare.pd.read_csv") def test__given_specific_la_region__returns_only_that_la( @@ -324,7 +346,9 @@ def test__given_specific_la_region__returns_only_that_la( mock_weights = np.ones((3, 10)) mock_h5py_context = MagicMock() - mock_h5py_context.__enter__ = MagicMock(return_value={"2025": mock_weights}) + mock_h5py_context.__enter__ = MagicMock( + return_value={"2025": mock_weights} + ) mock_h5py_context.__exit__ = MagicMock(return_value=False) mock_h5py_file.return_value = mock_h5py_context @@ -351,7 +375,9 @@ def test__given_specific_la_region__returns_only_that_la( assert "Aberdeen City" not in result.by_local_authority assert "Isle of Anglesey" not in result.by_local_authority - @patch("policyengine_api.endpoints.economy.compare.download_huggingface_dataset") + @patch( + "policyengine_api.endpoints.economy.compare.download_huggingface_dataset" + ) @patch("policyengine_api.endpoints.economy.compare.h5py.File") @patch("policyengine_api.endpoints.economy.compare.pd.read_csv") def test__given_country_scotland_region__returns_only_scottish_las( @@ -365,7 +391,9 @@ def test__given_country_scotland_region__returns_only_scottish_las( mock_weights = np.ones((3, 10)) mock_h5py_context = MagicMock() - mock_h5py_context.__enter__ = MagicMock(return_value={"2025": mock_weights}) + mock_h5py_context.__enter__ = MagicMock( + return_value={"2025": mock_weights} + ) mock_h5py_context.__exit__ = MagicMock(return_value=False) mock_h5py_file.return_value = mock_h5py_context @@ -392,7 +420,9 @@ def test__given_country_scotland_region__returns_only_scottish_las( assert "Hartlepool" not in result.by_local_authority assert "Isle of Anglesey" not in result.by_local_authority - @patch("policyengine_api.endpoints.economy.compare.download_huggingface_dataset") + @patch( + "policyengine_api.endpoints.economy.compare.download_huggingface_dataset" + ) @patch("policyengine_api.endpoints.economy.compare.h5py.File") @patch("policyengine_api.endpoints.economy.compare.pd.read_csv") def test__given_uk_region__returns_all_las( @@ -406,7 +436,9 @@ def test__given_uk_region__returns_all_las( mock_weights = np.ones((3, 10)) mock_h5py_context = MagicMock() - mock_h5py_context.__enter__ = MagicMock(return_value={"2025": mock_weights}) + mock_h5py_context.__enter__ = MagicMock( + return_value={"2025": mock_weights} + ) mock_h5py_context.__exit__ = MagicMock(return_value=False) mock_h5py_file.return_value = mock_h5py_context @@ -431,7 +463,9 @@ def test__given_uk_region__returns_all_las( assert "Aberdeen City" in result.by_local_authority assert "Isle of Anglesey" in result.by_local_authority - @patch("policyengine_api.endpoints.economy.compare.download_huggingface_dataset") + @patch( + "policyengine_api.endpoints.economy.compare.download_huggingface_dataset" + ) @patch("policyengine_api.endpoints.economy.compare.h5py.File") @patch("policyengine_api.endpoints.economy.compare.pd.read_csv") def test__given_no_region__returns_all_las( @@ -445,7 +479,9 @@ def test__given_no_region__returns_all_las( mock_weights = np.ones((3, 10)) mock_h5py_context = MagicMock() - mock_h5py_context.__enter__ = MagicMock(return_value={"2025": mock_weights}) + mock_h5py_context.__enter__ = MagicMock( + return_value={"2025": mock_weights} + ) mock_h5py_context.__exit__ = MagicMock(return_value=False) mock_h5py_file.return_value = mock_h5py_context @@ -509,15 +545,21 @@ def test__given_non_uk_country_nigeria__returns_none(self): def test__given_local_authority_region__returns_none(self): """When simulating a local authority, constituency breakdown should not be computed.""" - result = uk_constituency_breakdown({}, {}, "uk", "local_authority/Leicester") + result = uk_constituency_breakdown( + {}, {}, "uk", "local_authority/Leicester" + ) assert result is None def test__given_local_authority_region_with_code__returns_none(self): """When simulating a local authority by code, constituency breakdown should not be computed.""" - result = uk_constituency_breakdown({}, {}, "uk", "local_authority/E06000016") + result = uk_constituency_breakdown( + {}, {}, "uk", "local_authority/E06000016" + ) assert result is None - @patch("policyengine_api.endpoints.economy.compare.download_huggingface_dataset") + @patch( + "policyengine_api.endpoints.economy.compare.download_huggingface_dataset" + ) @patch("policyengine_api.endpoints.economy.compare.h5py.File") @patch("policyengine_api.endpoints.economy.compare.pd.read_csv") def test__given_specific_constituency_region__returns_only_that_constituency( @@ -532,7 +574,9 @@ def test__given_specific_constituency_region__returns_only_that_constituency( # Create mock weights - 3 constituencies, 10 households mock_weights = np.ones((3, 10)) mock_h5py_context = MagicMock() - mock_h5py_context.__enter__ = MagicMock(return_value={"2025": mock_weights}) + mock_h5py_context.__enter__ = MagicMock( + return_value={"2025": mock_weights} + ) mock_h5py_context.__exit__ = MagicMock(return_value=False) mock_h5py_file.return_value = mock_h5py_context @@ -560,7 +604,9 @@ def test__given_specific_constituency_region__returns_only_that_constituency( assert "Edinburgh East" not in result.by_constituency assert "Cardiff South" not in result.by_constituency - @patch("policyengine_api.endpoints.economy.compare.download_huggingface_dataset") + @patch( + "policyengine_api.endpoints.economy.compare.download_huggingface_dataset" + ) @patch("policyengine_api.endpoints.economy.compare.h5py.File") @patch("policyengine_api.endpoints.economy.compare.pd.read_csv") def test__given_country_scotland_region__returns_only_scottish_constituencies( @@ -574,7 +620,9 @@ def test__given_country_scotland_region__returns_only_scottish_constituencies( mock_weights = np.ones((3, 10)) mock_h5py_context = MagicMock() - mock_h5py_context.__enter__ = MagicMock(return_value={"2025": mock_weights}) + mock_h5py_context.__enter__ = MagicMock( + return_value={"2025": mock_weights} + ) mock_h5py_context.__exit__ = MagicMock(return_value=False) mock_h5py_file.return_value = mock_h5py_context @@ -591,7 +639,9 @@ def test__given_country_scotland_region__returns_only_scottish_constituencies( baseline = {"household_net_income": np.array([1000.0] * 10)} reform = {"household_net_income": np.array([1050.0] * 10)} - result = uk_constituency_breakdown(baseline, reform, "uk", "country/scotland") + result = uk_constituency_breakdown( + baseline, reform, "uk", "country/scotland" + ) assert result is not None assert len(result.by_constituency) == 1 @@ -599,7 +649,9 @@ def test__given_country_scotland_region__returns_only_scottish_constituencies( assert "Aldershot" not in result.by_constituency assert "Cardiff South" not in result.by_constituency - @patch("policyengine_api.endpoints.economy.compare.download_huggingface_dataset") + @patch( + "policyengine_api.endpoints.economy.compare.download_huggingface_dataset" + ) @patch("policyengine_api.endpoints.economy.compare.h5py.File") @patch("policyengine_api.endpoints.economy.compare.pd.read_csv") def test__given_uk_region__returns_all_constituencies( @@ -613,7 +665,9 @@ def test__given_uk_region__returns_all_constituencies( mock_weights = np.ones((3, 10)) mock_h5py_context = MagicMock() - mock_h5py_context.__enter__ = MagicMock(return_value={"2025": mock_weights}) + mock_h5py_context.__enter__ = MagicMock( + return_value={"2025": mock_weights} + ) mock_h5py_context.__exit__ = MagicMock(return_value=False) mock_h5py_file.return_value = mock_h5py_context @@ -638,7 +692,9 @@ def test__given_uk_region__returns_all_constituencies( assert "Edinburgh East" in result.by_constituency assert "Cardiff South" in result.by_constituency - @patch("policyengine_api.endpoints.economy.compare.download_huggingface_dataset") + @patch( + "policyengine_api.endpoints.economy.compare.download_huggingface_dataset" + ) @patch("policyengine_api.endpoints.economy.compare.h5py.File") @patch("policyengine_api.endpoints.economy.compare.pd.read_csv") def test__given_no_region__returns_all_constituencies( @@ -652,7 +708,9 @@ def test__given_no_region__returns_all_constituencies( mock_weights = np.ones((3, 10)) mock_h5py_context = MagicMock() - mock_h5py_context.__enter__ = MagicMock(return_value={"2025": mock_weights}) + mock_h5py_context.__enter__ = MagicMock( + return_value={"2025": mock_weights} + ) mock_h5py_context.__exit__ = MagicMock(return_value=False) mock_h5py_file.return_value = mock_h5py_context diff --git a/tests/unit/libs/test_simulation_api_factory.py b/tests/unit/libs/test_simulation_api_factory.py index 43d5ea339..9e243197a 100644 --- a/tests/unit/libs/test_simulation_api_factory.py +++ b/tests/unit/libs/test_simulation_api_factory.py @@ -171,7 +171,9 @@ def test__given_use_modal_env_false__then_logs_gcp_selection( # Then mock_factory_logger.log_struct.assert_called() - call_args = mock_factory_logger.log_struct.call_args[0][0] + call_args = mock_factory_logger.log_struct.call_args[ + 0 + ][0] assert "GCP" in call_args["message"] class TestGCPCredentialsError: diff --git a/tests/unit/libs/test_simulation_api_modal.py b/tests/unit/libs/test_simulation_api_modal.py index 25704e63a..4ba7d0616 100644 --- a/tests/unit/libs/test_simulation_api_modal.py +++ b/tests/unit/libs/test_simulation_api_modal.py @@ -93,7 +93,9 @@ class TestSimulationAPIModal: class TestInit: - def test__given_env_var_set__then_uses_env_url(self, mock_httpx_client): + def test__given_env_var_set__then_uses_env_url( + self, mock_httpx_client + ): # Given with patch.dict( "os.environ", @@ -105,7 +107,9 @@ def test__given_env_var_set__then_uses_env_url(self, mock_httpx_client): # Then assert api.base_url == MOCK_MODAL_BASE_URL - def test__given_env_var_not_set__then_uses_default_url(self, mock_httpx_client): + def test__given_env_var_not_set__then_uses_default_url( + self, mock_httpx_client + ): # Given with patch.dict("os.environ", {}, clear=False): import os @@ -184,7 +188,9 @@ def test__given_network_error__then_raises_exception( mock_modal_logger, ): # Given - mock_httpx_client.post.side_effect = httpx.RequestError("Connection failed") + mock_httpx_client.post.side_effect = httpx.RequestError( + "Connection failed" + ) api = SimulationAPIModal() # When/Then @@ -272,7 +278,9 @@ def test__given_job_id__then_polls_correct_endpoint( class TestGetExecutionId: - def test__given_execution__then_returns_job_id(self, mock_httpx_client): + def test__given_execution__then_returns_job_id( + self, mock_httpx_client + ): # Given api = SimulationAPIModal() execution = ModalSimulationExecution( @@ -288,7 +296,9 @@ def test__given_execution__then_returns_job_id(self, mock_httpx_client): class TestGetExecutionStatus: - def test__given_execution__then_returns_status_string(self, mock_httpx_client): + def test__given_execution__then_returns_status_string( + self, mock_httpx_client + ): # Given api = SimulationAPIModal() execution = ModalSimulationExecution( @@ -376,7 +386,9 @@ def test__given_network_error__then_returns_false( self, mock_httpx_client, mock_modal_logger ): # Given - mock_httpx_client.get.side_effect = httpx.RequestError("Connection failed") + mock_httpx_client.get.side_effect = httpx.RequestError( + "Connection failed" + ) api = SimulationAPIModal() # When diff --git a/tests/unit/services/test_ai_analysis_service.py b/tests/unit/services/test_ai_analysis_service.py index 2ff182b5c..34810cc2b 100644 --- a/tests/unit/services/test_ai_analysis_service.py +++ b/tests/unit/services/test_ai_analysis_service.py @@ -33,7 +33,8 @@ def test_trigger_ai_analysis_given_successful_streaming( for i, chunk in enumerate(results): if i < len(text_chunks): expected_chunk = ( - json.dumps({"type": "text", "stream": text_chunks[i][:5]}) + "\n" + json.dumps({"type": "text", "stream": text_chunks[i][:5]}) + + "\n" ) assert chunk == expected_chunk diff --git a/tests/unit/services/test_economy_service.py b/tests/unit/services/test_economy_service.py index 025f490c9..1220c24b8 100644 --- a/tests/unit/services/test_economy_service.py +++ b/tests/unit/services/test_economy_service.py @@ -175,7 +175,9 @@ def test__given_no_previous_impact__creates_new_simulation( mock_datetime, mock_numpy_random, ): - mock_reform_impacts_service.get_all_reform_impacts.return_value = [] + mock_reform_impacts_service.get_all_reform_impacts.return_value = ( + [] + ) result = economy_service.get_economic_impact(**base_params) @@ -197,8 +199,8 @@ def test__given_exception__raises_error( mock_datetime, mock_numpy_random, ): - mock_reform_impacts_service.get_all_reform_impacts.side_effect = Exception( - "Database error" + mock_reform_impacts_service.get_all_reform_impacts.side_effect = ( + Exception("Database error") ) with pytest.raises(Exception) as exc_info: @@ -271,7 +273,9 @@ def test__given_existing_impacts__returns_first_impact( create_mock_reform_impact(), create_mock_reform_impact(), ] - mock_reform_impacts_service.get_all_reform_impacts.return_value = impacts + mock_reform_impacts_service.get_all_reform_impacts.return_value = ( + impacts + ) result = economy_service._get_most_recent_impact(setup_options) @@ -281,7 +285,9 @@ def test__given_no_impacts__returns_none( self, economy_service, setup_options, mock_reform_impacts_service ): # Arrange - mock_reform_impacts_service.get_all_reform_impacts.return_value = [] + mock_reform_impacts_service.get_all_reform_impacts.return_value = ( + [] + ) # Act result = economy_service._get_most_recent_impact(setup_options) @@ -314,7 +320,9 @@ def test__given_error_status__returns_completed(self, economy_service): assert result == ImpactAction.COMPLETED - def test__given_computing_status__returns_computing(self, economy_service): + def test__given_computing_status__returns_computing( + self, economy_service + ): impact = create_mock_reform_impact(status="computing") result = economy_service._determine_impact_action(impact) @@ -410,7 +418,9 @@ def test__given_unknown_state__raises_error( economy_service._handle_execution_state( setup_options, "UNKNOWN", reform_impact ) - assert "Unexpected sim API execution state: UNKNOWN" in str(exc_info.value) + assert "Unexpected sim API execution state: UNKNOWN" in str( + exc_info.value + ) # Modal status tests def test__given_modal_complete_state__then_returns_completed_result( @@ -480,7 +490,9 @@ def test__given_modal_failed_state_with_error_message__then_includes_error_in_me # Then assert result.status == ImpactStatus.ERROR # Verify the error message was passed to the service - call_args = mock_reform_impacts_service.set_error_reform_impact.call_args + call_args = ( + mock_reform_impacts_service.set_error_reform_impact.call_args + ) assert "Simulation timed out" in call_args[1]["message"] def test__given_modal_running_state__then_returns_computing_result( @@ -620,7 +632,9 @@ class TestSetupSimOptions: """ test_country_id = "us" - test_reform_policy = json.dumps({"sample_param": {"2024-01-01.2100-12-31": 15}}) + test_reform_policy = json.dumps( + {"sample_param": {"2024-01-01.2100-12-31": 15}} + ) test_current_law_baseline_policy = json.dumps({}) test_region = "us" test_time_period = 2025 @@ -649,13 +663,16 @@ def test__given_us_nationwide__returns_correct_sim_options(self): assert sim_options["time_period"] == self.test_time_period assert sim_options["region"] == "us" assert ( - sim_options["data"] == "gs://policyengine-us-data/enhanced_cps_2024.h5" + sim_options["data"] + == "gs://policyengine-us-data/enhanced_cps_2024.h5" ) def test__given_us_state_ca__returns_correct_sim_options(self): # Test with a normalized US state (prefixed format) country_id = "us" - reform_policy = json.dumps({"sample_param": {"2024-01-01.2100-12-31": 15}}) + reform_policy = json.dumps( + {"sample_param": {"2024-01-01.2100-12-31": 15}} + ) current_law_baseline_policy = json.dumps({}) region = "state/ca" # Pre-normalized time_period = 2025 @@ -675,15 +692,21 @@ def test__given_us_state_ca__returns_correct_sim_options(self): assert sim_options["country"] == country_id assert sim_options["scope"] == scope assert sim_options["reform"] == json.loads(reform_policy) - assert sim_options["baseline"] == json.loads(current_law_baseline_policy) + assert sim_options["baseline"] == json.loads( + current_law_baseline_policy + ) assert sim_options["time_period"] == time_period assert sim_options["region"] == "state/ca" - assert sim_options["data"] == "gs://policyengine-us-data/states/CA.h5" + assert ( + sim_options["data"] == "gs://policyengine-us-data/states/CA.h5" + ) def test__given_us_state_utah__returns_correct_sim_options(self): # Test with normalized Utah state country_id = "us" - reform_policy = json.dumps({"sample_param": {"2024-01-01.2100-12-31": 15}}) + reform_policy = json.dumps( + {"sample_param": {"2024-01-01.2100-12-31": 15}} + ) current_law_baseline_policy = json.dumps({}) region = "state/ut" # Pre-normalized time_period = 2025 @@ -703,14 +726,20 @@ def test__given_us_state_utah__returns_correct_sim_options(self): assert sim_options["country"] == country_id assert sim_options["scope"] == scope assert sim_options["reform"] == json.loads(reform_policy) - assert sim_options["baseline"] == json.loads(current_law_baseline_policy) + assert sim_options["baseline"] == json.loads( + current_law_baseline_policy + ) assert sim_options["time_period"] == time_period assert sim_options["region"] == "state/ut" - assert sim_options["data"] == "gs://policyengine-us-data/states/UT.h5" + assert ( + sim_options["data"] == "gs://policyengine-us-data/states/UT.h5" + ) def test__given_cliff_target__returns_correct_sim_options(self): country_id = "us" - reform_policy = json.dumps({"sample_param": {"2024-01-01.2100-12-31": 15}}) + reform_policy = json.dumps( + {"sample_param": {"2024-01-01.2100-12-31": 15}} + ) current_law_baseline_policy = json.dumps({}) region = "us" time_period = 2025 @@ -732,17 +761,22 @@ def test__given_cliff_target__returns_correct_sim_options(self): assert sim_options["country"] == country_id assert sim_options["scope"] == scope assert sim_options["reform"] == json.loads(reform_policy) - assert sim_options["baseline"] == json.loads(current_law_baseline_policy) + assert sim_options["baseline"] == json.loads( + current_law_baseline_policy + ) assert sim_options["time_period"] == time_period assert sim_options["region"] == region assert ( - sim_options["data"] == "gs://policyengine-us-data/enhanced_cps_2024.h5" + sim_options["data"] + == "gs://policyengine-us-data/enhanced_cps_2024.h5" ) assert sim_options["include_cliffs"] is True def test__given_uk__returns_correct_sim_options(self): country_id = "uk" - reform_policy = json.dumps({"sample_param": {"2024-01-01.2100-12-31": 15}}) + reform_policy = json.dumps( + {"sample_param": {"2024-01-01.2100-12-31": 15}} + ) current_law_baseline_policy = json.dumps({}) region = "uk" time_period = 2025 @@ -771,7 +805,9 @@ def test__given_congressional_district__returns_correct_sim_options( self, ): country_id = "us" - reform_policy = json.dumps({"sample_param": {"2024-01-01.2100-12-31": 15}}) + reform_policy = json.dumps( + {"sample_param": {"2024-01-01.2100-12-31": 15}} + ) current_law_baseline_policy = json.dumps({}) region = "congressional_district/CA-37" # Pre-normalized time_period = 2025 @@ -790,7 +826,10 @@ def test__given_congressional_district__returns_correct_sim_options( sim_options = sim_options_model.model_dump() assert sim_options["region"] == "congressional_district/CA-37" - assert sim_options["data"] == "gs://policyengine-us-data/districts/CA-37.h5" + assert ( + sim_options["data"] + == "gs://policyengine-us-data/districts/CA-37.h5" + ) class TestSetupRegion: """Tests for _setup_region method. @@ -823,14 +862,18 @@ def test__given_prefixed_state_tx__returns_unchanged(self): def test__given_congressional_district__returns_unchanged(self): service = EconomyService() - result = service._setup_region("us", "congressional_district/CA-37") + result = service._setup_region( + "us", "congressional_district/CA-37" + ) assert result == "congressional_district/CA-37" def test__given_lowercase_congressional_district__returns_unchanged( self, ): service = EconomyService() - result = service._setup_region("us", "congressional_district/ca-37") + result = service._setup_region( + "us", "congressional_district/ca-37" + ) assert result == "congressional_district/ca-37" def test__given_invalid_prefixed_state__raises_value_error(self): @@ -845,13 +888,17 @@ def test__given_invalid_congressional_district__raises_value_error( service = EconomyService() with pytest.raises(ValueError) as exc_info: service._setup_region("us", "congressional_district/cruft") - assert "Invalid congressional district: 'cruft'" in str(exc_info.value) + assert "Invalid congressional district: 'cruft'" in str( + exc_info.value + ) def test__given_invalid_prefix__raises_value_error(self): service = EconomyService() with pytest.raises(ValueError) as exc_info: service._setup_region("us", "invalid_prefix/tx") - assert "Invalid US region: 'invalid_prefix/tx'" in str(exc_info.value) + assert "Invalid US region: 'invalid_prefix/tx'" in str( + exc_info.value + ) def test__given_invalid_bare_value__raises_value_error(self): # Bare values without prefix are now invalid (should be normalized first) @@ -877,7 +924,9 @@ def test__given_us_city_nyc__returns_pooled_cps(self): # Test with normalized city/nyc format service = EconomyService() result = service._setup_data("us", "city/nyc") - assert result == "gs://policyengine-us-data/pooled_3_year_cps_2023.h5" + assert ( + result == "gs://policyengine-us-data/pooled_3_year_cps_2023.h5" + ) def test__given_us_state_ca__returns_state_dataset(self): # Test with US state - returns state-specific dataset @@ -907,7 +956,10 @@ def test__given_uk__returns_efrs_dataset(self): # Test with UK - returns enhanced FRS dataset service = EconomyService() result = service._setup_data("uk", "uk") - assert result == "gs://policyengine-uk-data-private/enhanced_frs_2023_24.h5" + assert ( + result + == "gs://policyengine-uk-data-private/enhanced_frs_2023_24.h5" + ) def test__given_invalid_country__raises_value_error(self, mock_logger): # Test with invalid country @@ -951,10 +1003,14 @@ def test__given_invalid_congressional_district__raises_value_error( service = EconomyService() with pytest.raises(ValueError) as exc_info: service._validate_us_region("congressional_district/CA-99") - assert "Invalid congressional district: 'CA-99'" in str(exc_info.value) + assert "Invalid congressional district: 'CA-99'" in str( + exc_info.value + ) def test__given_nonexistent_district__raises_value_error(self): service = EconomyService() with pytest.raises(ValueError) as exc_info: service._validate_us_region("congressional_district/cruft") - assert "Invalid congressional district: 'cruft'" in str(exc_info.value) + assert "Invalid congressional district: 'cruft'" in str( + exc_info.value + ) diff --git a/tests/unit/services/test_household_service.py b/tests/unit/services/test_household_service.py index a67abfdb2..9a3ccad6d 100644 --- a/tests/unit/services/test_household_service.py +++ b/tests/unit/services/test_household_service.py @@ -27,7 +27,9 @@ def test_get_household_given_existing_record( # GIVEN an existing record... (included as fixture) # WHEN we call get_household for this record... - result = service.get_household(valid_db_row["country_id"], valid_db_row["id"]) + result = service.get_household( + valid_db_row["country_id"], valid_db_row["id"] + ) valid_household_json = valid_request_body["data"] diff --git a/tests/unit/services/test_metadata_service.py b/tests/unit/services/test_metadata_service.py index 42c266399..70ea9262e 100644 --- a/tests/unit/services/test_metadata_service.py +++ b/tests/unit/services/test_metadata_service.py @@ -127,7 +127,9 @@ def test_verify_metadata_for_given_country( ("us", ["national", "state", "city", "congressional_district"]), ], ) - def test_verify_region_types_for_given_country(self, country_id, expected_types): + def test_verify_region_types_for_given_country( + self, country_id, expected_types + ): """ Verifies that all regions for UK and US have a 'type' field with valid values. @@ -137,7 +139,9 @@ def test_verify_region_types_for_given_country(self, country_id, expected_types) regions = metadata["economy_options"]["region"] for region in regions: - assert "type" in region, f"Region '{region['name']}' missing 'type' field" + assert ( + "type" in region + ), f"Region '{region['name']}' missing 'type' field" assert ( region["type"] in expected_types ), f"Region '{region['name']}' has invalid type '{region['type']}'" diff --git a/tests/unit/services/test_policy_service.py b/tests/unit/services/test_policy_service.py index b93814fca..4530dd9d5 100644 --- a/tests/unit/services/test_policy_service.py +++ b/tests/unit/services/test_policy_service.py @@ -16,7 +16,9 @@ class TestGetPolicy: - def test_get_policy_given_existing_record(self, test_db, existing_policy_record): + def test_get_policy_given_existing_record( + self, test_db, existing_policy_record + ): # GIVEN an existing record... (included as fixture) # WHEN we call get_policy for this record... @@ -41,7 +43,9 @@ def test_get_policy_given_nonexistent_record(self, test_db): # WHEN we call get_policy for a nonexistent record NO_SUCH_RECORD_ID = 999 - result = service.get_policy(valid_policy_data["country_id"], NO_SUCH_RECORD_ID) + result = service.get_policy( + valid_policy_data["country_id"], NO_SUCH_RECORD_ID + ) # THEN the result should be None assert result is None @@ -56,7 +60,9 @@ def test_get_policy_given_str_id(self): ): # WHEN we call get_policy with the invalid ID # THEN an exception should be raised - service.get_policy(valid_policy_data["country_id"], INVALID_RECORD_ID) + service.get_policy( + valid_policy_data["country_id"], INVALID_RECORD_ID + ) def test_get_policy_given_negative_int_id(self): # GIVEN an invalid ID @@ -68,14 +74,18 @@ def test_get_policy_given_negative_int_id(self): ): # WHEN we call get_policy with the invalid ID # THEN an exception should be raised - service.get_policy(valid_policy_data["country_id"], INVALID_RECORD_ID) + service.get_policy( + valid_policy_data["country_id"], INVALID_RECORD_ID + ) def test_get_policy_given_invalid_country_id(self): # GIVEN an invalid country_id INVALID_COUNTRY_ID = "xx" # Unsupported country code # WHEN we call get_policy with the invalid country_id - result = service.get_policy(INVALID_COUNTRY_ID, valid_policy_data["id"]) + result = service.get_policy( + INVALID_COUNTRY_ID, valid_policy_data["id"] + ) # THEN the result should be None or raise an exception assert result is None @@ -226,7 +236,9 @@ def test_set_policy_existing( existing_policy = existing_policy_record # Setup mock - mock_database.query.return_value.fetchone.return_value = existing_policy + mock_database.query.return_value.fetchone.return_value = ( + existing_policy + ) # Define expected database calls - matches actual implementation expected_calls = [ @@ -265,7 +277,9 @@ def test_set_policy_given_database_insert_failure( # Setup mock to raise exception on insert mock_database.query.return_value.fetchone.side_effect = [ None, # First call: policy does not exist - Exception("Database insertion failed"), # Second call: insertion fails + Exception( + "Database insertion failed" + ), # Second call: insertion fails ] # WHEN we call set_policy @@ -286,7 +300,9 @@ def test_set_policy_given_invalid_country_id(self, mock_hash_object): # THEN an exception should be raised service.set_policy(INVALID_COUNTRY_ID, test_label, test_policy) - def test_set_policy_given_empty_label(self, mock_database, mock_hash_object): + def test_set_policy_given_empty_label( + self, mock_database, mock_hash_object + ): # GIVEN an empty label EMPTY_LABEL = "" test_policy = {"param": "value"} diff --git a/tests/unit/services/test_report_output_service.py b/tests/unit/services/test_report_output_service.py index c1f6b3e55..15f6b8576 100644 --- a/tests/unit/services/test_report_output_service.py +++ b/tests/unit/services/test_report_output_service.py @@ -13,7 +13,9 @@ class TestFindExistingReportOutput: """Test finding existing report outputs in the database.""" - def test_find_existing_report_output_found(self, test_db, existing_report_record): + def test_find_existing_report_output_found( + self, test_db, existing_report_record + ): """Test finding an existing report output.""" # GIVEN an existing report record (from fixture) @@ -27,7 +29,10 @@ def test_find_existing_report_output_found(self, test_db, existing_report_record # THEN the result should contain the existing report assert result is not None assert result["id"] == existing_report_record["id"] - assert result["simulation_1_id"] == existing_report_record["simulation_1_id"] + assert ( + result["simulation_1_id"] + == existing_report_record["simulation_1_id"] + ) assert result["status"] == existing_report_record["status"] def test_find_existing_report_output_not_found(self, test_db): @@ -243,7 +248,10 @@ def test_get_report_output_existing(self, test_db, existing_report_record): # THEN the correct report should be returned assert result is not None assert result["id"] == existing_report_record["id"] - assert result["simulation_1_id"] == existing_report_record["simulation_1_id"] + assert ( + result["simulation_1_id"] + == existing_report_record["simulation_1_id"] + ) assert result["status"] == existing_report_record["status"] def test_get_report_output_nonexistent(self, test_db): @@ -327,15 +335,21 @@ def test_duplicate_report_returns_existing(self, test_db): # THEN the same report should be returned (no duplicate created) assert first_report["id"] == second_report["id"] assert first_report["country_id"] == second_report["country_id"] - assert first_report["simulation_1_id"] == second_report["simulation_1_id"] - assert first_report["simulation_2_id"] == second_report["simulation_2_id"] + assert ( + first_report["simulation_1_id"] == second_report["simulation_1_id"] + ) + assert ( + first_report["simulation_2_id"] == second_report["simulation_2_id"] + ) assert first_report["year"] == second_report["year"] class TestUpdateReportOutput: """Test updating report outputs in the database.""" - def test_update_report_output_to_complete(self, test_db, existing_report_record): + def test_update_report_output_to_complete( + self, test_db, existing_report_record + ): """Test updating a report to complete status with output.""" # GIVEN an existing pending report report_id = existing_report_record["id"] @@ -360,7 +374,9 @@ def test_update_report_output_to_complete(self, test_db, existing_report_record) assert result["status"] == "complete" assert result["output"] == test_output_json - def test_update_report_output_to_error(self, test_db, existing_report_record): + def test_update_report_output_to_error( + self, test_db, existing_report_record + ): """Test updating a report to error status with message.""" # GIVEN an existing pending report report_id = existing_report_record["id"] @@ -384,7 +400,9 @@ def test_update_report_output_to_error(self, test_db, existing_report_record): assert result["status"] == "error" assert result["error_message"] == error_msg - def test_update_report_output_partial_update(self, test_db, existing_report_record): + def test_update_report_output_partial_update( + self, test_db, existing_report_record + ): """Test that partial updates work correctly.""" # GIVEN an existing report report_id = existing_report_record["id"] @@ -406,7 +424,9 @@ def test_update_report_output_partial_update(self, test_db, existing_report_reco assert result["status"] == "complete" assert result["output"] is None # Should remain unchanged - def test_update_report_output_no_fields(self, test_db, existing_report_record): + def test_update_report_output_no_fields( + self, test_db, existing_report_record + ): """Test that update with no optional fields still updates API version.""" # GIVEN an existing report diff --git a/tests/unit/services/test_simulation_service.py b/tests/unit/services/test_simulation_service.py index ac1fbccf6..49c8654a3 100644 --- a/tests/unit/services/test_simulation_service.py +++ b/tests/unit/services/test_simulation_service.py @@ -31,7 +31,9 @@ def test_find_existing_simulation_given_existing_record( assert result is not None assert result["id"] == existing_simulation_record["id"] assert result["country_id"] == valid_simulation_data["country_id"] - assert result["population_id"] == valid_simulation_data["population_id"] + assert ( + result["population_id"] == valid_simulation_data["population_id"] + ) assert result["policy_id"] == valid_simulation_data["policy_id"] def test_find_existing_simulation_given_no_match(self, test_db): @@ -152,7 +154,9 @@ def test_create_simulation_retrieves_correct_id(self, test_db): class TestGetSimulation: """Test retrieving simulations from the database.""" - def test_get_simulation_existing(self, test_db, existing_simulation_record): + def test_get_simulation_existing( + self, test_db, existing_simulation_record + ): """Test retrieving an existing simulation.""" # GIVEN an existing simulation record @@ -177,7 +181,9 @@ def test_get_simulation_nonexistent(self, test_db): # THEN None should be returned assert result is None - def test_get_simulation_wrong_country(self, test_db, existing_simulation_record): + def test_get_simulation_wrong_country( + self, test_db, existing_simulation_record + ): """Test that simulations are country-specific.""" # GIVEN an existing simulation for 'us' @@ -228,6 +234,11 @@ def test_duplicate_simulation_returns_existing(self, test_db): # THEN the same simulation should be returned (no duplicate created) assert first_simulation["id"] == second_simulation["id"] - assert first_simulation["country_id"] == second_simulation["country_id"] - assert first_simulation["population_id"] == second_simulation["population_id"] + assert ( + first_simulation["country_id"] == second_simulation["country_id"] + ) + assert ( + first_simulation["population_id"] + == second_simulation["population_id"] + ) assert first_simulation["policy_id"] == second_simulation["policy_id"] diff --git a/tests/unit/services/test_tracer_analysis_service.py b/tests/unit/services/test_tracer_analysis_service.py index fd1ba8364..1e87c41a6 100644 --- a/tests/unit/services/test_tracer_analysis_service.py +++ b/tests/unit/services/test_tracer_analysis_service.py @@ -78,7 +78,9 @@ def test_tracer_output_for_empty_tracer(): valid_target_variable = "snap" # When: Extracting from an empty output - result = test_service._parse_tracer_output(empty_tracer, valid_target_variable) + result = test_service._parse_tracer_output( + empty_tracer, valid_target_variable + ) # Then: It should return an empty list since there is no data to parse expected_output = empty_tracer @@ -136,7 +138,9 @@ def test_tracer_output_for_variable_that_is_substring_of_another(): target_variable = "snap_net_income" # When: Extracting the segment for this variable - result = test_service._parse_tracer_output(valid_tracer_output, target_variable) + result = test_service._parse_tracer_output( + valid_tracer_output, target_variable + ) # Then: It should return only the exact match for "snap_net_income", not "snap_net_income_fpg_ratio" diff --git a/tests/unit/services/test_tracer_service.py b/tests/unit/services/test_tracer_service.py index 84ece8df3..e5436d476 100644 --- a/tests/unit/services/test_tracer_service.py +++ b/tests/unit/services/test_tracer_service.py @@ -58,4 +58,6 @@ def test_get_tracer_database_error(test_db): valid_api_version, ] with pytest.raises(Exception): - tracer_service.get_tracer(*missing_parameter_causing_database_exception) + tracer_service.get_tracer( + *missing_parameter_causing_database_exception + ) diff --git a/tests/unit/services/test_update_profile_service.py b/tests/unit/services/test_update_profile_service.py index f240ccac0..f9fd607b7 100644 --- a/tests/unit/services/test_update_profile_service.py +++ b/tests/unit/services/test_update_profile_service.py @@ -11,7 +11,9 @@ class TestUpdateProfile: - def test_update_profile_given_existing_record(self, test_db, existing_user_profile): + def test_update_profile_given_existing_record( + self, test_db, existing_user_profile + ): # GIVEN an existing profile record (from fixture) # WHEN we call update_profile with new data @@ -52,7 +54,9 @@ def test_update_profile_given_nonexistent_record(self, test_db): # THEN the result should be False assert result is False - def test_update_profile_with_partial_fields(self, test_db, existing_user_profile): + def test_update_profile_with_partial_fields( + self, test_db, existing_user_profile + ): # GIVEN an existing profile record (from fixture) # WHEN we call update_profile with only some fields provided @@ -89,7 +93,9 @@ def test_update_profile_with_database_error( def mock_db_query_error(*args, **kwargs): raise Exception("Database error") - monkeypatch.setattr("policyengine_api.data.database.query", mock_db_query_error) + monkeypatch.setattr( + "policyengine_api.data.database.query", mock_db_query_error + ) # WHEN we call update_profile # THEN an exception should be raised diff --git a/tests/unit/services/test_user_service.py b/tests/unit/services/test_user_service.py index 49072a34a..75fe4c834 100644 --- a/tests/unit/services/test_user_service.py +++ b/tests/unit/services/test_user_service.py @@ -33,7 +33,9 @@ def test_get_profile_nonexistent_record(self): def test_get_profile_auth0_id(self, existing_user_profile): # WHEN we call get_profile with auth0_id - result = service.get_profile(auth0_id=existing_user_profile["auth0_id"]) + result = service.get_profile( + auth0_id=existing_user_profile["auth0_id"] + ) # THEN returns record assert result == existing_user_profile diff --git a/tests/unit/test_country.py b/tests/unit/test_country.py index 55a1f7c70..b57e8ceee 100644 --- a/tests/unit/test_country.py +++ b/tests/unit/test_country.py @@ -30,7 +30,9 @@ def test__uk_has_360_local_authorities(self, uk_regions): ] assert len(local_authority_regions) == 360 - def test__local_authority_regions_have_correct_name_format(self, uk_regions): + def test__local_authority_regions_have_correct_name_format( + self, uk_regions + ): """Verify local authority region names have the correct prefix.""" local_authority_regions = [ r for r in uk_regions if r.get("type") == "local_authority" @@ -119,7 +121,9 @@ def test__coordinates_are_numeric(self, local_authorities_df): assert local_authorities_df["x"].dtype in ["float64", "int64"] assert local_authorities_df["y"].dtype in ["float64", "int64"] - def test__english_local_authorities_have_e_prefix(self, local_authorities_df): + def test__english_local_authorities_have_e_prefix( + self, local_authorities_df + ): """Verify English local authorities have E prefix codes.""" english_las = local_authorities_df[ local_authorities_df["code"].str.startswith("E") @@ -127,7 +131,9 @@ def test__english_local_authorities_have_e_prefix(self, local_authorities_df): # England has 296 local authorities (majority of the 360 total) assert len(english_las) == 296 - def test__scottish_local_authorities_have_s_prefix(self, local_authorities_df): + def test__scottish_local_authorities_have_s_prefix( + self, local_authorities_df + ): """Verify Scottish local authorities have S prefix codes.""" scottish_las = local_authorities_df[ local_authorities_df["code"].str.startswith("S") @@ -135,7 +141,9 @@ def test__scottish_local_authorities_have_s_prefix(self, local_authorities_df): # Scotland has 32 council areas assert len(scottish_las) == 32 - def test__welsh_local_authorities_have_w_prefix(self, local_authorities_df): + def test__welsh_local_authorities_have_w_prefix( + self, local_authorities_df + ): """Verify Welsh local authorities have W prefix codes.""" welsh_las = local_authorities_df[ local_authorities_df["code"].str.startswith("W")