@@ -38,6 +38,105 @@ def __init__(
3838 self .time_period = time_period
3939 self .cds_to_calibrate = cds_to_calibrate
4040 self .dataset_path = dataset_path
41+ self ._entity_rel_cache = None
42+
43+ def _build_entity_relationship (self , sim ) -> pd .DataFrame :
44+ """
45+ Build entity relationship DataFrame mapping persons to all entity IDs.
46+
47+ This is used to evaluate constraints at the person level and then
48+ aggregate to household level, handling variables defined at different
49+ entity levels (person, tax_unit, household, spm_unit).
50+
51+ Returns:
52+ DataFrame with person_id, household_id, tax_unit_id, spm_unit_id
53+ """
54+ if self ._entity_rel_cache is not None :
55+ return self ._entity_rel_cache
56+
57+ self ._entity_rel_cache = pd .DataFrame (
58+ {
59+ "person_id" : sim .calculate (
60+ "person_id" , map_to = "person"
61+ ).values ,
62+ "household_id" : sim .calculate (
63+ "household_id" , map_to = "person"
64+ ).values ,
65+ "tax_unit_id" : sim .calculate (
66+ "tax_unit_id" , map_to = "person"
67+ ).values ,
68+ "spm_unit_id" : sim .calculate (
69+ "spm_unit_id" , map_to = "person"
70+ ).values ,
71+ }
72+ )
73+ return self ._entity_rel_cache
74+
75+ def _evaluate_constraints_entity_aware (
76+ self , state_sim , constraints : List [dict ], n_households : int
77+ ) -> np .ndarray :
78+ """
79+ Evaluate non-geographic constraints at person level, aggregate to
80+ household level using .any().
81+
82+ This properly handles constraints on variables defined at different
83+ entity levels (e.g., tax_unit_is_filer at tax_unit level). Instead of
84+ summing values at household level (which would give 2, 3, etc. for
85+ households with multiple tax units), we evaluate at person level and
86+ use .any() aggregation ("does this household have at least one person
87+ satisfying all constraints?").
88+
89+ Args:
90+ state_sim: Microsimulation with state_fips set
91+ constraints: List of constraint dicts with variable, operation,
92+ value keys (geographic constraints should be pre-filtered)
93+ n_households: Number of households
94+
95+ Returns:
96+ Boolean mask array of length n_households
97+ """
98+ if not constraints :
99+ return np .ones (n_households , dtype = bool )
100+
101+ entity_rel = self ._build_entity_relationship (state_sim )
102+ n_persons = len (entity_rel )
103+
104+ person_mask = np .ones (n_persons , dtype = bool )
105+
106+ for c in constraints :
107+ var = c ["variable" ]
108+ op = c ["operation" ]
109+ val = c ["value" ]
110+
111+ # Calculate constraint variable at person level
112+ constraint_values = state_sim .calculate (
113+ var , map_to = "person"
114+ ).values
115+
116+ # Apply operation at person level
117+ person_mask &= apply_op (constraint_values , op , val )
118+
119+ # Aggregate to household level using .any()
120+ # "At least one person in this household satisfies ALL constraints"
121+ entity_rel_with_mask = entity_rel .copy ()
122+ entity_rel_with_mask ["satisfies" ] = person_mask
123+
124+ household_mask_series = entity_rel_with_mask .groupby ("household_id" )[
125+ "satisfies"
126+ ].any ()
127+
128+ # Ensure we return a mask aligned with household order
129+ household_ids = state_sim .calculate (
130+ "household_id" , map_to = "household"
131+ ).values
132+ household_mask = np .array (
133+ [
134+ household_mask_series .get (hh_id , False )
135+ for hh_id in household_ids
136+ ]
137+ )
138+
139+ return household_mask
41140
42141 def _query_targets (self , target_filter : dict ) -> pd .DataFrame :
43142 """Query targets based on filter criteria using OR logic."""
@@ -166,6 +265,9 @@ def build_matrix(
166265 cds_by_state [state ].append ((cd_idx , cd ))
167266
168267 for state , cd_list in cds_by_state .items ():
268+ # Clear entity relationship cache when creating new simulation
269+ self ._entity_rel_cache = None
270+
169271 if self .dataset_path :
170272 state_sim = self ._create_state_sim (state , n_households )
171273 else :
@@ -184,27 +286,43 @@ def build_matrix(
184286 for row_idx , (_ , target ) in enumerate (targets_df .iterrows ()):
185287 constraints = self ._get_constraints (target ["stratum_id" ])
186288
187- mask = np .ones (n_households , dtype = bool )
289+ geo_constraints = []
290+ non_geo_constraints = []
188291 for c in constraints :
292+ if c ["variable" ] in (
293+ "state_fips" ,
294+ "congressional_district_geoid" ,
295+ ):
296+ geo_constraints .append (c )
297+ else :
298+ non_geo_constraints .append (c )
299+
300+ # Check geographic constraints first (quick fail)
301+ geo_mask = np .ones (n_households , dtype = bool )
302+ for c in geo_constraints :
189303 if c ["variable" ] == "congressional_district_geoid" :
190304 if (
191305 c ["operation" ] in ("==" , "=" )
192306 and c ["value" ] != cd
193307 ):
194- mask [:] = False
308+ geo_mask [:] = False
195309 elif c ["variable" ] == "state_fips" :
196310 if (
197311 c ["operation" ] in ("==" , "=" )
198312 and int (c ["value" ]) != state
199313 ):
200- mask [:] = False
201- else :
202- values = state_sim .calculate (
203- c ["variable" ], map_to = "household"
204- ).values
205- mask &= apply_op (
206- values , c ["operation" ], c ["value" ]
207- )
314+ geo_mask [:] = False
315+
316+ if not geo_mask .any ():
317+ continue
318+
319+ # Evaluate non-geographic constraints at entity level
320+ entity_mask = self ._evaluate_constraints_entity_aware (
321+ state_sim , non_geo_constraints , n_households
322+ )
323+
324+ # Combine geographic and entity-aware masks
325+ mask = geo_mask & entity_mask
208326
209327 if not mask .any ():
210328 continue
0 commit comments