Skip to content

Commit ea27979

Browse files
authored
Merge pull request #230 from neurodsp-tools/rhythm
[ENH] - Update rhythm code & docs
2 parents a385e73 + 5a9b8b6 commit ea27979

File tree

7 files changed

+430
-166
lines changed

7 files changed

+430
-166
lines changed

neurodsp/plts/rhythm.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,14 +24,16 @@ def plot_swm_pattern(pattern, ax=None, **kwargs):
2424
--------
2525
Plot the average pattern from a sliding window matching analysis:
2626
27+
>>> import numpy as np
2728
>>> from neurodsp.sim import sim_combined
2829
>>> from neurodsp.rhythm import sliding_window_matching
2930
>>> sig = sim_combined(n_seconds=10, fs=500,
3031
... components={'sim_powerlaw': {'f_range': (2, None)},
3132
... 'sim_bursty_oscillation': {'freq': 20,
3233
... 'enter_burst': .25,
3334
... 'leave_burst': .25}})
34-
>>> avg_window, _, _ = sliding_window_matching(sig, fs=500, win_len=0.05, win_spacing=0.5)
35+
>>> windows, _ = sliding_window_matching(sig, fs=500, win_len=0.05, win_spacing=0.5)
36+
>>> avg_window = np.mean(windows)
3537
>>> plot_swm_pattern(avg_window)
3638
"""
3739

neurodsp/rhythm/lc.py

Lines changed: 22 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -23,13 +23,14 @@ def compute_lagged_coherence(sig, fs, freqs, n_cycles=3, return_spectrum=False):
2323
fs : float
2424
Sampling rate, in Hz.
2525
freqs : 1d array or list of float
26-
If array, frequency values to estimate with morlet wavelets.
26+
The frequency values at which to estimate lagged coherence.
27+
If array, defines the frequency values to use.
2728
If list, define the frequency range, as [freq_start, freq_stop, freq_step].
2829
The `freq_step` is optional, and defaults to 1. Range is inclusive of `freq_stop` value.
2930
n_cycles : float or list of float, default: 3
30-
Number of cycles of each frequency to use to compute lagged coherence.
31-
If a single value, the same number of cycles is used for each frequency value.
32-
If a list or list_like, then should be a n_cycles corresponding to each frequency.
31+
The number of cycles to use to compute lagged coherence, for each frequency.
32+
If a single value, the same number of cycles is used for each frequency.
33+
If a list or list_like, there should be a value corresponding to each frequency.
3334
return_spectrum : bool, optional, default: False
3435
If True, return the lagged coherence for all frequency values.
3536
Otherwise, only the mean lagged coherence value across the frequency range is returned.
@@ -87,7 +88,7 @@ def compute_lagged_coherence(sig, fs, freqs, n_cycles=3, return_spectrum=False):
8788

8889

8990
def lagged_coherence_1freq(sig, fs, freq, n_cycles):
90-
"""Compute the lagged coherence of a frequency using the hanning-taper FFT method.
91+
"""Compute the lagged coherence at a particular frequency.
9192
9293
Parameters
9394
----------
@@ -98,12 +99,17 @@ def lagged_coherence_1freq(sig, fs, freq, n_cycles):
9899
freq : float
99100
The frequency at which to estimate lagged coherence.
100101
n_cycles : float
101-
Number of cycles at the examined frequency to use to compute lagged coherence.
102+
The number of cycles of the given frequency to use to compute lagged coherence.
102103
103104
Returns
104105
-------
105-
float
106+
lc : float
106107
The computed lagged coherence value.
108+
109+
Notes
110+
-----
111+
- Lagged coherence is computed using hanning-tapered FFTs.
112+
- The returned lagged coherence value is bound between 0 and 1.
107113
"""
108114

109115
# Determine number of samples to be used in each window to compute lagged coherence
@@ -113,20 +119,26 @@ def lagged_coherence_1freq(sig, fs, freq, n_cycles):
113119
chunks = split_signal(sig, n_samps)
114120
n_chunks = len(chunks)
115121

116-
# For each chunk, calculate the Fourier coefficients at the frequency of interest
122+
# Create the window to apply to each chunk
117123
hann_window = hann(n_samps)
124+
125+
# Create the frequency vector, finding the frequency value of interest
118126
fft_freqs = np.fft.fftfreq(n_samps, 1 / float(fs))
119127
fft_freqs_idx = np.argmin(np.abs(fft_freqs - freq))
120128

129+
# Calculate the Fourier coefficients across chunks for the frequency of interest
121130
fft_coefs = np.zeros(n_chunks, dtype=complex)
122131
for ind, chunk in enumerate(chunks):
123132
fourier_coef = np.fft.fft(chunk * hann_window)
124133
fft_coefs[ind] = fourier_coef[fft_freqs_idx]
125134

