Skip to content

Commit b18680d

Browse files
committed
adding matrix builder improvements
1 parent 89ce2c8 commit b18680d

File tree

3 files changed

+220
-30
lines changed

3 files changed

+220
-30
lines changed

policyengine_us_data/datasets/cps/local_area_calibration/calibration_utils.py

Lines changed: 45 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -192,15 +192,21 @@ def get_calculated_variables(sim) -> List[str]:
192192
"""
193193
Return variables that should be cleared for state-swap recalculation.
194194
195-
Includes variables with formulas, adds, or subtracts.
196-
197-
Excludes ID variables (person_id, household_id, etc.) because:
198-
1. They have formulas that generate sequential IDs (0, 1, 2, ...)
199-
2. We need the original H5 values, not regenerated sequences
200-
3. PolicyEngine's random() function uses entity IDs as seeds:
201-
seed = abs(entity_id * 100 + count_random_calls)
202-
If IDs change, random-dependent variables (SSI resource test,
203-
WIC nutritional risk, WIC takeup) produce different results.
195+
Includes variables with formulas, or adds/subtracts that are lists.
196+
197+
Excludes:
198+
1. ID variables (person_id, household_id, etc.) - needed for random seeds
199+
2. Variables with string adds/subtracts (parameter paths) - these are
200+
pseudo-inputs stored in H5 that would recalculate differently using
201+
parameter lookups. Examples: pre_tax_contributions.
202+
3. Variables in input_variables (have stored H5 values) even if they
203+
have formulas - the stored values represent original survey data
204+
that should be preserved. Examples: cdcc_relevant_expenses, rent.
205+
206+
The exclusions are critical because:
207+
- The H5 file stores pre-computed values from original CPS processing
208+
- If deleted, recalculation produces different values, corrupting
209+
downstream calculations like income_tax
204210
"""
205211
exclude_ids = {
206212
"person_id",
@@ -210,16 +216,36 @@ def get_calculated_variables(sim) -> List[str]:
210216
"family_id",
211217
"marital_unit_id",
212218
}
213-
return [
214-
name
215-
for name, var in sim.tax_benefit_system.variables.items()
216-
if (
217-
var.formulas
218-
or getattr(var, "adds", None)
219-
or getattr(var, "subtracts", None)
220-
)
221-
and name not in exclude_ids
222-
]
219+
220+
# Get stored input variables to exclude
221+
input_vars = set(sim.input_variables)
222+
223+
result = []
224+
for name, var in sim.tax_benefit_system.variables.items():
225+
if name in exclude_ids:
226+
continue
227+
228+
# Exclude variables that have stored values (input_variables)
229+
# These represent original survey data that should be preserved
230+
if name in input_vars:
231+
continue
232+
233+
# Include if has formulas
234+
if var.formulas:
235+
result.append(name)
236+
continue
237+
238+
# Include if adds/subtracts is a list (explicit component aggregation)
239+
# Exclude if adds/subtracts is a string (parameter path - pseudo-input)
240+
adds = getattr(var, "adds", None)
241+
subtracts = getattr(var, "subtracts", None)
242+
243+
if adds and isinstance(adds, list):
244+
result.append(name)
245+
elif subtracts and isinstance(subtracts, list):
246+
result.append(name)
247+
248+
return result
223249

224250

225251
def get_pseudo_input_variables(sim) -> set:

policyengine_us_data/datasets/cps/local_area_calibration/sparse_matrix_builder.py

