Skip to content

Commit a064863

Browse files
Run lite versions of imputations and calibrations in tests (#251)
* Run lite versions of imputations and calibrations in tests Fixes #250 * Downsample SCF data too * Add more logging * Don't tune hyperparameters on PR tests * Use correct env var * Increase epochs * Impute all extended CPS variables
1 parent 42a4802 commit a064863

File tree

7 files changed

+44
-12
lines changed

7 files changed

+44
-12
lines changed

.github/workflows/pr_code_changes.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ jobs:
4242
- name: Build datasets
4343
run: make data
4444
env:
45-
LITE_MODE: true
45+
TEST_LITE: true
4646
- name: Run tests
4747
run: pytest
4848
- name: Test documentation builds

changelog_entry.yaml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
- bump: patch
2+
changes:
3+
fixed:
4+
- Runtime for tests reduced.

policyengine_us_data/datasets/cps/cps.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
from policyengine_us_data.utils import QRF
1515
import logging
1616

17+
test_lite = os.environ.get("TEST_LITE")
18+
1719

1820
class CPS(Dataset):
1921
name = "cps"
@@ -49,21 +51,33 @@ def generate(self):
4951
raw_data[entity] for entity in ENTITIES
5052
]
5153

54+
logging.info("Adding ID variables")
5255
add_id_variables(cps, person, tax_unit, family, spm_unit, household)
56+
logging.info("Adding personal variables")
5357
add_personal_variables(cps, person)
58+
logging.info("Adding personal income variables")
5459
add_personal_income_variables(cps, person, self.raw_cps.time_period)
60+
logging.info("Adding previous year income variables")
5561
add_previous_year_income(self, cps)
62+
logging.info("Adding SSN card type")
5663
add_ssn_card_type(cps, person)
64+
logging.info("Adding family variables")
5765
add_spm_variables(cps, spm_unit)
66+
logging.info("Adding household variables")
5867
add_household_variables(cps, household)
68+
logging.info("Adding rent")
5969
add_rent(self, cps, person, household)
70+
logging.info("Adding auto loan balance")
6071
add_auto_loan_balance(self, cps)
72+
logging.info("Adding tips")
6173
add_tips(self, cps)
74+
logging.info("Added all variables")
6275

6376
raw_data.close()
6477
self.save_dataset(cps)
65-
78+
logging.info("Adding takeup")
6679
add_takeup(self)
80+
logging.info("Downsampling")
6781

6882
# Downsample
6983
if self.frac is not None and self.frac < 1.0:
@@ -146,7 +160,9 @@ def add_rent(self, cps: h5py.File, person: DataFrame, household: DataFrame):
146160
},
147161
na_action="ignore",
148162
).fillna(train_df.tenure_type)
149-
train_df = train_df[train_df.is_household_head].sample(100_000)
163+
train_df = train_df[train_df.is_household_head].sample(
164+
100_000 if not test_lite else 1_000
165+
)
150166
inference_df = cps_sim.calculate_dataframe(PREDICTORS)
151167
mask = inference_df.is_household_head.values
152168
inference_df = inference_df[mask]
@@ -290,7 +306,7 @@ def add_auto_loan_balance(self, cps: h5py.File) -> None:
290306
donor_data = donor_data.loc[
291307
np.random.choice(
292308
donor_data.index,
293-
size=100_000,
309+
size=100_000 if not test_lite else 1_000,
294310
replace=True,
295311
p=donor_data.household_weight / donor_data.household_weight.sum(),
296312
)
@@ -303,7 +319,7 @@ def add_auto_loan_balance(self, cps: h5py.File) -> None:
303319
X_train=donor_data,
304320
predictors=PREDICTORS,
305321
imputed_variables=IMPUTED_VARIABLES,
306-
tune_hyperparameters=True,
322+
tune_hyperparameters=not test_lite,
307323
)
308324

309325
imputations = fitted_model.predict(X_test=receiver_data)

policyengine_us_data/datasets/cps/enhanced_cps.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
CPS_2019,
1515
CPS_2024,
1616
)
17+
import os
1718

1819
try:
1920
import torch
@@ -66,7 +67,7 @@ def dropout_weights(weights, p):
6667

6768
start_loss = None
6869

69-
iterator = trange(5_000)
70+
iterator = trange(5_000 if not os.environ.get("TEST_LITE") else 1_000)
7071
for i in iterator:
7172
optimizer.zero_grad()
7273
weights_ = dropout_weights(weights, dropout_rate)
@@ -88,6 +89,9 @@ def train_previous_year_income_model():
8889

8990
sim = Microsimulation(dataset=CPS_2019)
9091

92+
if os.environ.get("TEST_LITE"):
93+
sim.subsample(1_000)
94+
9195
VARIABLES = [
9296
"previous_year_income_available",
9397
"employment_income",

policyengine_us_data/datasets/cps/extended_cps.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -71,9 +71,6 @@
7171
"deductible_mortgage_interest",
7272
]
7373

74-
if os.environ.get("TEST_LITE"):
75-
IMPUTED_VARIABLES = IMPUTED_VARIABLES[:7]
76-
7774

7875
class ExtendedCPS(Dataset):
7976
cps: Type[CPS]
@@ -86,8 +83,8 @@ def generate(self):
8683
cps_sim = Microsimulation(dataset=self.cps)
8784
puf_sim = Microsimulation(dataset=self.puf)
8885

89-
if os.environ.get("LITE_MODE"):
90-
puf_sim.subsample(10_000)
86+
if os.environ.get("TEST_LITE"):
87+
puf_sim.subsample(1_000)
9188

9289
INPUTS = [
9390
"age",

policyengine_us_data/datasets/puf/puf.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from policyengine_us_data.utils.uprating import (
1010
create_policyengine_uprating_factors_table,
1111
)
12+
import os
1213

1314
rng = np.random.default_rng(seed=64)
1415

@@ -18,6 +19,8 @@ def impute_pension_contributions_to_puf(puf_df):
1819
from policyengine_us_data.datasets.cps import CPS_2021
1920

2021
cps = Microsimulation(dataset=CPS_2021)
22+
if os.environ.get("TEST_LITE"):
23+
cps.subsample(1_000)
2124
cps_df = cps.calculate_dataframe(
2225
["employment_income", "household_weight", "pre_tax_contributions"]
2326
)
@@ -46,6 +49,11 @@ def impute_missing_demographics(
4649
.fillna(0)
4750
)
4851

52+
if os.environ.get("TEST_LITE"):
53+
puf_with_demographics = puf_with_demographics.sample(
54+
n=1_000, random_state=0
55+
)
56+
4957
DEMOGRAPHIC_VARIABLES = [
5058
"AGEDP1",
5159
"AGEDP2",

policyengine_us_data/datasets/sipp/sipp.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,9 @@
66
from policyengine_us_data.storage import STORAGE_FOLDER
77
import pickle
88
from huggingface_hub import hf_hub_download
9+
import os
10+
11+
test_lite = os.environ.get("TEST_LITE")
912

1013

1114
def train_tip_model():
@@ -100,7 +103,7 @@ def train_tip_model():
100103
sipp = sipp.loc[
101104
np.random.choice(
102105
sipp.index,
103-
size=100_000,
106+
size=100_000 if not test_lite else 1_000,
104107
replace=True,
105108
p=sipp.household_weight / sipp.household_weight.sum(),
106109
)

0 commit comments

Comments
 (0)