126-
# Compute the lagged coherence value
135+
# Compute lagged coherence across data segments
127136
lcs_num = 0
128137
for ind in range(n_chunks - 1):
129138
lcs_num += fft_coefs[ind] * np.conj(fft_coefs[ind + 1])
130139
lcs_denom = np.sqrt(np.sum(np.abs(fft_coefs[:-1])**2) * np.sum(np.abs(fft_coefs[1:])**2))
131140

132-
return np.abs(lcs_num / lcs_denom)
141+
# Normalize the lagged coherence value
142+
lc_val = np.abs(lcs_num / lcs_denom)
143+
144+
return lc_val

neurodsp/rhythm/swm.py

Lines changed: 134 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
"""The sliding window matching algorithm for identifying rhythmic components of a neural signal."""
1+
"""The sliding window matching algorithm for identifying recurring patterns in a neural signal."""
22

33
import numpy as np
44

@@ -8,8 +8,8 @@
88
###################################################################################################
99

1010
@multidim()
11-
def sliding_window_matching(sig, fs, win_len, win_spacing, max_iterations=500,
12-
temperature=1, window_starts_custom=None):
11+
def sliding_window_matching(sig, fs, win_len, win_spacing, max_iterations=100,
12+
window_starts_custom=None, var_thresh=None):
1313
"""Find recurring patterns in a time series using the sliding window matching algorithm.
1414
1515
Parameters
@@ -19,150 +19,194 @@ def sliding_window_matching(sig, fs, win_len, win_spacing, max_iterations=500,
1919
fs : float
2020
Sampling rate, in Hz.
2121
win_len : float
22-
Window length, in seconds.
22+
Window length, in seconds. This is L in the original paper.
2323
win_spacing : float
24-
Minimum window spacing, in seconds.
25-
max_iterations : int
24+
Minimum spacing between windows, in seconds. This is G in the original paper.
25+
max_iterations : int, optional, default: 100
2626
Maximum number of iterations of potential changes in window placement.
27-
temperature : float
28-
Temperature parameter. Controls probability of accepting a new window.
29-
window_starts_custom : 1d array, optional
27+
window_starts_custom : 1d array, optional, default: None
3028
Custom pre-set locations of initial windows.
29+
var_thresh: float, opational, default: None
30+
Removes initial windows with variance below a set threshold. This speeds up
31+
runtime proportional to the number of low variance windows in the data.
3132
3233
Returns
3334
-------
34-
avg_window : 1d array
35-
The average waveform in 'sig' in the frequency 'f_range' triggered on 'trigger'.
35+
windows : 2d array
36+
Putative patterns discovered in the input signal.
3637
window_starts : 1d array
3738
Indices at which each window begins for the final set of windows.
38-
costs : 1d array
39-
Cost function value at each iteration.
4039
4140
References
4241
----------
4342
.. [1] Gips, B., Bahramisharif, A., Lowet, E., Roberts, M. J., de Weerd, P., Jensen, O., &
4443
van der Eerden, J. (2017). Discovering recurring patterns in electrophysiological
4544
recordings. Journal of Neuroscience Methods, 275, 66-79.
4645
DOI: 10.1016/j.jneumeth.2016.11.001
47-
Matlab Code: https://github.com/bartgips/SWM
46+
.. [2] Matlab Code implementation: https://github.com/bartgips/SWM
4847
4948
Notes
5049
-----
51-
- Apply a highpass filter if looking at high frequency activity, so that it does
52-
not converge on a low frequency motif.
53-
- Parameters `win_len` and `win_spacing` should be chosen to be about the size of the
54-
motif of interest, and the N derived should be about the number of occurrences.
50+
- The `win_len` parameter should be chosen to be about the size of the motif of interest.
51+
The larger this window size, the more likely the pattern to reflect slower patterns.
52+
- The `win_spacing` parameter also determines the number of windows that are used.
53+
- If looking at high frequency activity, you may want to apply a highpass filter,
54+
so that the algorithm does not converge on a low frequency motif.
55+
- This implementation is a minimal, modified version, as compared to the original
56+
implementation in [2], which has more available options.
57+
- This version has the following changes to speed up convergence:
58+
59+
1. Each iteration is similar to an epoch, randomly moving all windows in
60+
random order. The original implementation randomly selects windows and
61+
does not guarantee even resampling.
62+
2. New window acceptance is determined via increased correlation coefficients
63+
and reduced ivariance across windows.
64+
3. Phase optimization / realignment to escape local minima.
65+
5566
5667
Examples
5768
--------
5869
Search for reoccuring patterns using sliding window matching in a simulated beta signal:
5970
6071
>>> from neurodsp.sim import sim_combined
61-
>>> sig = sim_combined(n_seconds=10, fs=500,
62-
... components={'sim_powerlaw': {'f_range': (2, None)},
63-
... 'sim_bursty_oscillation': {'freq': 20,
64-
... 'enter_burst': .25,
65-
... 'leave_burst': .25}})
66-
>>> avg_window, window_starts, costs = sliding_window_matching(sig, fs=500, win_len=0.05,
67-
... win_spacing=0.20)
72+
>>> components = {'sim_bursty_oscillation' : {'freq' : 20, 'phase' : 'min'},
73+
... 'sim_powerlaw' : {'f_range' : (2, None)}}
74+
>>> sig = sim_combined(10, fs=500, components=components, component_variances=(1, .05))
75+
>>> windows, starts = sliding_window_matching(sig, fs=500, win_len=0.05,
76+
... win_spacing=0.05, var_thresh=.5)
6877
"""
6978

