Skip to content

Commit b4d37af

Browse files
authored
Merge pull request #63 from eurunuela/surrogate_histograms
Change how surrogate histograms are calculated
2 parents 555b4e2 + 1526086 commit b4d37af

File tree

3 files changed

+28
-46
lines changed

3 files changed

+28
-46
lines changed

connPFM/connectivity/connectivity_utils.py

Lines changed: 14 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
import logging
33

44
import numpy as np
5-
from joblib import Parallel, delayed
65
from scipy.stats import zscore
76

87
LGR = logging.getLogger(__name__)
@@ -19,7 +18,7 @@ def calculate_ets(y, n):
1918
return ets, u, v
2019

2120

22-
def rss_surr(z_ts, u, v, surrprefix, sursufix, masker, irand):
21+
def rss_surr(z_ts, u, v, surrprefix, sursufix, masker, irand, nbins, hist_range=(0, 1)):
2322
"""Calculate RSS on surrogate data."""
2423
[t, n] = z_ts.shape
2524

@@ -42,7 +41,10 @@ def rss_surr(z_ts, u, v, surrprefix, sursufix, masker, irand):
4241
# calcuate rss
4342
rssr = np.sqrt(np.sum(np.square(etsr), axis=1))
4443

45-
return (rssr, etsr, np.min(etsr), np.max(etsr))
44+
# Calculate histogram
45+
ets_hist, bin_edges = np.histogram(etsr.flatten(), bins=nbins, range=hist_range)
46+
47+
return (rssr, etsr, ets_hist, bin_edges)
4648

4749

4850
def remove_neighboring_peaks(rss, idx):
@@ -136,35 +138,20 @@ def calculate_hist_threshold(hist, bins, percentile=95):
136138
return thr
137139

138140

139-
def surrogates_histogram(
140-
surrprefix,
141-
sursufix,
142-
masker,
143-
hist_range,
144-
numrand=100,
145-
nbins=500,
146-
percentile=95,
141+
def sum_histograms(
142+
hist_list,
147143
):
148144
"""
149-
Read AUCs of surrogates, calculate histogram and sum of all histograms to
145+
Get histograms of all surrogates and sum them to
150146
obtain a single histogram that summarizes the data.
151147
"""
152-
ets_hist = np.zeros((numrand, nbins))
153-
154-
# calculate histogram for each surrogate
155-
hist = Parallel(n_jobs=-1, backend="multiprocessing")(
156-
delayed(calculate_hist)(surrprefix, sursufix, irand, masker, hist_range, nbins)
157-
for irand in range(numrand)
158-
)
159148

160-
for irand in range(numrand):
161-
ets_hist[irand, :] = hist[irand][0]
149+
# Initialize matrix to store surrogate histograms
150+
all_hists = np.zeros((len(hist_list), hist_list[0][3].shape[0] - 1))
162151

163-
bin_edges = hist[0][1]
152+
for rand_idx in range(len(hist_list)):
153+
all_hists[rand_idx, :] = hist_list[rand_idx][2]
164154

165-
ets_hist_sum = np.sum(ets_hist, axis=0)
155+
ets_hist_sum = np.sum(all_hists, axis=0)
166156

167-
# calculate histogram threshold
168-
thr = calculate_hist_threshold(ets_hist_sum, bin_edges, percentile)
169-
170-
return thr
157+
return ets_hist_sum

connPFM/connectivity/ev.py

Lines changed: 12 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ def event_detection(
2121
nsur=100,
2222
segments=True,
2323
peak_detection="rss",
24+
nbins=1000,
2425
):
2526
"""Perform event detection on given data."""
2627
masker = NiftiLabelsMasker(
@@ -55,19 +56,14 @@ def event_detection(
5556
rssr = np.zeros([t, nsur])
5657

5758
# calculate ets and rss of surrogate data
59+
LGR.info("Calculating edge-time matrix, RSS and histograms for surrogates...")
5860
surrogate_events = Parallel(n_jobs=-1, backend="multiprocessing")(
59-
delayed(connectivity_utils.rss_surr)(z_ts, u, v, surrprefix, sursufix, masker, irand)
61+
delayed(connectivity_utils.rss_surr)(
62+
z_ts, u, v, surrprefix, sursufix, masker, irand, nbins
63+
)
6064
for irand in range(nsur)
6165
)
6266

63-
hist_ranges = np.zeros((2, nsur))
64-
for irand in range(nsur):
65-
hist_ranges[0, irand] = surrogate_events[irand][2]
66-
hist_ranges[1, irand] = surrogate_events[irand][3]
67-
68-
hist_min = np.min(hist_ranges, axis=1)[0]
69-
hist_max = np.max(hist_ranges, axis=1)[1]
70-
7167
# Make selection of points with RSS
7268
if "rss" in peak_detection:
7369
LGR.info("Selecting points with RSS...")
@@ -117,13 +113,13 @@ def event_detection(
117113
LGR.info("Selecting points with edge time-series matrix...")
118114
if peak_detection == "ets":
119115
LGR.info("Reading AUC of surrogates to perform the thresholding step...")
120-
thr = connectivity_utils.surrogates_histogram(
121-
surrprefix,
122-
sursufix,
123-
masker,
124-
hist_range=(hist_min, hist_max),
125-
numrand=nsur,
116+
hist_sum = connectivity_utils.sum_histograms(
117+
surrogate_events,
126118
)
119+
thr = connectivity_utils.calculate_hist_threshold(
120+
hist_sum, surrogate_events[0][3][:-1], percentile=95
121+
)
122+
127123
elif peak_detection == "ets_time":
128124
# Initialize array for threshold
129125
thr = np.zeros(t)
@@ -137,9 +133,7 @@ def event_detection(
137133
sur_ets_at_time[sur_idx, :] = surrogate_events[sur_idx][1][time_idx, :]
138134

139135
# calculate histogram of all surrogate ets at time point
140-
hist, bins = np.histogram(
141-
sur_ets_at_time.flatten(), bins=500, range=(hist_min, hist_max)
142-
)
136+
hist, bins = np.histogram(sur_ets_at_time.flatten(), bins=nbins, range=(0, 1))
143137

144138
# calculate threshold for time point
145139
thr[time_idx] = connectivity_utils.calculate_hist_threshold(

connPFM/tests/test_ev.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,12 @@ def test_rss_surr(AUC_file, atlas_file, surr_dir, rssr_auc_file):
2626
standardize=False,
2727
strategy="mean",
2828
)
29+
nbins = 1000
2930

3031
AUC_img = masker.fit_transform(AUC_file)
3132
_, u, v = connectivity_utils.calculate_ets(AUC_img, AUC_img.shape[1])
3233
rssr, _, _, _ = connectivity_utils.rss_surr(
33-
AUC_img, u, v, join(surr_dir, "surrogate_AUC_"), "", masker, 0
34+
AUC_img, u, v, join(surr_dir, "surrogate_AUC_"), "", masker, 0, nbins
3435
)
3536
rssr_auc = np.loadtxt(rssr_auc_file)
3637
assert np.all(np.isclose(rssr, rssr_auc))

0 commit comments

Comments
 (0)