Skip to content

Commit 35b7237

Browse files
author
Michael Smit
committed
DO NOT MERGE.
1 parent 882b614 commit 35b7237

File tree

7 files changed

+212
-73
lines changed

7 files changed

+212
-73
lines changed

changelog_entry.yaml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
- bump: patch
2+
changes:
3+
fixed:
4+
- TODO

policyengine/outputs/macro/single/calculate_average_earnings.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
1-
from policyengine import Simulation
1+
from policyengine.simulation_results import MacroContext
22

33

4-
def calculate_average_earnings(simulation: Simulation) -> float:
4+
def calculate_average_earnings(simulation: MacroContext) -> float:
55
"""Calculate average earnings."""
66
employment_income = simulation.baseline_simulation.calculate(
77
"employment_income"

policyengine/outputs/macro/single/calculate_single_economy.py

Lines changed: 23 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,11 @@
1313
from typing import Literal
1414
from microdf import MicroSeries
1515

16+
from policyengine.simulation_results import (
17+
AbstractSimulationResults,
18+
MacroContext,
19+
)
20+
1621

1722
class SingleEconomy(BaseModel):
1823
total_net_income: float
@@ -78,7 +83,7 @@ class UKPrograms:
7883

7984

8085
class GeneralEconomyTask:
81-
def __init__(self, simulation: Microsimulation, country_id: str):
86+
def __init__(self, simulation: AbstractSimulationResults, country_id: str):
8287
self.simulation = simulation
8388
self.country_id = country_id
8489
self.household_count_people = self.simulation.calculate(
@@ -288,15 +293,11 @@ def calculate_labor_supply_responses(self):
288293
return result
289294

290295
def _has_behavioral_response(self) -> bool:
291-
return (
296+
return self.simulation.variable_exists(
292297
"employment_income_behavioral_response"
293-
in self.simulation.tax_benefit_system.variables
294-
and any(
295-
self.simulation.calculate(
296-
"employment_income_behavioral_response"
297-
)
298-
!= 0
299-
)
298+
) and any(
299+
self.simulation.calculate("employment_income_behavioral_response")
300+
!= 0
300301
)
301302

302303
def calculate_lsr_working_hours(self):
@@ -332,8 +333,8 @@ def calculate_uk_programs(self) -> Dict[str, float]:
332333
}
333334

334335
def calculate_cliffs(self):
335-
cliff_gap: MicroSeries = self.simulation.calculate("cliff_gap")
336-
is_on_cliff: MicroSeries = self.simulation.calculate("is_on_cliff")
336+
cliff_gap: Series = self.simulation.calculate("cliff_gap")
337+
is_on_cliff: Series = self.simulation.calculate("is_on_cliff")
337338
total_cliff_gap: float = cliff_gap.sum()
338339
total_adults: float = self.simulation.calculate("is_adult").sum()
339340
cliff_share: float = is_on_cliff.sum() / total_adults
@@ -349,15 +350,20 @@ class CliffImpactInSimulation(BaseModel):
349350

350351

351352
def calculate_single_economy(
352-
simulation: Simulation, reform: bool = False
353+
simulation: MacroContext, reform: bool = False
353354
) -> Dict:
354355
include_cliffs = simulation.options.include_cliffs
356+
country_simulation = (
357+
simulation.baseline_simulation
358+
if not reform
359+
else simulation.reform_simulation
360+
)
361+
if country_simulation is None:
362+
raise ValueError(
363+
"Simulation data is not available for the specified context."
364+
)
355365
task_manager = GeneralEconomyTask(
356-
(
357-
simulation.baseline_simulation
358-
if not reform
359-
else simulation.reform_simulation
360-
),
366+
country_simulation,
361367
simulation.options.country,
362368
)
363369
country_id = simulation.options.country
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
from .simulation import Simulation as Simulation
2+
from .simulation_options import SimulationOptions as SimulationOptions
Lines changed: 89 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,26 @@
11
"""Simulate tax-benefit policy and derive society-level output statistics."""
22

3+
from copy import deepcopy
34
import sys
45
from pydantic import BaseModel, Field
56
from typing import Literal
6-
from .utils.data.datasets import (
7+
8+
from .simulation_options import (
9+
CountryType,
10+
DataType,
11+
ReformType,
12+
RegionType,
13+
ScopeType,
14+
SimulationOptions,
15+
SubsampleType,
16+
TimePeriodType,
17+
)
18+
19+
from policyengine.simulation_results import (
20+
AbstractSimulationResults,
21+
MacroContext,
22+
)
23+
from policyengine.utils.data.datasets import (
724
get_default_dataset,
825
process_gs_path,
926
POLICYENGINE_DATASETS,
@@ -13,8 +30,8 @@
1330
from policyengine_core.simulations import (
1431
Microsimulation as CountryMicrosimulation,
1532
)
16-
from .utils.reforms import ParametricReform
17-
from policyengine_core.reforms import Reform as StructuralReform
33+
from policyengine.utils.reforms import ParametricReform
34+
1835
from policyengine_core.data import Dataset
1936
from policyengine_us import (
2037
Simulation as USSimulation,
@@ -37,54 +54,6 @@
3754

3855
logger = logging.getLogger(__file__)
3956

40-
CountryType = Literal["uk", "us"]
41-
ScopeType = Literal["household", "macro"]
42-
DataType = (
43-
str | dict[Any, Any] | Any | None
44-
) # Needs stricter typing. Any==policyengine_core.data.Dataset, but pydantic refuses for some reason.
45-
TimePeriodType = int
46-
ReformType = ParametricReform | Type[StructuralReform] | None
47-
RegionType = Optional[str]
48-
SubsampleType = Optional[int]
49-
50-
51-
class SimulationOptions(BaseModel):
52-
country: CountryType = Field(..., description="The country to simulate.")
53-
scope: ScopeType = Field(..., description="The scope of the simulation.")
54-
data: DataType = Field(None, description="The data to simulate.")
55-
time_period: TimePeriodType = Field(
56-
2025, description="The time period to simulate."
57-
)
58-
reform: ReformType = Field(None, description="The reform to simulate.")
59-
baseline: ReformType = Field(None, description="The baseline to simulate.")
60-
region: RegionType = Field(
61-
None, description="The region to simulate within the country."
62-
)
63-
subsample: SubsampleType = Field(
64-
None,
65-
description="How many, if a subsample, households to randomly simulate.",
66-
)
67-
title: Optional[str] = Field(
68-
"[Analysis title]",
69-
description="The title of the analysis (for charts). If not provided, a default title will be generated.",
70-
)
71-
include_cliffs: Optional[bool] = Field(
72-
False,
73-
description="Whether to include tax-benefit cliffs in the simulation analyses. If True, cliffs will be included.",
74-
)
75-
model_version: Optional[str] = Field(
76-
None,
77-
description="The version of the country model used in the simulation. If not provided, the current package version will be used. If provided, this package will throw an error if the package version does not match. Use this as an extra safety check.",
78-
)
79-
data_version: Optional[str] = Field(
80-
None,
81-
description="The version of the data used in the simulation. If not provided, the current data version will be used. If provided, this package will throw an error if the data version does not match. Use this as an extra safety check.",
82-
)
83-
84-
model_config = {
85-
"arbitrary_types_allowed": True,
86-
}
87-
8857

8958
class Simulation:
9059
"""Simulate tax-benefit policy and derive society-level output statistics."""
@@ -98,9 +67,10 @@ class Simulation:
9867
data_version: Optional[str] = None
9968
"""The version of the data used in the simulation."""
10069
model_version: Optional[str] = None
70+
options: SimulationOptions
10171

102-
def __init__(self, **options: SimulationOptions):
103-
self.options = SimulationOptions(**options)
72+
def __init__(self, **kwargs):
73+
self.options = SimulationOptions.model_validate(kwargs)
10474
self.check_model_version()
10575
if not isinstance(self.options.data, dict) and not isinstance(
10676
self.options.data, Dataset
@@ -115,7 +85,8 @@ def __init__(self, **options: SimulationOptions):
11585
logging.info("Output functions loaded")
11686

11787
def _add_output_functions(self):
118-
folder = Path(__file__).parent / "outputs"
88+
logger.debug("Adding output functions to simulation")
89+
folder = Path(__file__).parent.parent / "outputs"
11990

12091
for module in folder.glob("**/*.py"):
12192
if module.stem == "__init__":
@@ -128,13 +99,18 @@ def _add_output_functions(self):
12899
)
129100
module = importlib.import_module("policyengine." + python_module)
130101
for name in dir(module):
102+
logging.debug(f"Looking for modules in {python_module}.{name}")
131103
func = getattr(module, name)
132104
if isinstance(func, Callable):
105+
logging.debug(f"Found function {name} in {python_module}")
133106
if hasattr(func, "__annotations__"):
134107
if (
135108
func.__annotations__.get("simulation")
136109
== Simulation
137110
):
111+
logging.info(
112+
f"Function {name} is an old macro function"
113+
)
138114
wrapped_func = wraps(func)(
139115
partial(func, simulation=self)
140116
)
@@ -144,6 +120,28 @@ def _add_output_functions(self):
144120
func.__name__,
145121
wrapped_func,
146122
)
123+
elif (
124+
func.__annotations__.get("simulation")
125+
== MacroContext
126+
):
127+
logging.info(
128+
f"Function {name} is a new macro function"
129+
)
130+
wrapped_func = wraps(func)(
131+
partial(
132+
func, simulation=self
133+
) # _macro_context(self))
134+
)
135+
wrapped_func.__annotations__ = func.__annotations__
136+
setattr(
137+
self,
138+
func.__name__,
139+
wrapped_func,
140+
)
141+
else:
142+
logging.debug(
143+
f"Function {name} is not a macro function, skipping"
144+
)
147145

148146
def _set_data(self, file_address: str | None = None) -> None:
149147

@@ -410,3 +408,40 @@ def _set_data_from_gs(self, file_address: str) -> tuple[str, str | None]:
410408
)
411409

412410
return filename, version
411+
412+
413+
class SimpleSimulationResults(AbstractSimulationResults):
414+
def __init__(self, simulation: CountrySimulation):
415+
self._country_simulation = simulation
416+
417+
def calculate(
418+
self,
419+
variable_name: str,
420+
period: pd.Period | None = None,
421+
map_to: str | None = None,
422+
decode_enums: bool = False,
423+
) -> pd.Series:
424+
"""
425+
Calculate a variable from the simulation results.
426+
"""
427+
return self._country_simulation.calculate(
428+
variable_name, period=period, map_to=map_to, decode_enums=decode_enums # type: ignore
429+
)
430+
431+
def variable_exists(self, variable_name: str) -> bool:
432+
return (
433+
variable_name
434+
in self._country_simulation.tax_benefit_system.variables
435+
)
436+
437+
438+
def _macro_context(simulation: Simulation):
439+
return MacroContext(
440+
simulation.options,
441+
SimpleSimulationResults(simulation.baseline_simulation),
442+
(
443+
SimpleSimulationResults(simulation.reform_simulation)
444+
if simulation.reform_simulation is not None
445+
else None
446+
),
447+
)
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
from typing import Any, Literal, Optional, Type
2+
3+
from pydantic import BaseModel, Field
4+
5+
from policyengine.utils.reforms import ParametricReform
6+
from policyengine_core.reforms import Reform as StructuralReform
7+
8+
9+
CountryType = Literal["uk", "us"]
10+
ScopeType = Literal["household", "macro"]
11+
DataType = (
12+
str | dict[Any, Any] | Any | None
13+
) # Needs stricter typing. Any==policyengine_core.data.Dataset, but pydantic refuses for some reason.
14+
TimePeriodType = int
15+
ReformType = ParametricReform | Type[StructuralReform] | None
16+
RegionType = Optional[str]
17+
SubsampleType = Optional[int]
18+
19+
20+
class SimulationOptions(BaseModel):
21+
country: CountryType = Field(..., description="The country to simulate.")
22+
scope: ScopeType = Field(..., description="The scope of the simulation.")
23+
data: DataType = Field(None, description="The data to simulate.")
24+
time_period: TimePeriodType = Field(
25+
2025, description="The time period to simulate."
26+
)
27+
reform: ReformType = Field(None, description="The reform to simulate.")
28+
baseline: ReformType = Field(None, description="The baseline to simulate.")
29+
region: RegionType = Field(
30+
None, description="The region to simulate within the country."
31+
)
32+
subsample: SubsampleType = Field(
33+
None,
34+
description="How many, if a subsample, households to randomly simulate.",
35+
)
36+
title: Optional[str] = Field(
37+
"[Analysis title]",
38+
description="The title of the analysis (for charts). If not provided, a default title will be generated.",
39+
)
40+
include_cliffs: Optional[bool] = Field(
41+
False,
42+
description="Whether to include tax-benefit cliffs in the simulation analyses. If True, cliffs will be included.",
43+
)
44+
model_version: Optional[str] = Field(
45+
None,
46+
description="The version of the country model used in the simulation. If not provided, the current package version will be used. If provided, this package will throw an error if the package version does not match. Use this as an extra safety check.",
47+
)
48+
data_version: Optional[str] = Field(
49+
None,
50+
description="The version of the data used in the simulation. If not provided, the current data version will be used. If provided, this package will throw an error if the data version does not match. Use this as an extra safety check.",
51+
)
52+
53+
model_config = {
54+
"arbitrary_types_allowed": True,
55+
}

policyengine/simulation_results.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
from abc import ABC, abstractmethod
2+
from numpy.typing import ArrayLike
3+
import pandas
4+
5+
from policyengine.simulation.simulation_options import SimulationOptions
6+
7+
8+
class AbstractSimulationResults(ABC):
9+
@abstractmethod
10+
def calculate(
11+
self,
12+
variable_name: str,
13+
period: pandas.Period | None = None,
14+
map_to: str | None = None,
15+
decode_enums: bool = False,
16+
) -> pandas.Series:
17+
pass
18+
19+
@abstractmethod
20+
def variable_exists(self, variable_name: str) -> bool:
21+
pass
22+
23+
24+
class MacroContext:
25+
options: SimulationOptions
26+
baseline_simulation: AbstractSimulationResults
27+
reform_simulation: AbstractSimulationResults | None = None
28+
29+
def __init__(
30+
self,
31+
options: SimulationOptions,
32+
baseline: AbstractSimulationResults,
33+
reform: AbstractSimulationResults | None = None,
34+
):
35+
self.options = options
36+
self.baseline_simulation = baseline
37+
self.reform_simulation = reform

0 commit comments

Comments
 (0)