Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions changelog_entry.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
- bump: patch
changes:
added:
- A method to disable downsampling within the base CPS dataset generation class
- Non-downsampled versions of the 2021, 2022, and 2023 CPS datasets
changed:
- Pooled 3-Year CPS generation uses the non-downsampled versions of the 2021, 2022, and 2023 CPS datasets
- Downsampling method attempts to preserve original dtype values
86 changes: 80 additions & 6 deletions policyengine_us_data/datasets/cps/cps.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ class CPS(Dataset):
raw_cps: Type[CensusCPS] = None
previous_year_raw_cps: Type[CensusCPS] = None
data_format = Dataset.ARRAYS
downsample_by_half: bool = True

def generate(self):
"""Generates the Current Population Survey dataset for PolicyEngine US microsimulations.
Expand Down Expand Up @@ -58,18 +59,55 @@ def generate(self):

# Downsample

self.downsample(fraction=0.5)
if self.downsample_by_half:
self.downsample(fraction=0.5)

# def downsample(self, fraction: float = 0.5):
# from policyengine_us import Microsimulation

# sim = Microsimulation(dataset=self)
# sim.subsample(frac=fraction)
# original_data: dict = self.load_dataset()
# for key in original_data:
# if key not in sim.tax_benefit_system.variables:
# continue
# original_data[key] = sim.calculate(key).values

# self.save_dataset(original_data)

def downsample(self, fraction: float = 0.5):
from policyengine_us import Microsimulation

# Store original dtypes before modifying
original_data: dict = self.load_dataset()
original_dtypes = {
key: original_data[key].dtype for key in original_data
}

sim = Microsimulation(dataset=self)
sim.subsample(frac=fraction)
original_data: dict = self.load_dataset()

for key in original_data:
if key not in sim.tax_benefit_system.variables:
continue
original_data[key] = sim.calculate(key).values
values = sim.calculate(key).values

# Preserve the original dtype if possible
if (
key in original_dtypes
and hasattr(values, "dtype")
and values.dtype != original_dtypes[key]
):
try:
original_data[key] = values.astype(original_dtypes[key])
except:
# If conversion fails, log it but continue
print(
f"Warning: Could not convert {key} back to {original_dtypes[key]}"
)
original_data[key] = values
else:
original_data[key] = values

self.save_dataset(original_data)

Expand Down Expand Up @@ -673,6 +711,39 @@ class CPS_2024(CPS):
url = "release://policyengine/policyengine-us-data/1.13.0/cps_2024.h5"


# The below datasets are a very naïve way of preventing downsampling in the
# Pooled 3-Year CPS. They should be replaced by a more sustainable approach.
# If these are still here on July 1, 2025, please open an issue and raise at standup.
class CPS_2021_Not_Downsampled(CPS):
name = "cps_2021_not_downsampled"
label = "CPS 2021 (not downsampled)"
raw_cps = CensusCPS_2021
previous_year_raw_cps = CensusCPS_2020
file_path = STORAGE_FOLDER / "cps_2021_not_downsampled.h5"
time_period = 2021
downsample_by_half = False


class CPS_2022_Not_Downsampled(CPS):
name = "cps_2022_not_downsampled"
label = "CPS 2022 (not downsampled)"
raw_cps = CensusCPS_2022
previous_year_raw_cps = CensusCPS_2021
file_path = STORAGE_FOLDER / "cps_2022_not_downsampled.h5"
time_period = 2022
downsample_by_half = False


class CPS_2023_Not_Downsampled(CPS):
name = "cps_2023_not_downsampled"
label = "CPS 2023 (not downsampled)"
raw_cps = CensusCPS_2023
previous_year_raw_cps = CensusCPS_2022
file_path = STORAGE_FOLDER / "cps_2023_not_downsampled.h5"
time_period = 2023
downsample_by_half = False


class PooledCPS(Dataset):
data_format = Dataset.ARRAYS
input_datasets: list
Expand Down Expand Up @@ -724,9 +795,9 @@ class Pooled_3_Year_CPS_2023(PooledCPS):
name = "pooled_3_year_cps_2023"
file_path = STORAGE_FOLDER / "pooled_3_year_cps_2023.h5"
input_datasets = [
CPS_2021,
CPS_2022,
CPS_2023,
CPS_2021_Not_Downsampled,
CPS_2022_Not_Downsampled,
CPS_2023_Not_Downsampled,
]
time_period = 2023
url = "hf://policyengine/policyengine-us-data/pooled_3_year_cps_2023.h5"
Expand All @@ -737,4 +808,7 @@ class Pooled_3_Year_CPS_2023(PooledCPS):
CPS_2022().generate()
CPS_2023().generate()
CPS_2024().generate()
CPS_2021_Not_Downsampled().generate()
CPS_2022_Not_Downsampled().generate()
CPS_2023_Not_Downsampled().generate()
Pooled_3_Year_CPS_2023().generate()
Loading