Skip to content

Commit 7b43865

Browse files
committed
use only predicted prob of class 1
1 parent a5f057f commit 7b43865

File tree

3 files changed

+6
-5
lines changed

3 files changed

+6
-5
lines changed

vimpy/predictiveness_measures.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ def cv_predictiveness(x, y, S, measure, pred_func, V = 5, stratified = True, na_
6565
if measure.__name__ in ["r_squared"]:
6666
preds_v = pred_func.predict(x_test[:, S])
6767
else:
68-
preds_v = pred_func.predict_proba(x_test[:, S])
68+
preds_v = pred_func.predict_proba(x_test[:, S])[:, 1]
6969
preds[cc_cond[fold_cond]] = preds_v
7070
vs[v] = measure(y_test, preds_v)
7171
ics[cc_cond[fold_cond]] = compute_ic(y_test, preds_v, measure.__name__)

vimpy/spvim.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,10 @@
44
## import required libraries
55
import numpy as np
66
from scipy.stats import norm
7-
import predictiveness_measures as mp
8-
import spvim_ic as sic
9-
import utils as uts
7+
from .predictiveness_measures import cv_predictiveness
8+
from .spvim_ic import shapley_influence_function, shapley_se
9+
from .vimpy_utils import get_measure_function
10+
1011

1112
class spvim:
1213

vimpy/vim.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import numpy as np
66
from scipy.stats import norm
77
from .predictiveness_measures import cv_predictiveness, cv_predictiveness_precomputed
8-
from .vimpy_utils import get_measure_function, make_folds
8+
from .vimpy_utils import get_measure_function
99

1010

1111
class vim:

0 commit comments

Comments
 (0)