Skip to content

Commit 77e5d02

Browse files
committed
[ENH] Added distance-dependent cross validation
1 parent fbdf9a3 commit 77e5d02

File tree

1 file changed

+100
-0
lines changed

1 file changed

+100
-0
lines changed

netneurotools/stats.py

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,12 @@
66
import warnings
77

88
import numpy as np
9+
import random
910
from tqdm import tqdm
1011
from itertools import combinations
1112
from scipy import optimize, spatial, special, stats as sstats
1213
from scipy.stats.stats import _chk2_asarray
14+
from scipy.spatial.distance import squareform, pdist
1315
from sklearn.utils.validation import check_random_state
1416
from sklearn.linear_model import LinearRegression
1517

@@ -925,3 +927,101 @@ def get_reg_r_sq(X, y):
925927
model_metrics["full_r_sq"] = model_r_sq[tuple(range(n_predictor))]
926928

927929
return model_metrics, model_r_sq
930+
931+
932+
def cv_distance_dependent(X, y, coords, train_pct=.75, nsplits=1000,
933+
metric='rsq', use_adjusted_rsq=True):
934+
'''
935+
Distance-dependent cross-validation of regression equation `y ~ X`
936+
937+
Parameters
938+
----------
939+
X : (N[, R]) array_like
940+
Coefficient matrix of `R` variables for `N` brain regions
941+
y : (N,) array_like
942+
Dependent variable vector for `N` brain regions
943+
coords : (N, 3) array_like
944+
Coordinate matrix for `N` brain regions
945+
train_pct : float, optional
946+
Percentage of brain regions in the training set. 0 < train_pct < 1
947+
Default: 75%
948+
nsplits : float, optional
949+
Number of train/test splits. Default: 1000
950+
metric : {'rsq', 'corr'}, optional
951+
Metric of model assessment. 'rsq' will return the adjusted r-squared
952+
of the train and test set model performance for each split. 'corr' will
953+
return the correlation between predicted and empiracle observations
954+
for the train and test set. Default: 'rsq'
955+
use_adjusted_rsq : bool, optional
956+
Whether to use adjusted r-squared. Only relevant if metric is 'rsq'.
957+
Default: True
958+
959+
Returns
960+
-------
961+
train_metric : (nsplits,) list
962+
List of length `nsplits` of performance metric on the training set.
963+
test_metric : (nsplits,) list
964+
List of length `nsplits` of performance metric on the test set.
965+
'''
966+
967+
# sklearn linear regression wrapper
968+
def get_reg_r_sq(X, y):
969+
lin_reg = LinearRegression()
970+
lin_reg.fit(X, y)
971+
yhat = lin_reg.predict(X)
972+
SS_Residual = sum((y - yhat) ** 2)
973+
SS_Total = sum((y - np.mean(y)) ** 2)
974+
r_squared = 1 - (float(SS_Residual)) / SS_Total
975+
adjusted_r_squared = 1 - (1 - r_squared) * \
976+
(len(y) - 1) / (len(y) - X.shape[1] - 1)
977+
if use_adjusted_rsq:
978+
return adjusted_r_squared
979+
else:
980+
return r_squared
981+
982+
P = squareform(pdist(coords, metric="euclidean"))
983+
train_metric = []
984+
test_metric = []
985+
986+
for i in range(nsplits):
987+
988+
# randomly chosen source node
989+
sourceNode = random.choice(range(0, len(coords), 1))
990+
# distance from source node to all other nodes in network
991+
distances = P[sourceNode, :]
992+
idx = np.argsort(distances)
993+
994+
# train_pct of nodes closest to source node comprise the training set
995+
# the remaining (1 - train_pct) of nodes comprise the test set
996+
train_idx = idx[:int(np.floor(train_pct * len(coords)))]
997+
test_idx = idx[int(np.floor(train_pct * len(coords))):]
998+
999+
# linear regression
1000+
mdl = LinearRegression()
1001+
mdl.fit(X[train_idx, :], y[train_idx])
1002+
if metric == 'rsq':
1003+
# get r^2 of train set
1004+
train_metric.append(get_reg_r_sq(X[train_idx, :], y[train_idx]))
1005+
1006+
elif metric == 'corr':
1007+
rho, _ = sstats.pearsonr(mdl.predict(X[train_idx, :]),
1008+
y[train_idx])
1009+
train_metric.append(rho)
1010+
1011+
# prediction on test set
1012+
yhat = mdl.predict(X[test_idx, :])
1013+
if metric == 'rsq':
1014+
# get r^2 of test set
1015+
SS_Residual = sum((y[test_idx] - yhat) ** 2)
1016+
SS_Total = sum((y[test_idx] - np.mean(y[test_idx])) ** 2)
1017+
r_squared = 1 - (float(SS_Residual)) / SS_Total
1018+
adjusted_r_squared = 1-(1-r_squared)*((len(y[test_idx]) - 1) /
1019+
(len(y[test_idx]) -
1020+
X.shape[1]-1))
1021+
test_metric.append(adjusted_r_squared)
1022+
1023+
elif metric == 'corr':
1024+
rho, _ = sstats.pearsonr(yhat, y[test_idx])
1025+
test_metric.append(rho)
1026+
1027+
return train_metric, test_metric

0 commit comments

Comments
 (0)