Skip to content

Commit ad06f73

Browse files
committed
run black
1 parent fa0a4fc commit ad06f73

40 files changed

+3821
-2693
lines changed

demos/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,4 @@
22
DEMOS - Demographic Micro-Simulator
33
44
A microsimulation framework for demographic and economic modeling.
5-
"""
5+
"""

demos/config.py

Lines changed: 48 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -10,14 +10,17 @@
1010

1111
CONFIG = None
1212

13+
1314
class HHRebalancingModuleConfig(BaseModel):
1415
"""
1516
Configuration for Household Rebalancing module
1617
"""
18+
1719
control_table: str
1820
control_col: str
1921
geoid_col: str
2022

23+
2124
class EmploymentModuleConfig(BaseModel):
2225
simultaneous_calibration_config: Optional[SimultaneousCalibrationConfig] = None
2326
enter_model_calibration_procedure: Optional[CalibrationConfig] = None
@@ -29,24 +32,31 @@ def check_calibration_config_exclusivity(self):
2932
enter_cal = self.enter_model_calibration_procedure is not None
3033
exit_cal = self.exit_model_calibration_procedure is not None
3134
if sim_cal and (enter_cal or exit_cal):
32-
raise ValueError(f"Simultaneous calibration cannot be used at the same time as " + \
33-
f"individual model calibration. Simultaneous selected: {sim_cal}, " + \
34-
f"EnterModel selected: {enter_cal}, ExitModel selected: {exit_cal}")
35+
raise ValueError(
36+
f"Simultaneous calibration cannot be used at the same time as "
37+
+ f"individual model calibration. Simultaneous selected: {sim_cal}, "
38+
+ f"EnterModel selected: {enter_cal}, ExitModel selected: {exit_cal}"
39+
)
3540
return self
3641

42+
3743
class HHReorgModuleConfig(BaseModel):
3844
simultaneous_calibration_config: Optional[SimultaneousCalibrationConfig] = None
3945
geoid_col: Optional[str] = None
4046

47+
4148
class MortalityModuleConfig(BaseModel):
4249
calibration_procedure: Optional[CalibrationConfig] = None
4350

51+
4452
class BirthModuleConfig(BaseModel):
4553
calibration_procedure: Optional[CalibrationConfig] = None
4654

55+
4756
class KidsMovingModuleConfig(BaseModel):
4857
geoid_col: str
4958

59+
5060
class AgingModuleConfig(BaseModel):
5161
#: Age at which a person qualifies as senior
5262
senior_age: int = 65
@@ -56,6 +66,7 @@ class DEMOSConfig(BaseModel):
5666
"""
5767
Global configuration for DEMOS. Individual fields in this class control the configuration of each module.
5868
"""
69+
5970
random_seed: int
6071

6172
#: Year represented in synthetic population input
@@ -82,32 +93,44 @@ class DEMOSConfig(BaseModel):
8293

8394
# Module-specific config
8495
aging_module_config: AgingModuleConfig = Field(default_factory=AgingModuleConfig)
85-
employment_module_config: EmploymentModuleConfig = Field(default_factory=EmploymentModuleConfig)
86-
hh_reorg_module_config: HHReorgModuleConfig = Field(default_factory=HHReorgModuleConfig)
87-
mortality_module_config: MortalityModuleConfig = Field(default_factory=MortalityModuleConfig)
96+
employment_module_config: EmploymentModuleConfig = Field(
97+
default_factory=EmploymentModuleConfig
98+
)
99+
hh_reorg_module_config: HHReorgModuleConfig = Field(
100+
default_factory=HHReorgModuleConfig
101+
)
102+
mortality_module_config: MortalityModuleConfig = Field(
103+
default_factory=MortalityModuleConfig
104+
)
88105
birth_module_config: BirthModuleConfig = Field(default_factory=BirthModuleConfig)
89-
hh_rebalancing_module_config: HHRebalancingModuleConfig = Field(default_factory=HHRebalancingModuleConfig)
90-
kids_moving_module_config: KidsMovingModuleConfig = Field(default_factory=KidsMovingModuleConfig)
91-
106+
hh_rebalancing_module_config: HHRebalancingModuleConfig = Field(
107+
default_factory=HHRebalancingModuleConfig
108+
)
109+
kids_moving_module_config: KidsMovingModuleConfig = Field(
110+
default_factory=KidsMovingModuleConfig
111+
)
112+
92113
def model_post_init(self, __context) -> None:
93114
if self.output_fname is None:
94-
self.output_fname = f"{self.output_dir}/demos_output_{self.forecast_year}.h5"
115+
self.output_fname = (
116+
f"{self.output_dir}/demos_output_{self.forecast_year}.h5"
117+
)
95118
os.makedirs(self.output_dir, exist_ok=True)
96119
logger.info(f"Output file set to default: {self.output_fname}")
97-
120+
98121
if self.output_tables is None:
99122
self.output_tables = []
100-
123+
101124
if self.initialize_empty_tables is None:
102125
self.initialize_empty_tables = []
103-
126+
104127
# Load all table datasources
105128
for t in self.tables:
106129
t.load_into_orca()
107130

