Skip to content

Commit 243eb26

Browse files
committed
break out ls and kkt into new functions
1 parent 39b4a17 commit 243eb26

File tree

1 file changed

+21
-11
lines changed

1 file changed

+21
-11
lines changed

vimpy/spvim.py

Lines changed: 21 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -47,9 +47,27 @@ def __init__(self, y, x, measure_type, V, pred_func, ensemble = False, na_rm = F
4747
self.G_ = np.vstack((np.append(1, np.zeros(self.p_)), np.ones(self.p_ + 1) - np.append(1, np.zeros(self.p_))))
4848
## set up outer folds for hypothesis testing
4949
self.folds_outer_ = np.random.choice(a = np.arange(2), size = self.n_, replace = True, p = np.array([0.25, 0.75]))
50+
self.folds_inner_ = []
5051
## if only two unique values in y, assume binary
5152
self.binary_ = (np.unique(y).shape[0] == 2)
5253
self.ensemble_ = ensemble
54+
self.cc_ = []
55+
56+
def _get_kkt_matrix(self):
57+
# kkt matrix for constrained wls
58+
A_W = np.sqrt(self.W_).dot(self.Z_)
59+
kkt_matrix_11 = 2 * A_W.transpose().dot(A_W)
60+
kkt_matrix_12 = self.G_.transpose()
61+
kkt_matrix_21 = self.G_
62+
kkt_matrix_22 = np.zeros((kkt_matrix_21.shape[0], kkt_matrix_12.shape[1]))
63+
kkt_matrix = np.vstack((np.hstack((kkt_matrix_11, kkt_matrix_12)), np.hstack((kkt_matrix_21, kkt_matrix_22))))
64+
return(kkt_matrix)
65+
66+
def _get_ls_matrix(self, c_n):
67+
A_W = np.sqrt(self.W_).dot(self.Z_)
68+
v_W = np.sqrt(self.W_).dot(self.v_)
69+
ls_matrix = np.vstack((2 * A_W.transpose().dot(v_W.reshape((len(v_W), 1))), c_n.reshape((c_n.shape[0], 1))))
70+
return(ls_matrix)
5371

5472
## calculate the point estimates
5573
def get_point_est(self, gamma = 1):
@@ -70,7 +88,7 @@ def get_point_est(self, gamma = 1):
7088
v_none = self.measure_(self.y_[self.folds_outer_ == 1], preds_none)
7189
ic_none = compute_ic(self.y_[self.folds_outer_ == 1], preds_none, self.measure_.__name__)
7290
## get v, preds, ic for remaining non-null groups in S
73-
v_lst, preds_lst, ic_lst = zip(*(cv_predictiveness(self.x_[self.folds_outer_ == 1, :], self.y_[self.folds_outer_ == 1], s, self.measure_, self.pred_func_, V = self.V_, stratified = self.binary_, na_rm = self.na_rm_) for s in S_lst[1:]))
91+
v_lst, preds_lst, ic_lst, self.folds_inner_, self.cc_ = zip(*(cv_predictiveness(self.x_[self.folds_outer_ == 1, :], self.y_[self.folds_outer_ == 1], s, self.measure_, self.pred_func_, V = self.V_, stratified = self.binary_, na_rm = self.na_rm_) for s in S_lst[1:]))
7492
## set up full lists
7593
v_lst_all = [v_none] + list(v_lst)
7694
ic_lst_all = [ic_none] + list(ic_lst)
@@ -79,19 +97,11 @@ def get_point_est(self, gamma = 1):
7997
self.v_ = np.array(v_lst_all)
8098
self.v_ics_ = ic_lst_all
8199
c_n = np.array([v_none, v_lst_all[len(v_lst)] - v_none])
82-
## do constrained ls
83-
A_W = np.sqrt(self.W_).dot(self.Z_)
84-
v_W = np.sqrt(self.W_).dot(self.v_)
85-
kkt_matrix_11 = 2 * A_W.transpose().dot(A_W)
86-
kkt_matrix_12 = self.G_.transpose()
87-
kkt_matrix_21 = self.G_
88-
kkt_matrix_22 = np.zeros((kkt_matrix_21.shape[0], kkt_matrix_12.shape[1]))
89-
kkt_matrix = np.vstack((np.hstack((kkt_matrix_11, kkt_matrix_12)), np.hstack((kkt_matrix_21, kkt_matrix_22))))
90-
ls_matrix = np.vstack((2 * A_W.transpose().dot(v_W.reshape((len(v_W), 1))), c_n.reshape((c_n.shape[0], 1))))
100+
kkt_matrix = self._get_kkt_matrix()
101+
ls_matrix = self._get_ls_matrix(c_n)
91102
ls_solution = np.linalg.inv(kkt_matrix).dot(ls_matrix)
92103
self.vimp_ = ls_solution[0:(self.p_ + 1), :]
93104
self.lambdas_ = ls_solution[(self.p_ + 1):ls_solution.shape[0], :]
94-
95105
return(self)
96106

97107
## calculate the influence function

0 commit comments

Comments
 (0)