@@ -28,6 +28,7 @@ def reweight(
2828 targets_array ,
2929 dropout_rate = 0.05 ,
3030 log_path = "calibration_log.csv" ,
31+ epochs = 150 ,
3132):
3233 target_names = np .array (loss_matrix .columns )
3334 is_national = loss_matrix .columns .str .startswith ("nation/" )
@@ -45,7 +46,7 @@ def reweight(
4546 np .log (original_weights ), requires_grad = True , dtype = torch .float32
4647 )
4748
48- # TODO: replace this with a call to the python reweight.py package.
49+ # TODO: replace this functionality from the microcalibrate package.
4950 def loss (weights ):
5051 # Check for Nans in either the weights or the loss matrix
5152 if torch .isnan (weights ).any ():
@@ -78,7 +79,7 @@ def dropout_weights(weights, p):
7879
7980 start_loss = None
8081
81- iterator = trange (500 )
82+ iterator = trange (epochs )
8283 performance = pd .DataFrame ()
8384 for i in iterator :
8485 optimizer .zero_grad ()
@@ -178,18 +179,71 @@ def generate(self):
178179 original_weights = original_weights .values + np .random .normal (
179180 1 , 0.1 , len (original_weights )
180181 )
182+
183+ bad_targets = [
184+ "nation/irs/adjusted gross income/total/AGI in 10k-15k/taxable/Head of Household" ,
185+ "nation/irs/adjusted gross income/total/AGI in 15k-20k/taxable/Head of Household" ,
186+ "nation/irs/adjusted gross income/total/AGI in 10k-15k/taxable/Married Filing Jointly/Surviving Spouse" ,
187+ "nation/irs/adjusted gross income/total/AGI in 15k-20k/taxable/Married Filing Jointly/Surviving Spouse" ,
188+ "nation/irs/count/count/AGI in 10k-15k/taxable/Head of Household" ,
189+ "nation/irs/count/count/AGI in 15k-20k/taxable/Head of Household" ,
190+ "nation/irs/count/count/AGI in 10k-15k/taxable/Married Filing Jointly/Surviving Spouse" ,
191+ "nation/irs/count/count/AGI in 15k-20k/taxable/Married Filing Jointly/Surviving Spouse" ,
192+ "state/RI/adjusted_gross_income/amount/-inf_1" ,
193+ "nation/irs/adjusted gross income/total/AGI in 10k-15k/taxable/Head of Household" ,
194+ "nation/irs/adjusted gross income/total/AGI in 15k-20k/taxable/Head of Household" ,
195+ "nation/irs/adjusted gross income/total/AGI in 10k-15k/taxable/Married Filing Jointly/Surviving Spouse" ,
196+ "nation/irs/adjusted gross income/total/AGI in 15k-20k/taxable/Married Filing Jointly/Surviving Spouse" ,
197+ "nation/irs/count/count/AGI in 10k-15k/taxable/Head of Household" ,
198+ "nation/irs/count/count/AGI in 15k-20k/taxable/Head of Household" ,
199+ "nation/irs/count/count/AGI in 10k-15k/taxable/Married Filing Jointly/Surviving Spouse" ,
200+ "nation/irs/count/count/AGI in 15k-20k/taxable/Married Filing Jointly/Surviving Spouse" ,
201+ "state/RI/adjusted_gross_income/amount/-inf_1" ,
202+ "nation/irs/exempt interest/count/AGI in -inf-inf/taxable/All" ,
203+ ]
204+
205+ # Run the optimization procedure to get (close to) minimum loss weights
181206 for year in range (self .start_year , self .end_year + 1 ):
182207 loss_matrix , targets_array = build_loss_matrix (
183208 self .input_dataset , year
184209 )
210+ zero_mask = np .isclose (targets_array , 0.0 , atol = 0.1 )
211+ bad_mask = loss_matrix .columns .isin (bad_targets )
212+ keep_mask_bool = ~ (zero_mask | bad_mask )
213+ keep_idx = np .where (keep_mask_bool )[0 ]
214+ loss_matrix_clean = loss_matrix .iloc [:, keep_idx ]
215+ targets_array_clean = targets_array [keep_idx ]
216+ assert loss_matrix_clean .shape [1 ] == targets_array_clean .size
217+
185218 optimised_weights = reweight (
186219 original_weights ,
187- loss_matrix ,
188- targets_array ,
220+ loss_matrix_clean ,
221+ targets_array_clean ,
189222 log_path = "calibration_log.csv" ,
223+ epochs = 150 ,
190224 )
191225 data ["household_weight" ][year ] = optimised_weights
192226
227+ print ("\n \n ---reweighting quick diagnostics----\n " )
228+ estimate = optimised_weights @ loss_matrix_clean
229+ rel_error = (
230+ ((estimate - targets_array_clean ) + 1 )
231+ / (targets_array_clean + 1 )
232+ ) ** 2
233+ print (
234+ f"rel_error: min: { np .min (rel_error ):.2f} , "
235+ f"max: { np .max (rel_error ):.2f} "
236+ f"mean: { np .mean (rel_error ):.2f} , "
237+ f"median: { np .median (rel_error ):.2f} "
238+ )
239+ print ("Relative error over 100% for:" )
240+ for i in np .where (rel_error > 1 )[0 ]:
241+ print (f"target_name: { loss_matrix_clean .columns [i ]} " )
242+ print (f"target_value: { targets_array_clean [i ]} " )
243+ print (f"estimate_value: { estimate [i ]} " )
244+ print (f"has rel_error: { rel_error [i ]:.2f} \n " )
245+ print ("---End of reweighting quick diagnostics------" )
246+
193247 self .save_dataset (data )
194248
195249
0 commit comments