Skip to content

Commit 39b4a17

Browse files
committed
update spvim
1 parent 7b43865 commit 39b4a17

File tree

1 file changed

+19
-21
lines changed

1 file changed

+19
-21
lines changed

vimpy/spvim.py

Lines changed: 19 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,9 @@
44
## import required libraries
55
import numpy as np
66
from scipy.stats import norm
7-
from .predictiveness_measures import cv_predictiveness
7+
from .predictiveness_measures import cv_predictiveness, compute_ic
88
from .spvim_ic import shapley_influence_function, shapley_se
9-
from .vimpy_utils import get_measure_function
9+
from .vimpy_utils import get_measure_function, choose, shapley_hyp_test
1010

1111

1212
class spvim:
@@ -15,20 +15,21 @@ class spvim:
1515
"""
1616
@param y the outcome
1717
@param x the feature data
18-
@param pred_func the function that predicts outcome given features
19-
@param V the number of cross-validation folds
2018
@param measure_type the predictiveness measure to use (a function)
21-
@param na_rm remove NAs prior to computing predictiveness?
19+
@param V the number of cross-validation folds
20+
@param pred_func the function that predicts outcome given features
21+
@param ensemble is pred_func an ensemble (True) or a single function (False, default)
22+
@param na_rm remove NAs prior to computing predictiveness? (defaults to False)
2223
"""
23-
def __init__(self, y, x, pred_func, V, measure_type, na_rm):
24+
def __init__(self, y, x, measure_type, V, pred_func, ensemble = False, na_rm = False):
2425
self.y_ = y
2526
self.x_ = x
2627
self.n_ = y.shape[0]
2728
self.p_ = x.shape[0]
2829
self.pred_func_ = pred_func
2930
self.V_ = V
3031
self.measure_type_ = measure_type
31-
self.measure_ = uts.get_measure_function(measure_type)
32+
self.measure_ = get_measure_function(measure_type)
3233
self.ics_ = []
3334
self.vimp_ = []
3435
self.lambdas_ = []
@@ -48,31 +49,30 @@ def __init__(self, y, x, pred_func, V, measure_type, na_rm):
4849
self.folds_outer_ = np.random.choice(a = np.arange(2), size = self.n_, replace = True, p = np.array([0.25, 0.75]))
4950
## if only two unique values in y, assume binary
5051
self.binary_ = (np.unique(y).shape[0] == 2)
51-
52+
self.ensemble_ = ensemble
5253

5354
## calculate the point estimates
5455
def get_point_est(self, gamma = 1):
5556
self.gamma_ = gamma
5657
## sample subsets, set up Z
5758
max_subset = np.array(list(range(self.p_)))
58-
sampling_weights = np.append(np.append(1, [uts.choose(self.p_ - 2, s - 1) ** (-1) for s in range(1, self.p_)]), 1)
59+
sampling_weights = np.append(np.append(1, [choose(self.p_ - 2, s - 1) ** (-1) for s in range(1, self.p_)]), 1)
5960
subset_sizes = np.random.choice(np.arange(0, self.p_ + 1), p = sampling_weights / sum(sampling_weights), size = self.gamma_ * self.x_.shape[0], replace = True)
6061
S_lst_all = [np.sort(np.random.choice(np.arange(0, self.p_), subset_size, replace = False)) for subset_size in list(subset_sizes)]
6162
## only need to continue with the unique subsets S
6263
Z_lst_all = [np.in1d(max_subset, S).astype(np.float64) for S in S_lst_all]
6364
Z, z_counts = np.unique(np.array(Z_lst_all), axis = 0, return_counts = True)
6465
Z_lst = list(Z)
65-
Z_aug_lst = [np.append(1, Z) for Z in Z_lst]
66-
S_lst = [max_subset[Z.astype(bool).tolist()] for Z in Z_lst]
66+
Z_aug_lst = [np.append(1, z) for z in Z_lst]
67+
S_lst = [max_subset[z.astype(bool).tolist()] for z in Z_lst]
6768
## get v, preds, ic for null set
6869
preds_none = np.repeat(np.mean(self.y_[self.folds_outer_ == 1]), np.sum(self.folds_outer_ == 1))
6970
v_none = self.measure_(self.y_[self.folds_outer_ == 1], preds_none)
70-
ic_none = mp.compute_ic(self.y_[self.folds_outer_ == 1], preds_none, self.measure_.__name__)
71+
ic_none = compute_ic(self.y_[self.folds_outer_ == 1], preds_none, self.measure_.__name__)
7172
## get v, preds, ic for remaining non-null groups in S
72-
v_lst, preds_lst, ic_lst = zip(*(mp.cv_predictiveness(self.x_[self.folds_outer_ == 1, :], self.y_[self.folds_outer_ == 1], s, self.measure_, self.pred_func_, V = self.V_, stratified = self.binary_, na_rm = self.na_rm_) for s in S_lst[1:]))
73+
v_lst, preds_lst, ic_lst = zip(*(cv_predictiveness(self.x_[self.folds_outer_ == 1, :], self.y_[self.folds_outer_ == 1], s, self.measure_, self.pred_func_, V = self.V_, stratified = self.binary_, na_rm = self.na_rm_) for s in S_lst[1:]))
7374
## set up full lists
7475
v_lst_all = [v_none] + list(v_lst)
75-
preds_lst_all = [preds_none] + list(preds_lst)
7676
ic_lst_all = [ic_none] + list(ic_lst)
7777
self.Z_ = np.array(Z_aug_lst)
7878
self.W_ = np.diag(z_counts / np.sum(z_counts))
@@ -90,19 +90,19 @@ def get_point_est(self, gamma = 1):
9090
ls_matrix = np.vstack((2 * A_W.transpose().dot(v_W.reshape((len(v_W), 1))), c_n.reshape((c_n.shape[0], 1))))
9191
ls_solution = np.linalg.inv(kkt_matrix).dot(ls_matrix)
9292
self.vimp_ = ls_solution[0:(self.p_ + 1), :]
93-
self.lambdas_ = ls_solution[(self.p_+1):ls_solution.shape[0], :]
93+
self.lambdas_ = ls_solution[(self.p_ + 1):ls_solution.shape[0], :]
9494

