@@ -47,9 +47,27 @@ def __init__(self, y, x, measure_type, V, pred_func, ensemble = False, na_rm = F
4747 self .G_ = np .vstack ((np .append (1 , np .zeros (self .p_ )), np .ones (self .p_ + 1 ) - np .append (1 , np .zeros (self .p_ ))))
4848 ## set up outer folds for hypothesis testing
4949 self .folds_outer_ = np .random .choice (a = np .arange (2 ), size = self .n_ , replace = True , p = np .array ([0.25 , 0.75 ]))
50+ self .folds_inner_ = []
5051 ## if only two unique values in y, assume binary
5152 self .binary_ = (np .unique (y ).shape [0 ] == 2 )
5253 self .ensemble_ = ensemble
54+ self .cc_ = []
55+
56+ def _get_kkt_matrix (self ):
57+ # kkt matrix for constrained wls
58+ A_W = np .sqrt (self .W_ ).dot (self .Z_ )
59+ kkt_matrix_11 = 2 * A_W .transpose ().dot (A_W )
60+ kkt_matrix_12 = self .G_ .transpose ()
61+ kkt_matrix_21 = self .G_
62+ kkt_matrix_22 = np .zeros ((kkt_matrix_21 .shape [0 ], kkt_matrix_12 .shape [1 ]))
63+ kkt_matrix = np .vstack ((np .hstack ((kkt_matrix_11 , kkt_matrix_12 )), np .hstack ((kkt_matrix_21 , kkt_matrix_22 ))))
64+ return (kkt_matrix )
65+
66+ def _get_ls_matrix (self , c_n ):
67+ A_W = np .sqrt (self .W_ ).dot (self .Z_ )
68+ v_W = np .sqrt (self .W_ ).dot (self .v_ )
69+ ls_matrix = np .vstack ((2 * A_W .transpose ().dot (v_W .reshape ((len (v_W ), 1 ))), c_n .reshape ((c_n .shape [0 ], 1 ))))
70+ return (ls_matrix )
5371
5472 ## calculate the point estimates
5573 def get_point_est (self , gamma = 1 ):
@@ -70,7 +88,7 @@ def get_point_est(self, gamma = 1):
7088 v_none = self .measure_ (self .y_ [self .folds_outer_ == 1 ], preds_none )
7189 ic_none = compute_ic (self .y_ [self .folds_outer_ == 1 ], preds_none , self .measure_ .__name__ )
7290 ## get v, preds, ic for remaining non-null groups in S
73- v_lst , preds_lst , ic_lst = zip (* (cv_predictiveness (self .x_ [self .folds_outer_ == 1 , :], self .y_ [self .folds_outer_ == 1 ], s , self .measure_ , self .pred_func_ , V = self .V_ , stratified = self .binary_ , na_rm = self .na_rm_ ) for s in S_lst [1 :]))
91+ v_lst , preds_lst , ic_lst , self . folds_inner_ , self . cc_ = zip (* (cv_predictiveness (self .x_ [self .folds_outer_ == 1 , :], self .y_ [self .folds_outer_ == 1 ], s , self .measure_ , self .pred_func_ , V = self .V_ , stratified = self .binary_ , na_rm = self .na_rm_ ) for s in S_lst [1 :]))
7492 ## set up full lists
7593 v_lst_all = [v_none ] + list (v_lst )
7694 ic_lst_all = [ic_none ] + list (ic_lst )
@@ -79,19 +97,11 @@ def get_point_est(self, gamma = 1):
7997 self .v_ = np .array (v_lst_all )
8098 self .v_ics_ = ic_lst_all
8199 c_n = np .array ([v_none , v_lst_all [len (v_lst )] - v_none ])
82- ## do constrained ls
83- A_W = np .sqrt (self .W_ ).dot (self .Z_ )
84- v_W = np .sqrt (self .W_ ).dot (self .v_ )
85- kkt_matrix_11 = 2 * A_W .transpose ().dot (A_W )
86- kkt_matrix_12 = self .G_ .transpose ()
87- kkt_matrix_21 = self .G_
88- kkt_matrix_22 = np .zeros ((kkt_matrix_21 .shape [0 ], kkt_matrix_12 .shape [1 ]))
89- kkt_matrix = np .vstack ((np .hstack ((kkt_matrix_11 , kkt_matrix_12 )), np .hstack ((kkt_matrix_21 , kkt_matrix_22 ))))
90- ls_matrix = np .vstack ((2 * A_W .transpose ().dot (v_W .reshape ((len (v_W ), 1 ))), c_n .reshape ((c_n .shape [0 ], 1 ))))
100+ kkt_matrix = self ._get_kkt_matrix ()
101+ ls_matrix = self ._get_ls_matrix (c_n )
91102 ls_solution = np .linalg .inv (kkt_matrix ).dot (ls_matrix )
92103 self .vimp_ = ls_solution [0 :(self .p_ + 1 ), :]
93104 self .lambdas_ = ls_solution [(self .p_ + 1 ):ls_solution .shape [0 ], :]
94-
95105 return (self )
96106
97107 ## calculate the influence function
0 commit comments