7079
# Compute window length and spacing in samples
71-
win_n_samps = int(win_len * fs)
72-
spacing_n_samps = int(win_spacing * fs)
80+
win_len = int(win_len * fs)
81+
win_spacing = int(win_spacing * fs)
7382

7483
# Initialize window positions
7584
if window_starts_custom is None:
76-
window_starts = np.arange(0, len(sig) - win_n_samps, 2 * spacing_n_samps)
85+
window_starts = np.arange(0, len(sig) - win_len, win_spacing).astype(int)
7786
else:
7887
window_starts = window_starts_custom
79-
n_windows = len(window_starts)
8088

81-
# Randomly sample windows with replacement
82-
random_window_idx = np.random.choice(range(n_windows), size=max_iterations)
89+
windows = np.array([sig[start:start + win_len] for start in window_starts])
90+
91+
# Compute new window bounds
92+
lower_bounds, upper_bounds = _compute_bounds(window_starts, win_spacing, 0, len(sig) - win_len)
93+
94+
# Remove low variance windows to speed up runtime
95+
if var_thresh is not None:
96+
97+
thresh = np.array([np.var(sig[loc:loc + win_len]) > var_thresh for loc in window_starts])
98+
99+
windows = windows[thresh]
100+
window_starts = window_starts[thresh]
101+
lower_bounds = lower_bounds[thresh]
102+
upper_bounds = upper_bounds[thresh]
103+
104+
# Modified SWM procedure
105+
window_idxs = np.arange(len(windows)).astype(int)
106+
107+
corrs, variance = _compute_cost(sig, window_starts, win_len)
108+
mae = np.mean(np.abs(windows - windows.mean(axis=0)))
83109

84-
# Calculate initial cost
85-
costs = np.zeros(max_iterations)
86-
costs[0] = _compute_cost(sig, window_starts, win_n_samps)
110+
for _ in range(max_iterations):
87111

88-
for iter_num in range(1, max_iterations):
112+
# Randomly shuffle order of windows
113+
np.random.shuffle(window_idxs)
89114

90-
# Pick a random window position to randomly replace with a
91-
# new window to improve cross-window similarity
92-
window_idx_replace = random_window_idx[iter_num]
115+
for win_idx in window_idxs:
93116

94-
# Find a new allowed position for the window
95-
window_starts_temp = np.copy(window_starts)
96-
window_starts_temp[window_idx_replace] = _find_new_window_idx(
97-
window_starts, spacing_n_samps, len(sig) - win_n_samps)
117+
# Find a new, random window start
118+
_window_starts = window_starts.copy()
119+
_window_starts[win_idx] = np.random.choice(np.arange(lower_bounds[win_idx],
120+
upper_bounds[win_idx] + 1))
98121

99-
# Calculate the cost & the change in the cost function
100-
cost_temp = _compute_cost(sig, window_starts_temp, win_n_samps)
101-
delta_cost = cost_temp - costs[iter_num - 1]
122+
# Accept new window if correlation increases and variance decreases
123+
_corrs, _variance = _compute_cost(sig, _window_starts, win_len)
102124

103-
# Calculate the acceptance probability
104-
p_accept = np.exp(-delta_cost / float(temperature))
125+
if _corrs[win_idx].sum() > corrs[win_idx].sum() and _variance < variance:
105126

106-
# Accept update to J with a certain probability
107-
if np.random.rand() < p_accept:
127+
corrs = _corrs.copy()
128+
variance = _variance
129+
window_starts = _window_starts.copy()
130+
lower_bounds, upper_bounds = _compute_bounds(\
131+
window_starts, win_spacing, 0, len(sig) - win_len)
108132

