Skip to content

Commit 1d83e1f

Browse files
committed
change to try predict_proba; if not available, use predict
1 parent 2519472 commit 1d83e1f

File tree

1 file changed

+8
-6
lines changed

1 file changed

+8
-6
lines changed

vimpy/predictiveness_measures.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -46,10 +46,11 @@ def cv_predictiveness(x, y, S, measure, pred_func, V = 5, stratified = True, na_
4646
if ensemble:
4747
preds_v = np.mean(pred_func.transform(x_train[:, S]))
4848
else:
49-
if measure.__name__ in ["r_squared"]:
50-
preds_v = pred_func.predict(x_train[:, S])
51-
else:
49+
try:
5250
preds_v = pred_func.predict_proba(x_train[:, S])[:, 1]
51+
except AttributeError:
52+
preds_v = pred_func.predict(x_train[:, S])
53+
5354
preds[cc_cond] = preds_v
5455
vs[0] = measure(y_train, preds_v)
5556
ics[cc_cond] = compute_ic(y_train, preds_v, measure.__name__)
@@ -62,10 +63,11 @@ def cv_predictiveness(x, y, S, measure, pred_func, V = 5, stratified = True, na_
6263
if ensemble:
6364
preds_v = np.mean(pred_func.transform(x_test[:, S]))
6465
else:
65-
if measure.__name__ in ["r_squared"]:
66-
preds_v = pred_func.predict(x_test[:, S])
67-
else:
66+
try:
6867
preds_v = pred_func.predict_proba(x_test[:, S])[:, 1]
68+
except AttributeError:
69+
preds_v = pred_func.predict(x_test[:, S])
70+
6971
preds[cc_cond[fold_cond]] = preds_v
7072
vs[v] = measure(y_test, preds_v)
7173
ics[cc_cond[fold_cond]] = compute_ic(y_test, preds_v, measure.__name__)

0 commit comments

Comments
 (0)