diff --git a/changelog_entry.yaml b/changelog_entry.yaml index e69de29..1a814ec 100644 --- a/changelog_entry.yaml +++ b/changelog_entry.yaml @@ -0,0 +1,4 @@ +- bump: patch + changes: + changed: + - Filter nationwide subnational results when user is running sim at subnational level \ No newline at end of file diff --git a/policyengine/outputs/macro/comparison/calculate_economy_comparison.py b/policyengine/outputs/macro/comparison/calculate_economy_comparison.py index dd0123a..68be458 100644 --- a/policyengine/outputs/macro/comparison/calculate_economy_comparison.py +++ b/policyengine/outputs/macro/comparison/calculate_economy_comparison.py @@ -3,6 +3,10 @@ from microdf import MicroSeries import numpy as np from policyengine.utils.data_download import download +from policyengine.utils.uk_geography import ( + should_zero_constituency, + should_zero_local_authority, +) import pandas as pd import h5py from pydantic import BaseModel @@ -692,7 +696,10 @@ class UKConstituencyBreakdownWithValues(BaseModel): def uk_constituency_breakdown( - baseline: SingleEconomy, reform: SingleEconomy, country_id: str + baseline: SingleEconomy, + reform: SingleEconomy, + country_id: str, + region: str | None = None, ) -> UKConstituencyBreakdown: if country_id != "uk": return None @@ -701,8 +708,8 @@ def uk_constituency_breakdown( "by_constituency": {}, "outcomes_by_region": {}, } - for region in ["uk", "england", "scotland", "wales", "northern_ireland"]: - output["outcomes_by_region"][region] = { + for region_ in ["uk", "england", "scotland", "wales", "northern_ireland"]: + output["outcomes_by_region"][region_] = { "Gain more than 5%": 0, "Gain less than 5%": 0, "No change": 0, @@ -732,15 +739,21 @@ def uk_constituency_breakdown( for i in range(len(constituency_names)): name: str = constituency_names.iloc[i]["name"] code: str = constituency_names.iloc[i]["code"] - weight: np.ndarray = weights[i] - baseline_income = MicroSeries(baseline_hnet, weights=weight) - reform_income = MicroSeries(reform_hnet, weights=weight) - average_household_income_change: float = ( - reform_income.sum() - baseline_income.sum() - ) / baseline_income.count() - percent_household_income_change: float = ( - reform_income.sum() / baseline_income.sum() - 1 - ) + + if should_zero_constituency(region, code, name): + average_household_income_change = 0.0 + percent_household_income_change = 0.0 + else: + weight: np.ndarray = weights[i] + baseline_income = MicroSeries(baseline_hnet, weights=weight) + reform_income = MicroSeries(reform_hnet, weights=weight) + average_household_income_change = ( + reform_income.sum() - baseline_income.sum() + ) / baseline_income.count() + percent_household_income_change = ( + reform_income.sum() / baseline_income.sum() - 1 + ) + output["by_constituency"][name] = { "average_household_income_change": average_household_income_change, "relative_household_income_change": percent_household_income_change, @@ -748,29 +761,31 @@ def uk_constituency_breakdown( "y": int(constituency_names.iloc[i]["y"]), } - regions = ["uk"] - if "E" in code: - regions.append("england") - elif "S" in code: - regions.append("scotland") - elif "W" in code: - regions.append("wales") - elif "N" in code: - regions.append("northern_ireland") - - if percent_household_income_change > 0.05: - bucket = "Gain more than 5%" - elif percent_household_income_change > 1e-3: - bucket = "Gain less than 5%" - elif percent_household_income_change > -1e-3: - bucket = "No change" - elif percent_household_income_change > -0.05: - bucket = "Lose less than 5%" - else: - bucket = "Lose more than 5%" + # Only count non-zeroed constituencies in outcomes_by_region + if not should_zero_constituency(region, code, name): + regions = ["uk"] + if "E" in code: + regions.append("england") + elif "S" in code: + regions.append("scotland") + elif "W" in code: + regions.append("wales") + elif "N" in code: + regions.append("northern_ireland") + + if percent_household_income_change > 0.05: + bucket = "Gain more than 5%" + elif percent_household_income_change > 1e-3: + bucket = "Gain less than 5%" + elif percent_household_income_change > -1e-3: + bucket = "No change" + elif percent_household_income_change > -0.05: + bucket = "Lose less than 5%" + else: + bucket = "Lose more than 5%" - for region_ in regions: - output["outcomes_by_region"][region_][bucket] += 1 + for region_ in regions: + output["outcomes_by_region"][region_][bucket] += 1 return UKConstituencyBreakdownWithValues(**output) @@ -790,7 +805,10 @@ class UKLocalAuthorityBreakdownWithValues(BaseModel): def uk_local_authority_breakdown( - baseline: SingleEconomy, reform: SingleEconomy, country_id: str + baseline: SingleEconomy, + reform: SingleEconomy, + country_id: str, + region: str | None = None, ) -> UKLocalAuthorityBreakdown: if country_id != "uk": return None @@ -822,15 +840,21 @@ def uk_local_authority_breakdown( for i in range(len(local_authority_names)): name: str = local_authority_names.iloc[i]["name"] code: str = local_authority_names.iloc[i]["code"] - weight: np.ndarray = weights[i] - baseline_income = MicroSeries(baseline_hnet, weights=weight) - reform_income = MicroSeries(reform_hnet, weights=weight) - average_household_income_change: float = ( - reform_income.sum() - baseline_income.sum() - ) / baseline_income.count() - percent_household_income_change: float = ( - reform_income.sum() / baseline_income.sum() - 1 - ) + + if should_zero_local_authority(region, code, name): + average_household_income_change = 0.0 + percent_household_income_change = 0.0 + else: + weight: np.ndarray = weights[i] + baseline_income = MicroSeries(baseline_hnet, weights=weight) + reform_income = MicroSeries(reform_hnet, weights=weight) + average_household_income_change = ( + reform_income.sum() - baseline_income.sum() + ) / baseline_income.count() + percent_household_income_change = ( + reform_income.sum() / baseline_income.sum() - 1 + ) + output["by_local_authority"][name] = { "average_household_income_change": average_household_income_change, "relative_household_income_change": percent_household_income_change, @@ -840,8 +864,6 @@ def uk_local_authority_breakdown( "y": int(local_authority_names.iloc[i]["y"]), } - # Note: Country-level aggregation and bucketing logic removed for local authorities - return UKLocalAuthorityBreakdownWithValues(**output) @@ -901,10 +923,12 @@ def calculate_economy_comparison( intra_decile_impact_data = intra_decile_impact(baseline, reform) labor_supply_response_data = labor_supply_response(baseline, reform) constituency_impact_data: UKConstituencyBreakdown = ( - uk_constituency_breakdown(baseline, reform, country_id) + uk_constituency_breakdown(baseline, reform, country_id, options.region) ) local_authority_impact_data: UKLocalAuthorityBreakdown = ( - uk_local_authority_breakdown(baseline, reform, country_id) + uk_local_authority_breakdown( + baseline, reform, country_id, options.region + ) ) wealth_decile_impact_data = wealth_decile_impact( baseline, reform, country_id diff --git a/policyengine/utils/uk_geography.py b/policyengine/utils/uk_geography.py new file mode 100644 index 0000000..caff9f2 --- /dev/null +++ b/policyengine/utils/uk_geography.py @@ -0,0 +1,90 @@ +"""Utilities for UK geographic region filtering.""" + +from typing import Literal + +UKRegionType = Literal["uk", "country", "constituency", "local_authority"] + +UK_REGION_TYPES: tuple[UKRegionType, ...] = ( + "uk", + "country", + "constituency", + "local_authority", +) + + +def determine_uk_region_type(region: str | None) -> UKRegionType: + """ + Determine the type of UK region from a region string. + + Args: + region: A region string (e.g., "country/scotland", "constituency/Aberdeen North", + "local_authority/leicester") or None. + + Returns: + One of "uk", "country", "constituency", or "local_authority". + + Raises: + ValueError: If the region prefix is not a valid UK region type. + """ + if region is None: + return "uk" + + prefix = region.split("/")[0] + if prefix not in UK_REGION_TYPES: + raise ValueError( + f"Invalid UK region type: '{prefix}'. " + f"Expected one of: {list(UK_REGION_TYPES)}" + ) + + return prefix + + +def get_country_from_code(code: str) -> str | None: + """Get country name from geographic code prefix (E, S, W, N).""" + prefix_map = { + "E": "england", + "S": "scotland", + "W": "wales", + "N": "northern_ireland", + } + return prefix_map.get(code[0]) + + +def should_zero_constituency(region: str | None, code: str, name: str) -> bool: + """Return True if this constituency's impacts should be zeroed out.""" + region_type = determine_uk_region_type(region) + + if region_type == "uk": + return False + # region is guaranteed to be non-None for non-uk region types + assert region is not None + if region_type == "country": + target = region.split("/")[1] + return get_country_from_code(code) != target + if region_type == "constituency": + target = region.split("/")[1] + return code != target and name != target + if region_type == "local_authority": + return True + return False + + +def should_zero_local_authority( + region: str | None, code: str, name: str +) -> bool: + """Return True if this local authority's impacts should be zeroed out.""" + region_type = determine_uk_region_type(region) + + if region_type == "uk": + return False + # region is guaranteed to be non-None for non-uk region types + assert region is not None + if region_type == "country": + target = region.split("/")[1] + return get_country_from_code(code) != target + if region_type == "local_authority": + target = region.split("/")[1] + return code != target and name != target + if region_type == "constituency": + return True + return False diff --git a/tests/utils/test_uk_geography.py b/tests/utils/test_uk_geography.py new file mode 100644 index 0000000..5820dd1 --- /dev/null +++ b/tests/utils/test_uk_geography.py @@ -0,0 +1,333 @@ +"""Tests for UK geographic region filtering utilities.""" + +import pytest + +from policyengine.utils.uk_geography import ( + UK_REGION_TYPES, + UKRegionType, + determine_uk_region_type, + get_country_from_code, + should_zero_constituency, + should_zero_local_authority, +) + + +class TestUKRegionTypes: + """Tests for UK_REGION_TYPES constant.""" + + def test_contains_expected_types(self): + """Test that UK_REGION_TYPES contains all expected region types.""" + assert "uk" in UK_REGION_TYPES + assert "country" in UK_REGION_TYPES + assert "constituency" in UK_REGION_TYPES + assert "local_authority" in UK_REGION_TYPES + + def test_has_four_types(self): + """Test that UK_REGION_TYPES has exactly four types.""" + assert len(UK_REGION_TYPES) == 4 + + +class TestDetermineUKRegionType: + """Tests for determine_uk_region_type function.""" + + def test_none_returns_uk(self): + """Test that None region returns 'uk' type.""" + result = determine_uk_region_type(None) + assert result == "uk" + + def test_country_region(self): + """Test parsing country region strings.""" + assert determine_uk_region_type("country/scotland") == "country" + assert determine_uk_region_type("country/england") == "country" + assert determine_uk_region_type("country/wales") == "country" + assert ( + determine_uk_region_type("country/northern_ireland") == "country" + ) + + def test_constituency_region(self): + """Test parsing constituency region strings.""" + assert ( + determine_uk_region_type("constituency/Aberdeen North") + == "constituency" + ) + assert ( + determine_uk_region_type("constituency/S14000001") + == "constituency" + ) + + def test_local_authority_region(self): + """Test parsing local authority region strings.""" + assert ( + determine_uk_region_type("local_authority/leicester") + == "local_authority" + ) + assert ( + determine_uk_region_type("local_authority/E06000016") + == "local_authority" + ) + + def test_invalid_region_raises_error(self): + """Test that invalid region prefixes raise ValueError.""" + with pytest.raises(ValueError) as exc_info: + determine_uk_region_type("invalid/test") + assert "Invalid UK region type: 'invalid'" in str(exc_info.value) + assert "Expected one of:" in str(exc_info.value) + + def test_empty_prefix_raises_error(self): + """Test that empty prefix raises ValueError.""" + with pytest.raises(ValueError) as exc_info: + determine_uk_region_type("/test") + assert "Invalid UK region type: ''" in str(exc_info.value) + + def test_unknown_prefix_raises_error(self): + """Test that unknown prefixes raise ValueError.""" + with pytest.raises(ValueError): + determine_uk_region_type("state/california") + with pytest.raises(ValueError): + determine_uk_region_type("region/north_west") + + +class TestGetCountryFromCode: + """Tests for get_country_from_code function.""" + + def test_english_code(self): + """Test that E prefix returns england.""" + assert get_country_from_code("E14000001") == "england" + assert get_country_from_code("E06000016") == "england" + + def test_scottish_code(self): + """Test that S prefix returns scotland.""" + assert get_country_from_code("S14000001") == "scotland" + assert get_country_from_code("S12000033") == "scotland" + + def test_welsh_code(self): + """Test that W prefix returns wales.""" + assert get_country_from_code("W07000041") == "wales" + + def test_northern_ireland_code(self): + """Test that N prefix returns northern_ireland.""" + assert get_country_from_code("N06000001") == "northern_ireland" + + def test_unknown_code_returns_none(self): + """Test that unknown code prefixes return None.""" + assert get_country_from_code("X12345678") is None + assert get_country_from_code("123456789") is None + + +class TestShouldZeroConstituency: + """Tests for should_zero_constituency function.""" + + def test_uk_wide_never_zeros(self): + """Test that UK-wide (None) region never zeros any constituency.""" + assert ( + should_zero_constituency(None, "S14000001", "Aberdeen North") + is False + ) + assert ( + should_zero_constituency(None, "E14000001", "Some English") + is False + ) + assert ( + should_zero_constituency(None, "W07000001", "Some Welsh") is False + ) + + def test_country_filter_keeps_matching_country(self): + """Test that country filter keeps constituencies in that country.""" + # Scottish constituency in scotland filter -> keep + assert ( + should_zero_constituency( + "country/scotland", "S14000001", "Aberdeen North" + ) + is False + ) + # English constituency in england filter -> keep + assert ( + should_zero_constituency( + "country/england", "E14000001", "Some English" + ) + is False + ) + + def test_country_filter_zeros_other_countries(self): + """Test that country filter zeros constituencies in other countries.""" + # English constituency in scotland filter -> zero + assert ( + should_zero_constituency( + "country/scotland", "E14000001", "Some English" + ) + is True + ) + # Scottish constituency in england filter -> zero + assert ( + should_zero_constituency( + "country/england", "S14000001", "Aberdeen North" + ) + is True + ) + # Welsh constituency in scotland filter -> zero + assert ( + should_zero_constituency( + "country/scotland", "W07000001", "Some Welsh" + ) + is True + ) + + def test_constituency_filter_keeps_matching_by_code(self): + """Test that constituency filter keeps matching constituency by code.""" + assert ( + should_zero_constituency( + "constituency/S14000001", "S14000001", "Aberdeen North" + ) + is False + ) + + def test_constituency_filter_keeps_matching_by_name(self): + """Test that constituency filter keeps matching constituency by name.""" + assert ( + should_zero_constituency( + "constituency/Aberdeen North", "S14000001", "Aberdeen North" + ) + is False + ) + + def test_constituency_filter_zeros_non_matching(self): + """Test that constituency filter zeros non-matching constituencies.""" + assert ( + should_zero_constituency( + "constituency/Aberdeen North", "S14000002", "Aberdeen South" + ) + is True + ) + assert ( + should_zero_constituency( + "constituency/S14000001", "E14000001", "Some English" + ) + is True + ) + + def test_local_authority_filter_zeros_all_constituencies(self): + """Test that local authority filter zeros all constituencies.""" + assert ( + should_zero_constituency( + "local_authority/leicester", "S14000001", "Aberdeen North" + ) + is True + ) + assert ( + should_zero_constituency( + "local_authority/leicester", "E14000001", "Some English" + ) + is True + ) + assert ( + should_zero_constituency( + "local_authority/E06000016", "W07000001", "Some Welsh" + ) + is True + ) + + +class TestShouldZeroLocalAuthority: + """Tests for should_zero_local_authority function.""" + + def test_uk_wide_never_zeros(self): + """Test that UK-wide (None) region never zeros any local authority.""" + assert ( + should_zero_local_authority(None, "E06000016", "Leicester") + is False + ) + assert ( + should_zero_local_authority(None, "S12000033", "Aberdeen") is False + ) + assert ( + should_zero_local_authority(None, "W06000001", "Some Welsh LA") + is False + ) + + def test_country_filter_keeps_matching_country(self): + """Test that country filter keeps local authorities in that country.""" + # English LA in england filter -> keep + assert ( + should_zero_local_authority( + "country/england", "E06000016", "Leicester" + ) + is False + ) + # Scottish LA in scotland filter -> keep + assert ( + should_zero_local_authority( + "country/scotland", "S12000033", "Aberdeen" + ) + is False + ) + + def test_country_filter_zeros_other_countries(self): + """Test that country filter zeros local authorities in other countries.""" + # Scottish LA in england filter -> zero + assert ( + should_zero_local_authority( + "country/england", "S12000033", "Aberdeen" + ) + is True + ) + # English LA in scotland filter -> zero + assert ( + should_zero_local_authority( + "country/scotland", "E06000016", "Leicester" + ) + is True + ) + + def test_local_authority_filter_keeps_matching_by_code(self): + """Test that local authority filter keeps matching LA by code.""" + assert ( + should_zero_local_authority( + "local_authority/E06000016", "E06000016", "Leicester" + ) + is False + ) + + def test_local_authority_filter_keeps_matching_by_name(self): + """Test that local authority filter keeps matching LA by name.""" + assert ( + should_zero_local_authority( + "local_authority/Leicester", "E06000016", "Leicester" + ) + is False + ) + + def test_local_authority_filter_zeros_non_matching(self): + """Test that local authority filter zeros non-matching local authorities.""" + assert ( + should_zero_local_authority( + "local_authority/Leicester", "E06000017", "Rutland" + ) + is True + ) + assert ( + should_zero_local_authority( + "local_authority/E06000016", "S12000033", "Aberdeen" + ) + is True + ) + + def test_constituency_filter_zeros_all_local_authorities(self): + """Test that constituency filter zeros all local authorities.""" + assert ( + should_zero_local_authority( + "constituency/Aberdeen North", "E06000016", "Leicester" + ) + is True + ) + assert ( + should_zero_local_authority( + "constituency/Aberdeen North", "S12000033", "Aberdeen" + ) + is True + ) + assert ( + should_zero_local_authority( + "constituency/S14000001", "W06000001", "Some Welsh LA" + ) + is True + )