109-
# Update costs & windows
110-
costs[iter_num] = cost_temp
111-
window_starts = window_starts_temp
133+
# Phase optimization
134+
_window_starts = window_starts.copy()
112135

113-
else:
136+
for shift in np.arange(-int(win_len/2), int(win_len/2)):
114137

115-
# Update costs
116-
costs[iter_num] = costs[iter_num - 1]
138+
_starts = _window_starts + shift
117139

118-
# Calculate average window
119-
avg_window = np.zeros(win_n_samps)
120-
for w_ind in range(n_windows):
121-
avg_window = avg_window + sig[window_starts[w_ind]:window_starts[w_ind] + win_n_samps]
122-
avg_window = avg_window / float(n_windows)
140+
# Skip windows shifts that are out-of-bounds
141+
if (_starts[0] < 0) or (_starts[-1] > len(sig) - win_len):
142+
continue
123143

124-
return avg_window, window_starts, costs
144+
_windows = np.array([sig[start:start + win_len] for start in _starts])
145+
146+
_mae = np.mean(np.abs(_windows - _windows.mean(axis=0)))
147+
148+
if _mae < mae:
149+
window_starts = _starts.copy()
150+
windows = _windows.copy()
151+
mae = _mae
152+
153+
lower_bounds, upper_bounds = _compute_bounds(\
154+
window_starts, win_spacing, 0, len(sig) - win_len)
155+
156+
return windows, window_starts
125157

126158

127159
def _compute_cost(sig, window_starts, win_n_samps):
128-
"""Compute the cost, which is proportional to the difference between pairs of windows."""
160+
"""Compute the cost, as corrleation coefficients and variance across windows.
161+
162+
Parameters
163+
----------
164+
sig : 1d array
165+
Time series.
166+
window_starts : list of int
167+
The list of window start definitions.
168+
win_n_samps : int
169+
The length of each window, in samples.
129170
130-
# Get all windows and z-score them
131-
n_windows = len(window_starts)
132-
windows = np.zeros((n_windows, win_n_samps))
171+
Returns
172+
-------
173+
corrs: 2d array
174+
Window correlation matrix.
175+
variance: float
176+
Sum of the variance across windows.
177+
"""
133178

134-
for ind, window in enumerate(window_starts):
135-
temp = sig[window:window_starts[ind] + win_n_samps]
136-
windows[ind] = (temp - np.mean(temp)) / np.std(temp)
179+
windows = np.array([sig[start:start + win_n_samps] for start in window_starts])
137180

138-
# Calculate distances for all pairs of windows
139-
dists = []
140-
for ind1 in range(n_windows):
141-
for ind2 in range(ind1 + 1, n_windows):
142-
window_diff = windows[ind1] - windows[ind2]
143-
dist_temp = np.sum(window_diff**2) / float(win_n_samps)
144-
dists.append(dist_temp)
181+
corrs = np.corrcoef(windows)
145182

146-
# Calculate cost function, which is the average difference, roughly
147-
cost = np.sum(dists) / float(2 * (n_windows - 1))
183+
variance = windows.var(axis=1).sum()
148184

149-
return cost
185+
return corrs, variance
150186

151187

152-
def _find_new_window_idx(window_starts, spacing_n_samps, n_samp, tries_limit=1000):
153-
"""Find a new sample for the starting window."""
188+
def _compute_bounds(window_starts, win_spacing, start, end):
189+
"""Compute bounds on a new window.
154190
155-
for n_try in range(tries_limit):
191+
Parameters
192+
----------
193+
window_starts : list of int
194+
The list of window start definitions.
195+
win_spacing : float
196+
Minimum spacing between windows, in seconds.
197+
start, end : int
198+
Start and end edges for the possible window.
156199
157-
# Generate a random sample & check how close it is to other window starts
158-
new_samp = np.random.randint(n_samp)
159-
dists = np.abs(window_starts - new_samp)
200+
Returns
201+
-------
202+
lower_bounds, upper_bounds : 1d array
203+
Computed upper and lower bounds for the window position.
204+
"""
160205

161-
if np.min(dists) > spacing_n_samps:
162-
break
206+
lower_bounds = window_starts[:-1] + win_spacing
207+
lower_bounds = np.insert(lower_bounds, 0, start)
163208

164-
else:
165-
raise RuntimeError('SWM algorithm has difficulty finding a new window. \
166-
Try increasing the spacing parameter.')
209+
upper_bounds = window_starts[1:] - win_spacing
210+
upper_bounds = np.insert(upper_bounds, len(upper_bounds), end)
167211

168-
return new_samp
212+
return lower_bounds, upper_bounds

0 commit comments

Comments
 (0)