Skip to content

Commit 15696d8

Browse files
committed
Merge branch 'release/2.10.0'
2 parents 11cd0ed + d0fb153 commit 15696d8

32 files changed

+2315
-121
lines changed

brainbox/behavior/training.py

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -408,7 +408,7 @@ def compute_performance_easy(trials):
408408
return np.sum(trials['feedbackType'][easy_trials] == 1) / easy_trials.shape[0]
409409

410410

411-
def compute_performance(trials, signed_contrast=None, block=None):
411+
def compute_performance(trials, signed_contrast=None, block=None, prob_right=False):
412412
"""
413413
Compute performance on all trials at each contrast level from trials object
414414
@@ -429,12 +429,16 @@ def compute_performance(trials, signed_contrast=None, block=None):
429429
return np.nan * np.zeros(3)
430430

431431
contrasts, n_contrasts = np.unique(signed_contrast[block_idx], return_counts=True)
432-
rightward = trials.choice == -1
433-
# Calculate the proportion rightward for each contrast type
434-
prob_choose_right = np.vectorize(lambda x: np.mean(rightward[(x == signed_contrast) &
435-
block_idx]))(contrasts)
436432

437-
return prob_choose_right, contrasts, n_contrasts
433+
if not prob_right:
434+
correct = trials.feedbackType == 1
435+
performance = np.vectorize(lambda x: np.mean(correct[(x == signed_contrast) & block_idx]))(contrasts)
436+
else:
437+
rightward = trials.choice == -1
438+
# Calculate the proportion rightward for each contrast type
439+
performance = np.vectorize(lambda x: np.mean(rightward[(x == signed_contrast) & block_idx]))(contrasts)
440+
441+
return performance, contrasts, n_contrasts
438442

439443

440444
def compute_n_trials(trials):
@@ -472,7 +476,8 @@ def compute_psychometric(trials, signed_contrast=None, block=None):
472476
if not np.any(block_idx):
473477
return np.nan * np.zeros(4)
474478

475-
prob_choose_right, contrasts, n_contrasts = compute_performance(trials, signed_contrast=signed_contrast, block=block)
479+
prob_choose_right, contrasts, n_contrasts = compute_performance(trials, signed_contrast=signed_contrast, block=block,
480+
prob_right=True)
476481

477482
psych, _ = psy.mle_fit_psycho(
478483
np.vstack([contrasts, n_contrasts, prob_choose_right]),
@@ -584,15 +589,15 @@ def plot_psychometric(trials, ax=None, title=None, **kwargs):
584589
signed_contrast = get_signed_contrast(trials)
585590
contrasts_fit = np.arange(-100, 100)
586591

587-
prob_right_50, contrasts_50, _ = compute_performance(trials, signed_contrast=signed_contrast, block=0.5)
592+
prob_right_50, contrasts_50, _ = compute_performance(trials, signed_contrast=signed_contrast, block=0.5, prob_right=True)
588593
pars_50 = compute_psychometric(trials, signed_contrast=signed_contrast, block=0.5)
589594
prob_right_fit_50 = psy.erf_psycho_2gammas(pars_50, contrasts_fit)
590595

591-
prob_right_20, contrasts_20, _ = compute_performance(trials, signed_contrast=signed_contrast, block=0.2)
596+
prob_right_20, contrasts_20, _ = compute_performance(trials, signed_contrast=signed_contrast, block=0.2, prob_right=True)
592597
pars_20 = compute_psychometric(trials, signed_contrast=signed_contrast, block=0.2)
593598
prob_right_fit_20 = psy.erf_psycho_2gammas(pars_20, contrasts_fit)
594599

595-
prob_right_80, contrasts_80, _ = compute_performance(trials, signed_contrast=signed_contrast, block=0.8)
600+
prob_right_80, contrasts_80, _ = compute_performance(trials, signed_contrast=signed_contrast, block=0.8, prob_right=True)
596601
pars_80 = compute_psychometric(trials, signed_contrast=signed_contrast, block=0.8)
597602
prob_right_fit_80 = psy.erf_psycho_2gammas(pars_80, contrasts_fit)
598603

brainbox/ephys_plots.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,9 @@ def image_lfp_spectrum_plot(lfp_power, lfp_freq, chn_coords=None, chn_inds=None,
2525

2626
ylabel = 'Channel index' if chn_coords is None else 'Distance from probe tip (um)'
2727
title = title or 'LFP Power Spectrum'
28-
chn_inds = chn_inds or np.arange(lfp_power.shape[1])
28+
2929
y = np.arange(lfp_power.shape[1]) if chn_coords is None else chn_coords[:, 1]
30+
chn_inds = np.arange(lfp_power.shape[1]) if chn_inds is None else chn_inds
3031

3132
freq_idx = np.where((lfp_freq >= freq_range[0]) & (lfp_freq < freq_range[1]))[0]
3233
freqs = lfp_freq[freq_idx]
@@ -80,7 +81,7 @@ def image_rms_plot(rms_amps, rms_times, chn_coords=None, chn_inds=None, avg_acro
8081

8182
ylabel = 'Channel index' if chn_coords is None else 'Distance from probe tip (um)'
8283
title = title or f'{band} RMS'
83-
chn_inds = chn_inds or np.arange(rms_amps.shape[1])
84+
chn_inds = np.arange(rms_amps.shape[1]) if chn_inds is None else chn_inds
8485
y = np.arange(rms_amps.shape[1]) if chn_coords is None else chn_coords[:, 1]
8586

8687
rms = rms_amps[:, chn_inds]

brainbox/examples/docs_load_spike_sorting.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323

2424
insertions = one.alyx.rest('insertions', 'list')
2525
pid = insertions[0]['id']
26-
sl = SpikeSortingLoader(pid, one=one, atlas=ba)
26+
sl = SpikeSortingLoader(pid=pid, one=one, atlas=ba)
2727
spikes, clusters, channels = sl.load_spike_sorting()
2828
clusters_labeled = SpikeSortingLoader.merge_clusters(spikes, clusters, channels)
2929

brainbox/examples/docs_wheel_moves.ipynb

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -577,9 +577,9 @@
577577
"name": "python",
578578
"nbconvert_exporter": "python",
579579
"pygments_lexer": "ipython3",
580-
"version": "3.7.7"
580+
"version": "3.9.7"
581581
}
582582
},
583583
"nbformat": 4,
584584
"nbformat_minor": 2
585-
}
585+
}

brainbox/task/closed_loop.py

Lines changed: 63 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212

1313
def responsive_units(spike_times, spike_clusters, event_times,
14-
pre_time=[0.5, 0], post_time=[0, 0.5], alpha=0.05):
14+
pre_time=[0.5, 0], post_time=[0, 0.5], alpha=0.05, use_fr=False):
1515
"""
1616
Determine responsive neurons by doing a Wilcoxon Signed-Rank test between a baseline period
1717
before a certain task event (e.g. stimulus onset) and a period after the task event.
@@ -31,6 +31,8 @@ def responsive_units(spike_times, spike_clusters, event_times,
3131
time (in seconds) to follow the event times
3232
alpha : float
3333
alpha to use for statistical significance
34+
use_fr : bool
35+
whether to use the firing rate instead of total spike count
3436
3537
Returns
3638
-------
@@ -51,18 +53,12 @@ def responsive_units(spike_times, spike_clusters, event_times,
5153
times = np.column_stack(((event_times + post_time[0]), (event_times + post_time[1])))
5254
spike_counts, cluster_ids = get_spike_counts_in_bins(spike_times, spike_clusters, times)
5355

54-
# Do statistics
55-
p_values = np.empty(spike_counts.shape[0])
56-
stats = np.empty(spike_counts.shape[0])
57-
for i in range(spike_counts.shape[0]):
58-
if np.sum(baseline_counts[i, :] - spike_counts[i, :]) == 0:
59-
p_values[i] = 1
60-
stats[i] = 0
61-
else:
62-
stats[i], p_values[i] = wilcoxon(baseline_counts[i, :], spike_counts[i, :])
56+
if use_fr:
57+
baseline_counts = baseline_counts / (pre_time[0] - pre_time[1])
58+
spike_counts = spike_counts / (post_time[1] - post_time[0])
6359

64-
# Perform FDR correction for multiple testing
65-
sig_units, p_values, _, _ = multipletests(p_values, alpha, method='fdr_bh')
60+
# Do statistics
61+
sig_units, stats, p_values = compute_comparison_statistics(baseline_counts, spike_counts, test='signrank', alpha=alpha)
6662
significant_units = cluster_ids[sig_units]
6763

6864
return significant_units, stats, p_values, cluster_ids
@@ -125,27 +121,66 @@ def differentiate_units(spike_times, spike_clusters, event_times, event_groups,
125121
counts_2, cluster_ids = get_spike_counts_in_bins(spike_times, spike_clusters, times_2)
126122

127123
# Do statistics
128-
p_values = np.empty(len(cluster_ids))
129-
stats = np.empty(len(cluster_ids))
130-
for i in range(len(cluster_ids)):
131-
if (np.sum(counts_1[i, :]) == 0) and (np.sum(counts_2[i, :]) == 0):
132-
p_values[i] = 1
133-
stats[i] = 0
124+
sig_units, stats, p_values = compute_comparison_statistics(counts_1, counts_2, test=test, alpha=alpha)
125+
significant_units = cluster_ids[sig_units]
126+
127+
return significant_units, stats, p_values, cluster_ids
128+
129+
130+
def compute_comparison_statistics(value1, value2, test='ranksums', alpha=0.05):
131+
"""
132+
Compute statistical test between two arrays
133+
134+
Parameters
135+
----------
136+
value1 : 1D array
137+
first array of values to compare
138+
value2 : 1D array
139+
second array of values to compare
140+
test : string
141+
which statistical test to use, options are:
142+
'ranksums' Wilcoxon Rank Sums test
143+
'signrank' Wilcoxon Signed Rank test (for paired observations)
144+
'ttest' independent samples t-test
145+
'paired_ttest' paired t-test
146+
alpha : float
147+
alpha to use for statistical significance
148+
149+
Returns
150+
-------
151+
significant_units : 1D array
152+
an array with the indices of values that are significatly modulated
153+
stats : 1D array
154+
the statistic of the test that was performed
155+
p_values : 1D array
156+
the p-values of all the values
157+
"""
158+
159+
p_values = np.empty(len(value1))
160+
stats = np.empty(len(value1))
161+
for i in range(len(value1)):
162+
if test == 'signrank':
163+
if np.sum(value1[i, :] - value2[i, :]) == 0:
164+
p_values[i] = 1
165+
stats[i] = 0
166+
else:
167+
stats[i], p_values[i] = wilcoxon(value1[i, :], value2[i, :])
134168
else:
135-
if test == 'ranksums':
136-
stats[i], p_values[i] = ranksums(counts_1[i, :], counts_2[i, :])
137-
elif test == 'signrank':
138-
stats[i], p_values[i] = wilcoxon(counts_1[i, :], counts_2[i, :])
139-
elif test == 'ttest':
140-
stats[i], p_values[i] = ttest_ind(counts_1[i, :], counts_2[i, :])
141-
elif test == 'paired_ttest':
142-
stats[i], p_values[i] = ttest_rel(counts_1[i, :], counts_2[i, :])
169+
if (np.sum(value1[i, :]) == 0) and (np.sum(value2[i, :]) == 0):
170+
p_values[i] = 1
171+
stats[i] = 0
172+
else:
173+
if test == 'ranksums':
174+
stats[i], p_values[i] = ranksums(value1[i, :], value2[i, :])
175+
elif test == 'ttest':
176+
stats[i], p_values[i] = ttest_ind(value1[i, :], value2[i, :])
177+
elif test == 'paired_ttest':
178+
stats[i], p_values[i] = ttest_rel(value1[i, :], value2[i, :])
143179

144180
# Perform FDR correction for multiple testing
145181
sig_units, p_values, _, _ = multipletests(p_values, alpha, method='fdr_bh')
146-
significant_units = cluster_ids[sig_units]
147182

148-
return significant_units, stats, p_values, cluster_ids
183+
return sig_units, stats, p_values
149184

150185

151186
def roc_single_event(spike_times, spike_clusters, event_times,

brainbox/task/passive.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,9 @@ def get_on_off_times_and_positions(rf_map):
1717
Returns
1818
-------
1919
rf_map_times: time of each receptive field map frame np.array(len(stim_frames)
20-
rf_map_pos: unique position of each pixel on scree np.array(len(x_pos), len(y_pos))
20+
rf_map_pos: unique position of each pixel on screen np.array(len(x_pos), len(y_pos))
2121
rf_stim_frames: for each pixel on screen stores array of stimulus frames where stim onset
22-
occured. For both white squares 'on' and black squares 'off'
22+
occurred. For both white squares 'on' and black squares 'off'
2323
2424
"""
2525

brainbox/task/trials.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ def find_trial_ids(trials, side='all', choice='all', order='trial num', sort='id
2121
"""
2222
if event:
2323
idx = ~np.isnan(trials[event])
24+
nan_idx = np.where(idx)[0]
2425
else:
2526
idx = np.ones_like(trials['feedbackType'], dtype=bool)
2627

@@ -118,6 +119,9 @@ def _order_by(_trials, order):
118119
if side == 'right' and choice == 'incorrect':
119120
trial_id = _order_by(incor_r, order)
120121

122+
if event:
123+
trial_id = nan_idx[trial_id]
124+
121125
return trial_id, dividers
122126

123127

@@ -146,8 +150,9 @@ def get_event_aligned_raster(times, events, tbin=0.02, values=None, epoch=[-0.4,
146150
nbin = t.shape[0]
147151

148152
# remove nan trials
149-
events = events[~np.isnan(events)]
150-
intervals = np.c_[events + epoch[0], events + epoch[1]]
153+
non_nan_events = events[~np.isnan(events)]
154+
nan_idx = np.where(~np.isnan(events))
155+
intervals = np.c_[non_nan_events + epoch[0], non_nan_events + epoch[1]]
151156

152157
# Remove any trials that are later than the last value in bin_times
153158
out_intervals = intervals[:, 1] > bin_times[-1]
@@ -170,7 +175,11 @@ def get_event_aligned_raster(times, events, tbin=0.02, values=None, epoch=[-0.4,
170175
event_raster.shape[1]), np.nan)]
171176
assert(event_raster.shape[0] == intervals.shape[0])
172177

173-
return event_raster, t
178+
# Reindex if we have removed any nan values
179+
all_event_raster = np.full((events.shape[0], event_raster.shape[1]), np.nan)
180+
all_event_raster[nan_idx, :] = event_raster
181+
182+
return all_event_raster, t
174183

175184

176185
def get_psth(raster, trial_ids=None):
@@ -298,7 +307,7 @@ def filter_left_right(trials, event_raster, event, contrast, order='trial num'):
298307
return raster, psth
299308

300309

301-
def filter_trials(trials, event_raster, event, contrast, order='trial num', sort='choice'):
310+
def filter_trials(trials, event_raster, event, contrast=(1, 0.5, 0.25, 0.125, 0.0625, 0), order='trial num', sort='choice'):
302311
"""
303312
Wrapper to get out psth and raster for trial choice
304313
:param trials: trials object

brainbox/tests/test_trials.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -159,13 +159,12 @@ def test_get_event_aligned_rasters(self):
159159
# Test when nans have trials - these are removed from the raster
160160
use_trials[10:12] = np.nan
161161
raster, t = get_event_aligned_raster(spikes, use_trials)
162-
assert(raster.shape[0] == len(use_trials) - 2)
163-
assert(all(np.isnan(raster[0:5, :]).ravel()))
164-
assert(all(np.isnan(raster[-5:, :]).ravel()))
162+
assert (raster.shape[0] == len(use_trials))
163+
assert (all(np.isnan(raster[10:12, :]).ravel()))
164+
assert (all(~np.isnan(raster[12:15, :]).ravel()))
165165

166166
use_trials[0:2] = np.nan
167167
raster, t = get_event_aligned_raster(spikes, use_trials)
168-
assert(raster.shape[0] == len(use_trials) - 4)
169-
assert(all(np.isnan(raster[0:3, :]).ravel()))
170-
assert(all(~np.isnan(raster[4, :]).ravel()))
171-
assert(all(np.isnan(raster[-5:, :]).ravel()))
168+
assert (raster.shape[0] == len(use_trials))
169+
assert (all(np.isnan(raster[0:2, :]).ravel()))
170+
assert (all(np.isnan(raster[-5:, :]).ravel()))

0 commit comments

Comments
 (0)