108131
for n in self.initialize_empty_tables:
109132
orca.add_table(n, pd.DataFrame())
110-
133+
111134
if self.modules is None:
112135
self.modules = [
113136
"aging",
@@ -118,21 +141,27 @@ def model_post_init(self, __context) -> None:
118141
"birth_model",
119142
"education_model",
120143
"household_rebalancing",
121-
"update_income"
144+
"update_income",
122145
]
123-
124146

125-
@model_validator(mode='after')
147+
@model_validator(mode="after")
126148
def require_persons_and_households(self):
127149
loaded_table_names = [t.table_name for t in self.tables]
128-
if "persons" not in loaded_table_names or "households" not in loaded_table_names:
129-
raise ValueError(f"Both 'persons' and 'households' tables are required. Tables defined: {loaded_table_names}")
150+
if (
151+
"persons" not in loaded_table_names
152+
or "households" not in loaded_table_names
153+
):
154+
raise ValueError(
155+
f"Both 'persons' and 'households' tables are required. Tables defined: {loaded_table_names}"
156+
)
130157
return self
131158

159+
132160
def load_config_file(dir: str) -> DEMOSConfig:
133161
global CONFIG
134162
CONFIG = DEMOSConfig(**toml.load(dir))
135163

164+
136165
def get_config():
137166
global CONFIG
138167
if CONFIG is None:

demos/datasources.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@
77

88
class CSVTableSource(BaseModel):
99
""""""
10-
file_type: Literal['csv']
10+
11+
file_type: Literal["csv"]
1112
#: Path to source file
1213
filepath: str
1314
#: Column in the file to be used as index (e.g. `person_id`)
@@ -21,24 +22,31 @@ class CSVTableSource(BaseModel):
2122

2223
def load_into_orca(self):
2324
logger.info(f"Loading CSV '{self.table_name}' table from {self.filepath}")
24-
df = pd.read_csv(self.filepath, delimiter=self.delimiter,
25-
dtype=self.custom_dtype_casting).set_index(self.index_col)
25+
df = pd.read_csv(
26+
self.filepath, delimiter=self.delimiter, dtype=self.custom_dtype_casting
27+
).set_index(self.index_col)
2628
orca.add_table(self.table_name, df)
2729

2830

2931
class H5TableSource(BaseModel):
3032
""""""
31-
file_type: Literal['h5']
33+
34+
file_type: Literal["h5"]
3235
#: Path to source file
3336
filepath: str
3437
#: key in the source HDF5 to be loaded
3538
h5_key: str
3639
#: Identifier of the table in orca
3740
table_name: str
38-
41+
3942
def load_into_orca(self):
40-
logger.info(f"Loading HDF5 '{self.table_name}' table from {self.filepath}/{self.h5_key}")
43+
logger.info(
44+
f"Loading HDF5 '{self.table_name}' table from {self.filepath}/{self.h5_key}"
45+
)
4146
df = pd.read_hdf(self.filepath, key=self.h5_key)
4247
orca.add_table(self.table_name, df)
4348

44-
DataSourceModel = Annotated[H5TableSource | CSVTableSource, Field(discriminator="file_type")]
49+
50+
DataSourceModel = Annotated[
51+
H5TableSource | CSVTableSource, Field(discriminator="file_type")
52+
]

demos/logging_logic.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,16 @@ def flush(self):
2525
logger.opt(depth=1).log(self.level, f"{self.prefix}{self._buf.rstrip()}")
2626
self._buf = ""
2727

28+
2829
def log_execution_time(start_time, year, module_name):
2930
now = time.time()
30-
run_table = orca.get_table('run_times')
31-
run_table.local = pd.concat([run_table.local,
32-
pd.DataFrame([[year, module_name, now - start_time]],
33-
columns=["year", "module", "walltime"])
34-
])
31+
run_table = orca.get_table("run_times")
32+
run_table.local = pd.concat(
33+
[
34+
run_table.local,
35+
pd.DataFrame(
36+
[[year, module_name, now - start_time]],
37+
columns=["year", "module", "walltime"],
38+
),
39+
]
40+
)

demos/models/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,4 +10,4 @@
1010
from .rebalancing import *
1111
from .income_adjustment import *
1212
from .export import *
13-
from .main import *
13+
from .main import *

demos/models/aging.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,8 @@
55
from config import DEMOSConfig, AgingModuleConfig, get_config
66

