Skip to content

Commit 1b58231

Browse files
committed
Merge branch 'develop' into bb_passive_fix
2 parents 3dc3b30 + b3a86bf commit 1b58231

File tree

13 files changed

+100
-77
lines changed

13 files changed

+100
-77
lines changed

brainbox/io/one.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -70,11 +70,11 @@ def load_channel_locations(eid, one=None, probe=None, aligned=False):
7070
counts = [0]
7171
else:
7272
tracing = [(insertions.get('json', {'temp': 0}).get('extended_qc', {'temp': 0}).
73-
get('tracing_exists', False))]
73+
get('tracing_exists', False))]
7474
resolved = [(insertions.get('json', {'temp': 0}).get('extended_qc', {'temp': 0}).
75-
get('alignment_resolved', False))]
75+
get('alignment_resolved', False))]
7676
counts = [(insertions.get('json', {'temp': 0}).get('extended_qc', {'temp': 0}).
77-
get('alignment_count', 0))]
77+
get('alignment_count', 0))]
7878
probe_id = [insertions['id']]
7979
# No specific probe specified, load any that is available
8080
# Need to catch for the case where we have two of the same probe insertions
@@ -420,7 +420,7 @@ def load_wheel_reaction_times(eid, one=None):
420420

421421

422422
def load_trials_df(eid, one=None, maxlen=None, t_before=0., t_after=0., ret_wheel=False,
423-
ret_abswheel=False, wheel_binsize=0.02, addtl_types=()):
423+
ret_abswheel=False, ext_DLC=False, wheel_binsize=0.02, addtl_types=[]):
424424
"""
425425
TODO Test this with new ONE
426426
Generate a pandas dataframe of per-trial timing information about a given session.
@@ -451,6 +451,8 @@ def load_trials_df(eid, one=None, maxlen=None, t_before=0., t_after=0., ret_whee
451451
Whether to return the time-resampled wheel velocity trace, by default False
452452
ret_abswheel : bool, optional
453453
Whether to return the time-resampled absolute wheel velocity trace, by default False
454+
ext_DLC : bool, optional
455+
Whether to extract DLC data, by default False
454456
wheel_binsize : float, optional
455457
Time bins to resample wheel velocity to, by default 0.02
456458
addtl_types : list, optional

brainbox/modeling/linear.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,16 +59,38 @@ def __init__(self, design_matrix, spk_times, spk_clu,
5959
def _fit(self, dm, binned, cells=None):
6060
"""
6161
Fitting primitive that brainbox.NeuralModel.fit method will call
62+
63+
Parameters
64+
----------
65+
dm : np.ndarray
66+
Design matrix to use for fitting
67+
binned : np.ndarray
68+
Array of binned spike times. Must share first dimension with dm
69+
cells : iterable with .shape attribute, optional
70+
List of cells which are being fit. Use to generate index for output
71+
coefficients and intercepts, must share shape with second dimension
72+
of binned. When None will default to a list of all cells in the model object,
73+
by default None
74+
75+
Returns
76+
-------
77+
coefs, pd.Series
78+
Series containing fit coefficients for cells
79+
intercepts, pd.Series
80+
Series containing intercepts for fits.
6281
"""
6382
if cells is None:
6483
cells = self.clu_ids.flatten()
84+
if cells.shape[0] != binned.shape[1]:
85+
raise ValueError('Length of cells does not match shape of binned')
86+
6587
coefs = pd.Series(index=cells, name='coefficients', dtype=object)
6688
intercepts = pd.Series(index=cells, name='intercepts')
6789

6890
lm = self.estimator.fit(dm, binned)
6991
weight, intercept = lm.coef_, lm.intercept_
7092
for cell in cells:
71-
cell_idx = np.argwhere(self.clu_ids == cell)[0, 0]
93+
cell_idx = np.argwhere(cells == cell)[0, 0]
7294
coefs.at[cell] = weight[cell_idx, :]
7395
intercepts.at[cell] = intercept[cell_idx]
7496
return coefs, intercepts
@@ -84,7 +106,6 @@ def score(self):
84106
"""
85107
if not hasattr(self, 'coefs'):
86108
raise AttributeError('Model has not been fit yet.')
87-
88109
testmask = np.isin(self.design.trlabels, self.testinds).flatten()
89110
dm, binned = self.design[testmask, :], self.binnedspikes[testmask]
90111