Lines changed: 128 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,105 @@ def __init__(
3838
self.time_period = time_period
3939
self.cds_to_calibrate = cds_to_calibrate
4040
self.dataset_path = dataset_path
41+
self._entity_rel_cache = None
42+
43+
def _build_entity_relationship(self, sim) -> pd.DataFrame:
44+
"""
45+
Build entity relationship DataFrame mapping persons to all entity IDs.
46+
47+
This is used to evaluate constraints at the person level and then
48+
aggregate to household level, handling variables defined at different
49+
entity levels (person, tax_unit, household, spm_unit).
50+
51+
Returns:
52+
DataFrame with person_id, household_id, tax_unit_id, spm_unit_id
53+
"""
54+
if self._entity_rel_cache is not None:
55+
return self._entity_rel_cache
56+
57+
self._entity_rel_cache = pd.DataFrame(
58+
{
59+
"person_id": sim.calculate(
60+
"person_id", map_to="person"
61+
).values,
62+
"household_id": sim.calculate(
63+
"household_id", map_to="person"
64+
).values,
65+
"tax_unit_id": sim.calculate(
66+
"tax_unit_id", map_to="person"
67+
).values,
68+
"spm_unit_id": sim.calculate(
69+
"spm_unit_id", map_to="person"
70+
).values,
71+
}
72+
)
73+
return self._entity_rel_cache
74+
75+
def _evaluate_constraints_entity_aware(
76+
self, state_sim, constraints: List[dict], n_households: int
77+
) -> np.ndarray:
78+
"""
79+
Evaluate non-geographic constraints at person level, aggregate to
80+
household level using .any().
81+
82+
This properly handles constraints on variables defined at different
83+
entity levels (e.g., tax_unit_is_filer at tax_unit level). Instead of
84+
summing values at household level (which would give 2, 3, etc. for
85+
households with multiple tax units), we evaluate at person level and
86+
use .any() aggregation ("does this household have at least one person
87+
satisfying all constraints?").
88+
89+
Args:
90+
state_sim: Microsimulation with state_fips set
91+
constraints: List of constraint dicts with variable, operation,
92+
value keys (geographic constraints should be pre-filtered)
93+
n_households: Number of households
94+
95+
Returns:
96+
Boolean mask array of length n_households
97+
"""
98+
if not constraints:
99+
return np.ones(n_households, dtype=bool)
100+
101+
entity_rel = self._build_entity_relationship(state_sim)
102+
n_persons = len(entity_rel)
103+
104+
person_mask = np.ones(n_persons, dtype=bool)
105+
106+
for c in constraints:
107+
var = c["variable"]
108+
op = c["operation"]
109+
val = c["value"]
110+
111+
# Calculate constraint variable at person level
112+
constraint_values = state_sim.calculate(
113+
var, map_to="person"
114+
).values
115+
116+
# Apply operation at person level
117+
person_mask &= apply_op(constraint_values, op, val)
118+
119+
# Aggregate to household level using .any()
120+
# "At least one person in this household satisfies ALL constraints"
121+
entity_rel_with_mask = entity_rel.copy()
122+
entity_rel_with_mask["satisfies"] = person_mask
123+
124+
household_mask_series = entity_rel_with_mask.groupby("household_id")[
125+
"satisfies"
126+
].any()
127+
128+
# Ensure we return a mask aligned with household order
129+
household_ids = state_sim.calculate(
130+
"household_id", map_to="household"
131+
).values
132+
household_mask = np.array(
133+
[
134+
household_mask_series.get(hh_id, False)
135+
for hh_id in household_ids
136+
]
137+
)
138+
139+
return household_mask
41140

42141
def _query_targets(self, target_filter: dict) -> pd.DataFrame:
43142
"""Query targets based on filter criteria using OR logic."""
@@ -166,6 +265,9 @@ def build_matrix(
166265
cds_by_state[state].append((cd_idx, cd))
167266

168267
for state, cd_list in cds_by_state.items():
268+
# Clear entity relationship cache when creating new simulation
269+
self._entity_rel_cache = None
270+
169271
if self.dataset_path:
170272
state_sim = self._create_state_sim(state, n_households)
171273
else:
@@ -184,27 +286,43 @@ def build_matrix(
184286
for row_idx, (_, target) in enumerate(targets_df.iterrows()):
185287
constraints = self._get_constraints(target["stratum_id"])
186288

187-
mask = np.ones(n_households, dtype=bool)
289+
geo_constraints = []
290+
non_geo_constraints = []
188291
for c in constraints:
292+
if c["variable"] in (
293+
"state_fips",
294+
"congressional_district_geoid",
295+
):
296+
geo_constraints.append(c)
297+
else:
298+
non_geo_constraints.append(c)
299+
300+
# Check geographic constraints first (quick fail)
301+
geo_mask = np.ones(n_households, dtype=bool)
302+
for c in geo_constraints:
189303
if c["variable"] == "congressional_district_geoid":
190304
if (
191305
c["operation"] in ("==", "=")
192306
and c["value"] != cd
193307
):
194-
mask[:] = False
308+
geo_mask[:] = False
195309
elif c["variable"] == "state_fips":
196310
if (
197311
c["operation"] in ("==", "=")
198312
and int(c["value"]) != state
199313
):
200-
mask[:] = False
201-
else:
202-
values = state_sim.calculate(
203-
c["variable"], map_to="household"
204-
).values
205-
mask &= apply_op(
206-
values, c["operation"], c["value"]
207-
)
314+
geo_mask[:] = False
315+
316+
if not geo_mask.any():
317+
continue
318+
319+
# Evaluate non-geographic constraints at entity level
320+
entity_mask = self._evaluate_constraints_entity_aware(
321+
state_sim, non_geo_constraints, n_households
322+
)
323+
324+
# Combine geographic and entity-aware masks
325+
mask = geo_mask & entity_mask
208326

209327
if not mask.any():
210328
continue

policyengine_us_data/tests/test_local_area_calibration/conftest.py

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,22 @@
3737
("tanf", 1e-2),
3838
("tip_income", 1e-2),
3939
("unemployment_compensation", 1e-2),
40+
("income_tax", 1e-2),
41+
("income_tax", 1e-2),
42+
("qualified_business_income_deduction", 1e-2),
43+
("taxable_social_security", 1e-2),
44+
("taxable_pension_income", 1e-2),
45+
("taxable_ira_distributions", 1e-2),
46+
("taxable_interest_income", 1e-2),
47+
("tax_exempt_interest_income", 1e-2),
48+
("self_employment_income", 1e-2),
49+
("salt", 1e-2),
50+
("refundable_ctc", 1e-2),
51+
("real_estate_taxes", 1e-2),
52+
("qualified_dividend_income", 1e-2),
53+
("dividend_income", 1e-2),
54+
("adjusted_gross_income", 1e-2),
55+
("eitc", 1e-2),
4056
]
4157

4258
# Combined filter config to build matrix with all variables at once
@@ -45,6 +61,20 @@
4561
4, # SNAP targets
4662
5, # Medicaid targets
4763
112, # Unemployment compensation targets
64+
117, # Income tax targets
65+
100, # QBID targets
66+
111, # Taxable social security targets
67+
114, # Taxable pension income targets
68+
105, # Taxable IRA distributions targets
69+
106, # Taxable interest income targets
70+
107, # Tax exempt interest income targets
71+
101, # Self-employment income targets
72+
116, # Salt targets
73+
115, # Refundable CTC targets
74+
103, # Real estate taxes targets
75+
109, # Qualified dividend income targets
76+
108, # Dividend income targets
77+
3, # Adjusted gross income targets
4878
],
4979
"variables": [
5080
"snap",
@@ -60,14 +90,30 @@
6090
"tanf",
6191
"tip_income",
6292
"unemployment_compensation",
93+
"income_tax",
94+
"income_tax",
95+
"qualified_business_income_deduction",
96+
"taxable_social_security",
97+
"taxable_pension_income",
98+
"taxable_ira_distributions",
99+
"taxable_interest_income",
100+
"tax_exempt_interest_income",
101+
"self_employment_income",
102+
"salt",
103+
"refundable_ctc",
104+
"real_estate_taxes",
105+
"qualified_dividend_income",
106+
"dividend_income",
107+
"adjusted_gross_income",
108+
"eitc",
63109
],
64110
}
65111

66112
# Maximum allowed mismatch rate for state-level value comparison
67113
MAX_MISMATCH_RATE = 0.02
68114

69115
# Number of samples for cell-level verification tests
70-
N_VERIFICATION_SAMPLES = 200
116+
N_VERIFICATION_SAMPLES = 2000
71117

72118

73119
@pytest.fixture(scope="module")

0 commit comments

Comments
 (0)