77
STEP_NAME = "aging"
8-
REQUIRED_COLUMNS = [
9-
"persons.age"
10-
]
8+
REQUIRED_COLUMNS = ["persons.age"]
9+
1110

1211
@orca.step(STEP_NAME)
1312
def aging(persons):
@@ -78,6 +77,7 @@ def senior(data="persons.age"):
7877

7978
return (data >= aging_config.senior_age).astype(int)
8079

80+
8181
@orca.column(table_name="persons")
8282
def age_group(data="persons.age"):
8383
"""
@@ -96,5 +96,7 @@ def age_group(data="persons.age"):
9696
Categorical age group labels as strings.
9797
"""
9898
age_intervals = [0, 20, 30, 40, 50, 65, 900]
99-
age_labels = ['lte19', '20-29', '30-39', '40-49', '50-64', 'gte65']
100-
return pd.cut(data, bins=age_intervals, labels=age_labels, include_lowest=True).astype(str)
99+
age_labels = ["lte19", "20-29", "30-39", "40-49", "50-64", "gte65"]
100+
return pd.cut(
101+
data, bins=age_intervals, labels=age_labels, include_lowest=True
102+
).astype(str)

demos/models/birth.py

Lines changed: 30 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -7,16 +7,23 @@
77
from logging_logic import log_execution_time
88
from config import DEMOSConfig, get_config
99

10+
1011
@orca.injectable(autocall=False)
1112
def get_new_person_id(n):
1213
persons = orca.get_table("persons")
1314
graveyard = orca.get_table("graveyard")
1415
rebalanced_persons = orca.get_table("rebalanced_persons")
1516

16-
current_max = max([persons.local.index.max(), graveyard.local.index.max(), rebalanced_persons.local.index.max()])
17+
current_max = max(
18+
[
19+
persons.local.index.max(),
20+
graveyard.local.index.max(),
21+
rebalanced_persons.local.index.max(),
22+
]
23+
)
1724
return (
18-
np.arange(n) # = [0, 1, 2 ...] up to the number of people
19-
+ current_max # = [max_person_id, max_person_id + 1, ...]
25+
np.arange(n) # = [0, 1, 2 ...] up to the number of people
26+
+ current_max # = [max_person_id, max_person_id + 1, ...]
2027
+ 1
2128
)
2229

@@ -61,19 +68,25 @@ def birth_model(persons, households, observed_births_data, get_new_person_id, ye
6168

6269
# Set race of babies
6370
# TODO: There is duplication of information between `race_id` and `race`
64-
hh_races = (persons.local.groupby("household_id")
65-
.agg(num_races=("race_id", "nunique"))
66-
.reset_index()
67-
.merge(
68-
households.to_frame(["hh_race_of_head", "hh_race_id_of_head", "household_id"])
69-
.reset_index(),
70-
on="household_id")).set_index("household_id")
71+
hh_races = (
72+
persons.local.groupby("household_id")
73+
.agg(num_races=("race_id", "nunique"))
74+
.reset_index()
75+
.merge(
76+
households.to_frame(
77+
["hh_race_of_head", "hh_race_id_of_head", "household_id"]
78+
).reset_index(),
79+
on="household_id",
80+
)
81+
).set_index("household_id")
7182
one_race_hh_filter = (hh_races.loc[babies.household_id]["num_races"] == 1).values
7283
babies["race_id"] = 9
73-
babies.loc[one_race_hh_filter, "race_id"] = hh_races.loc[babies.loc[one_race_hh_filter, "household_id"], "hh_race_id_of_head"].values
84+
babies.loc[one_race_hh_filter, "race_id"] = hh_races.loc[
85+
babies.loc[one_race_hh_filter, "household_id"], "hh_race_id_of_head"
86+
].values
7487
babies["race"] = babies["race_id"].map({1: "white", 2: "black"})
7588
babies["race"].fillna("other", inplace=True)
76-
89+
7790
# Finally add babies to persons table
7891
persons.local = pd.concat([persons.local, babies])
7992

@@ -87,13 +100,15 @@ def run_and_calibrate_birth_model(persons, households):
87100
# Load calibration config
88101
demos_config: DEMOSConfig = get_config()
89102
calibration_procedure = demos_config.birth_module_config.calibration_procedure
90-
103+
91104
# Get model data
92105
birth_model = mm.get_step("birth")
93106
birth_model_variables = columns_in_formula(birth_model.model_expression)
94107
birth_model_data = households.to_frame(birth_model_variables).loc[ELIGIBLE_HH]
95-
108+
96109
# Calibrate if needed
97110
if calibration_procedure is not None:
98-
return calibration_procedure.calibrate_and_run_model(birth_model, birth_model_data)
111+
return calibration_procedure.calibrate_and_run_model(
112+
birth_model, birth_model_data
113+
)
99114
return birth_model.predict(birth_model_data)

0 commit comments

Comments
 (0)