brainbox/modeling/neural_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ def __init__(self, design_matrix, spk_times, spk_clu,
103103
self.design = design_matrix
104104
self.spikes = spks
105105
self.clu = clu
106-
self.clu_ids = np.argwhere(np.sum(trialspiking, axis=0) > mintrials)
106+
self.clu_ids = np.argwhere(np.sum(trialspiking, axis=0) > mintrials).flatten()
107107
self.traininds = traininds
108108
self.testinds = testinds
109109
self.stepwise = stepwise

brainbox/modeling/poisson.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,16 +38,19 @@ def _fit(self, dm, binned, cells=None, noncovwarn=False):
3838
alpha : float
3939
Regularization strength, applied as multiplicative constant on ridge regularization.
4040
cells : list
41-
List of cells which should be fit. If None is passed, will default to fitting all cells
42-
in clu_ids
41+
List of cells labels for columns in binned. Will default to all cells in model if None
42+
is passed. Must be of the same length as columns in binned. By default None.
4343
"""
4444
if cells is None:
4545
cells = self.clu_ids.flatten()
46+
if cells.shape[0] != binned.shape[1]:
47+
raise ValueError('Length of cells does not match shape of binned')
48+
4649
coefs = pd.Series(index=cells, name='coefficients', dtype=object)
4750
intercepts = pd.Series(index=cells, name='intercepts')
4851
nonconverged = []
4952
for cell in tqdm(cells, 'Fitting units:', leave=False):
50-
cell_idx = np.argwhere(self.clu_ids == cell)[0, 0]
53+
cell_idx = np.argwhere(cells == cell)[0, 0]
5154
cellbinned = binned[:, cell_idx]
5255
with catch_warnings(record=True) as w:
5356
fitobj = PoissonRegressor(alpha=self.alpha,

brainbox/modeling/utils.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ def fit(self, progress=False):
101101
new_feature_idx, nf_score = self._get_best_new_feature(current_mask, cells)
102102
for cell in cells:
103103
maskdf.at[cell, self.features[new_feature_idx.loc[cell]]] = True
104-
seqdf.loc[cell, i] = self.features[new_feature_idx]
104+
seqdf.loc[cell, i] = self.features[new_feature_idx.loc[cell]]
105105
scoredf.loc[cell, i] = nf_score.loc[cell]
106106
self.support_ = maskdf
107107
self.sequences_ = seqdf
@@ -110,7 +110,8 @@ def fit(self, progress=False):
110110
def _get_best_new_feature(self, mask, cells):
111111
mask = np.array(mask)
112112
candidate_features = np.flatnonzero(~mask)
113-
my = self.model.binnedspikes[self.train]
113+
cell_idxs = np.argwhere(np.isin(self.model.clu_ids, cells)).flatten()
114+
my = self.model.binnedspikes[np.ix_(self.train, cell_idxs)]
114115
scores = pd.DataFrame(index=cells, columns=candidate_features, dtype=float)
115116
for feature_idx in candidate_features:
116117
candidate_mask = mask.copy()

examples/one/tutorial_script.py

Lines changed: 0 additions & 55 deletions
This file was deleted.

ibllib/io/extractors/ephys_fpga.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,20 +60,37 @@
6060
}
6161

6262

63+
def data_for_keys(keys, data):
64+
"""Check keys exist in 'data' dict and contain values other than None"""
65+
return data is not None and all(k in data and data.get(k, None) is not None for k in keys)
66+
67+
6368
def get_ibl_sync_map(ef, version):
6469
"""
6570
Gets default channel map for the version/binary file type combination
6671
:param ef: ibllib.io.spikeglx.glob_ephys_file dictionary with field 'ap' or 'nidq'
6772
:return: channel map dictionary
6873
"""
74+
# Determine default channel map
6975
if version == '3A':
7076
default_chmap = CHMAPS['3A']['ap']
7177
elif version == '3B':
7278
if ef.get('nidq', None):
7379
default_chmap = CHMAPS['3B']['nidq']
7480
elif ef.get('ap', None):
7581
default_chmap = CHMAPS['3B']['ap']
76-
return spikeglx.get_sync_map(ef['path']) or default_chmap
82+
# Try to load channel map from file
83+
chmap = spikeglx.get_sync_map(ef['path'])
84+
# If chmap provided but not with all keys, fill up with default values
85+
if not chmap:
86+
return default_chmap
87+
else:
88+
if data_for_keys(default_chmap.keys(), chmap):
89+
return chmap
90+
else:
91+
_logger.warning("Keys missing from provided channel map, "
92+
"setting missing keys from default channel map")
93+
return {**default_chmap, **chmap}
7794

7895

7996
def _sync_to_alf(raw_ephys_apfile, output_path=None, save=False, parts=''):
@@ -242,7 +259,8 @@ def _assign_events_audio(audio_t, audio_polarities, return_indices=False):
242259
# make sure that there are no 2 consecutive fall or consecutive rise events
243260
assert(np.all(np.abs(np.diff(audio_polarities)) == 2))
244261
# take only even time differences: ie. from rising to falling fronts
245-
dt = np.diff(audio_t)[::2]
262+
i0 = 0 if audio_polarities[0] == 1 else 1
263+
dt = np.diff(audio_t)[i0::2]
246264
# detect ready tone by length below 110 ms
247265
i_ready_tone_in = np.r_[np.where(dt <= 0.11)[0] * 2]
248266
t_ready_tone_in = audio_t[i_ready_tone_in]

ibllib/io/extractors/ephys_passive.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,19 @@ def _get_spacer_times(spacer_template, jitter, ttl_signal, t_quiet):
125125
# adjust indices for
126126
# - `np.where` call above
127127
# - length of spacer_model
128-
idxs_spacer_middle += 2 - int((np.floor(len(spacer_model) / 2)))
128+
spacer_around = int((np.floor(len(spacer_model) / 2)))
129+
idxs_spacer_middle += 2 - spacer_around
130+
131+
# for each spacer make sure the times are monotonically increasing before
132+
# and monotonically decreasing afterwards
133+
is_valid = np.zeros((idxs_spacer_middle.size), dtype=bool)
134+
for i, t in enumerate(idxs_spacer_middle):
135+
before = all(np.diff(dttl[t - spacer_around:t]) > 0)
136+
after = all(np.diff(dttl[t + 1:t + 1 + spacer_around]) < 0)
137+
is_valid[i] = np.bitwise_and(before, after)
138+
139+
idxs_spacer_middle = idxs_spacer_middle[is_valid]
140+
129141
# pull out spacer times (middle)
130142
ts_spacer_middle = ttl_signal[idxs_spacer_middle]
131143
# put beginning/end of spacer times into an array

ibllib/pipes/ephys_preprocessing.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ class SpikeSorting(tasks.Task):
9393
"Documents/PYTHON/iblscripts/deploy/serverpc/kilosort2/run_pykilosort.sh"
9494
)
9595
SPIKE_SORTER_NAME = 'pykilosort'
96-
PYKILOSORT_REPO = '~/Documents/PYTHON/SPIKE_SORTING/pykilosort'
96+
PYKILOSORT_REPO = Path.home().joinpath('Documents/PYTHON/SPIKE_SORTING/pykilosort')
9797

9898
@staticmethod
9999
def _sample2v(ap_file):
@@ -113,7 +113,7 @@ def _fetch_pykilosort_version(repo_path):
113113
version = line.split('=')[-1].strip().replace('"', '').replace("'", '')
114114
except Exception:
115115
pass
116-
return version
116+
return f"pykilosort_{version}"
117117

118118
@staticmethod
119119
def _fetch_ks2_commit_hash(repo_path):
@@ -137,7 +137,7 @@ def _run_pykilosort(self, ap_file):
137137
session_path/spike_sorters/{self.SPIKE_SORTER_NAME}/probeXX folder
138138
:return: path of the folder containing ks2 spike sorting output
139139
"""
140-
140+
self.version = self._fetch_pykilosort_version(self.PYKILOSORT_REPO)
141141
label = ap_file.parts[-2] # this is usually the probe name
142142
if ap_file.parent.joinpath(f"spike_sorting_{self.SPIKE_SORTER_NAME}.log").exists():
143143
_logger.info(f"Already ran: spike_sorting_{self.SPIKE_SORTER_NAME}.log"
@@ -194,7 +194,6 @@ def _run_pykilosort(self, ap_file):
194194

195195
shutil.copytree(temp_dir.joinpath('output'), sorter_dir, dirs_exist_ok=True)
196196
shutil.rmtree(temp_dir, ignore_errors=True)
197-
self.version = self._fetch_ks2_commit_hash(self.PYKILOSORT_REPO)
198197
return sorter_dir
199198

200199
def _run(self, overwrite=False):
@@ -225,7 +224,7 @@ def _run(self, overwrite=False):
225224
logfile = ks2_dir.joinpath(f"spike_sorting_{self.SPIKE_SORTER_NAME}.log")
226225
if logfile.exists():
227226
shutil.copyfile(logfile, probe_out_path.joinpath(
228-
f"ibl_log.info_{self.SPIKE_SORTER_NAME}.log"))
227+
f"_ibl_log.info_{self.SPIKE_SORTER_NAME}.log"))
229228
out, _ = spikes.sync_spike_sorting(ap_file=ap_file, out_path=probe_out_path)
230229
out_files.extend(out)
231230
# convert ks2_output into tar file and also register

ibllib/qc/camera.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -274,6 +274,9 @@ def get_active_wheel_period(wheel, duration_range=(3., 20.), display=False):
274274
edges = np.c_[on, off]
275275
indices, _ = np.where(np.logical_and(
276276
np.diff(edges) > duration_range[0], np.diff(edges) < duration_range[1]))
277+
if len(indices) == 0:
278+
_log.warning('No period of wheel movement found for motion alignment.')
279+
return None
277280
# Pick movement somewhere in the middle
278281
i = indices[int(indices.size / 2)]
279282
if display:
@@ -304,7 +307,7 @@ def ensure_required_data(self):
304307
# Assert 3A probe model; if so download all probe data
305308
det = self.one.get_details(self.eid, full=True)
306309
probe_model = next(x['model'] for x in det['probe_insertion'])
307-
assert probe_model == '3A', 'raw ephys data not missing'
310+
assert probe_model == '3A', 'raw ephys data missing'
308311
collections += ('raw_ephys_data/probe00', 'raw_ephys_data/probe01')
309312
assert_unique = False
310313
for dstype in dtypes:

0 commit comments

Comments
 (0)