Skip to content

Commit d5fb50e

Browse files
Increase sample size of the CPS (#325)
* Use full CPS * Changelog * Reduce prod epochs * Cut down test time * Cut test time * Remove redundant CPS test
1 parent 58f1f6a commit d5fb50e

File tree

5 files changed

+26
-59
lines changed

5 files changed

+26
-59
lines changed

changelog_entry.yaml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
- bump: patch
2+
changes:
3+
fixed:
4+
- Use full CPS by default.

policyengine_us_data/datasets/cps/cps.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1923,7 +1923,7 @@ class CPS_2024(CPS):
19231923
file_path = STORAGE_FOLDER / "cps_2024.h5"
19241924
time_period = 2024
19251925
url = "release://policyengine/policyengine-us-data/1.13.0/cps_2024.h5"
1926-
frac = 0.5
1926+
frac = 1
19271927

19281928

19291929
# The below datasets are a very naïve way of preventing downsampling in the

policyengine_us_data/datasets/cps/enhanced_cps.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,12 +63,12 @@ def dropout_weights(weights, p):
6363
masked_weights[mask] = mean
6464
return masked_weights
6565

66-
optimizer = torch.optim.Adam([weights], lr=1e-1)
66+
optimizer = torch.optim.Adam([weights], lr=3e-1)
6767
from tqdm import trange
6868

6969
start_loss = None
7070

71-
iterator = trange(5_000 if not os.environ.get("TEST_LITE") else 1000)
71+
iterator = trange(500 if not os.environ.get("TEST_LITE") else 500)
7272
performance = pd.DataFrame()
7373
for i in iterator:
7474
optimizer.zero_grad()

policyengine_us_data/tests/test_datasets/test_cps.py

Lines changed: 0 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -2,34 +2,6 @@
22
import numpy as np
33

44

5-
@pytest.mark.parametrize("year", [2022])
6-
def test_policyengine_cps_generates(year: int):
7-
from policyengine_us_data.datasets.cps.cps import CPS_2022
8-
9-
dataset_by_year = {
10-
2022: CPS_2022,
11-
}
12-
13-
dataset_by_year[year](require=True)
14-
15-
16-
@pytest.mark.parametrize("year", [2022])
17-
def test_policyengine_cps_loads(year: int):
18-
from policyengine_us_data.datasets.cps.cps import CPS_2022
19-
20-
dataset_by_year = {
21-
2022: CPS_2022,
22-
}
23-
24-
dataset = dataset_by_year[year]
25-
26-
from policyengine_us import Microsimulation
27-
28-
sim = Microsimulation(dataset=dataset)
29-
30-
assert not sim.calculate("household_net_income").isna().any()
31-
32-
335
def test_cps_has_auto_loan_interest():
346
from policyengine_us_data.datasets.cps import CPS_2024
357
from policyengine_us import Microsimulation

policyengine_us_data/tests/test_datasets/test_enhanced_cps.py

Lines changed: 19 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,34 +1,6 @@
11
import pytest
22

33

4-
@pytest.mark.parametrize("year", [2024])
5-
def test_policyengine_cps_generates(year: int):
6-
from policyengine_us_data.datasets.cps import EnhancedCPS_2024
7-
8-
dataset_by_year = {
9-
2024: EnhancedCPS_2024,
10-
}
11-
12-
dataset_by_year[year](require=True)
13-
14-
15-
@pytest.mark.parametrize("year", [2024])
16-
def test_policyengine_cps_loads(year: int):
17-
from policyengine_us_data.datasets.cps import EnhancedCPS_2024
18-
19-
dataset_by_year = {
20-
2024: EnhancedCPS_2024,
21-
}
22-
23-
dataset = dataset_by_year[year]
24-
25-
from policyengine_us import Microsimulation
26-
27-
sim = Microsimulation(dataset=dataset)
28-
29-
assert not sim.calculate("household_net_income").isna().any()
30-
31-
324
def test_ecps_has_mortgage_interest():
335
from policyengine_us_data.datasets.cps import EnhancedCPS_2024
346
from policyengine_us import Microsimulation
@@ -50,6 +22,25 @@ def test_ecps_has_tips():
5022

5123

5224
def test_ecps_replicates_jct_tax_expenditures():
25+
import pandas as pd
26+
27+
calibration_log = pd.read_csv(
28+
"calibration_log.csv",
29+
)
30+
31+
jct_rows = calibration_log[
32+
(calibration_log["target_name"].str.contains("jct/"))
33+
& (calibration_log["epoch"] == calibration_log["epoch"].max())
34+
]
35+
36+
assert (
37+
jct_rows.rel_abs_error.max() < 0.4
38+
), "JCT tax expenditure targets not met (see the calibration log for details). Max relative error: {:.2%}".format(
39+
jct_rows.rel_abs_error.max()
40+
)
41+
42+
43+
def deprecated_test_ecps_replicates_jct_tax_expenditures_full():
5344
from policyengine_us import Microsimulation
5445
from policyengine_core.reforms import Reform
5546
from policyengine_us_data.datasets import EnhancedCPS_2024

0 commit comments

Comments
 (0)