9595
return(self)
9696

9797
## calculate the influence function
9898
def get_influence_functions(self):
9999
c_n = np.array([self.v_[0], self.v_[self.v.shape[0]] - self.v_[0]])
100-
self.ics_ = sic.shapley_influence_function(self.Z_, self.z_counts_, self.W_, self.v_, self.vimp_, self.G_, c_n, np.array(self.v_ics_), self.measure_.__name__)
100+
self.ics_ = shapley_influence_function(self.Z_, self.z_counts_, self.W_, self.v_, self.vimp_, self.G_, c_n, np.array(self.v_ics_), self.measure_.__name__)
101101
return self
102102

103103
## calculate standard errors
104104
def get_ses(self):
105-
ses = [sic.shapley_se(self.ics_, idx, self.gamma_) for idx in range(self.p_)]
105+
ses = [shapley_se(self.ics_, idx, self.gamma_) for idx in range(self.p_)]
106106
self.ses_ = np.array(ses)
107107
return self
108108

@@ -113,8 +113,6 @@ def get_cis(self, level = 0.95):
113113
a = np.array([a, 1 - a])
114114
## calculate the quantiles
115115
fac = norm.ppf(a)
116-
## set up the ci array
117-
ci = np.zeros((self.vimp_.shape[0], 2))
118116
## create it
119117
self.ci_ = self.vimp_ + np.outer((self.ses_), fac)
120118
return self
@@ -124,10 +122,10 @@ def hypothesis_test(self, alpha = 0.05, delta = 0):
124122
## null predictiveness
125123
preds_none_0 = np.repeat(np.mean(self.y_[self.folds_outer_ == 0]), np.sum(self.folds_outer_ == 0))
126124
v_none_0 = self.measure_(self.y_[self.folds_outer_ == 0], preds_none_0)
127-
ic_none_0 = mp.compute_ic(self.y_[self.folds_outer_ == 0], preds_none_0, self.measure_.__name__)
125+
ic_none_0 = compute_ic(self.y_[self.folds_outer_ == 0], preds_none_0, self.measure_.__name__)
128126
sigma_none_0 = np.sqrt(np.mean((ic_none_0) ** 2)) / np.sqrt(np.sum(self.folds_outer_ == 0))
129127
## get shapley values + null predictiveness on first split
130128
shapley_vals_plus = self.vimp_ + self.vimp_[0]
131129
sigmas_one = np.sqrt(self.ses_ ** 2 + sigma_none_0 ** 2)
132-
self.test_statistics_, self.p_values_, self.hyp_tests_ = uts.shapley_hyp_test(shapley_vals_plus[1:], v_none_0, sigmas_one, sigma_none_0, level = alpha, delta = delta, p = self.p_)
130+
self.test_statistics_, self.p_values_, self.hyp_tests_ = shapley_hyp_test(shapley_vals_plus[1:], v_none_0, sigmas_one, sigma_none_0, level = alpha, delta = delta, p = self.p_)
133131
return self

0 commit comments

Comments
 (0)