Skip to content

Commit df599fb

Browse files
author
Thinh Nguyen
committed
triggering kilosort analysis for open-ephys
1 parent ddc3b94 commit df599fb

File tree

4 files changed

+96
-12
lines changed

4 files changed

+96
-12
lines changed

element_array_ephys/ephys_no_curation.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -607,12 +607,23 @@ def make(self, key):
607607
probe_type = (ProbeInsertion * probe.Probe & key).fetch1('probe_type')
608608
params['probe_type'] = {'neuropixels 1.0 - 3A': '3A',
609609
'neuropixels 1.0 - 3B': 'NP1',
610+
'neuropixels UHD': 'NP1100',
610611
'neuropixels 2.0 - SS': 'NP21',
611612
'neuropixels 2.0 - MS': 'NP24'}[probe_type]
612613
params['sample_rate'] = oe_probe.ap_meta['sample_rate']
613614
params['num_channels'] = oe_probe.ap_meta['num_channels']
614615
params['uVPerBit'] = oe_probe.ap_meta['channels_gains'][0]
615616

617+
# add additional electrodes information into `params`
618+
electrode_config_key = (probe.ElectrodeConfig * EphysRecording & key).fetch1('KEY')
619+
params['channel_ind'], params['x_coords'], params['y_coords'], params['shank_ind'] = (
620+
probe.ElectrodeConfig.Electrode * probe.ProbeType.Electrode
621+
& electrode_config_key).fetch('electrode', 'x_coord', 'y_coord', 'shank')
622+
params['connected'] = np.array([int(v == 1)
623+
for c, v in oe_probe.channel_status.items()
624+
if c in params['channel_ind']])
625+
626+
# run kilosort
616627
run_kilosort = kilosort_triggering.OpenEphysKilosortPipeline(
617628
npx_input_dir=oe_probe.recording_info['recording_files'][0],
618629
ks_output_dir=kilosort_dir,
@@ -873,11 +884,7 @@ def get_neuropixels_channel2electrode_map(ephys_recording_key, acq_software):
873884
for recorded_site, (shank, shank_col, shank_row, _) in enumerate(
874885
spikeglx_meta.shankmap['data'])}
875886
elif acq_software == 'Open Ephys':
876-
sess_dir = find_full_path(get_ephys_root_data_dir(),
877-
get_session_directory(ephys_recording_key))
878-
openephys_dataset = openephys.OpenEphys(sess_dir)
879-
probe_serial_number = (ProbeInsertion & ephys_recording_key).fetch1('probe')
880-
probe_dataset = openephys_dataset.probes[probe_serial_number]
887+
probe_dataset = get_openephys_probe_data(ephys_recording_key)
881888

882889
electrode_query = (probe.ProbeType.Electrode
883890
* probe.ElectrodeConfig.Electrode

element_array_ephys/readers/kilosort_triggering.py

Lines changed: 72 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
import re
66
import inspect
77
import os
8+
import scipy.io
9+
import numpy as np
810
from datetime import datetime
911

1012
from ..import dict_to_uuid
@@ -14,8 +16,9 @@
1416
try:
1517
from ecephys_spike_sorting.scripts.create_input_json import createInputJson
1618
from ecephys_spike_sorting.scripts.helpers import SpikeGLX_utils, log_from_json
19+
from ecephys_spike_sorting.modules.kilosort_helper.__main__ import get_noise_channels
1720
except Exception as e:
18-
print(f'Error in loading "ecephys_spike_sorting" - {str(e)}')
21+
print(f'Error in loading "ecephys_spike_sorting" package - {str(e)}')
1922

2023

2124
class SGLXKilosortPipeline:
@@ -66,6 +69,7 @@ def __init__(self, npx_input_dir: str, ks_output_dir: str,
6669
self._json_directory.mkdir(parents=True, exist_ok=True)
6770

6871
self._CatGT_finished = False
72+
self.ks_input_params = None
6973
self._modules_input_hash = None
7074
self._modules_input_hash_fp = None
7175

@@ -147,7 +151,7 @@ def generate_modules_input_json(self):
147151
if k in self._input_json_args:
148152
params[k] = value
149153

150-
input_params = createInputJson(
154+
self.ks_input_params = createInputJson(
151155
self._module_input_json.as_posix(),
152156
KS2ver=self._KS2ver,
153157
npx_directory=self._npx_input_dir.as_posix(),
@@ -156,6 +160,7 @@ def generate_modules_input_json(self):
156160
input_meta_path=input_meta_fullpath.as_posix(),
157161
extracted_data_directory=self._ks_output_dir.parent.as_posix(),
158162
kilosort_output_directory=self._ks_output_dir.as_posix(),
163+
kilosort_output_tmp=self._ks_output_dir.as_posix(),
159164
ks_make_copy=True,
160165
noise_template_use_rf=self._params.get('noise_template_use_rf', False),
161166
c_Waves_snr_um=self._params.get('c_Waves_snr_um', 160),
@@ -164,7 +169,7 @@ def generate_modules_input_json(self):
164169
**params
165170
)
166171

167-
self._modules_input_hash = dict_to_uuid(input_params)
172+
self._modules_input_hash = dict_to_uuid(self.ks_input_params)
168173

169174
def run_modules(self):
170175
if self._run_CatGT and not self._CatGT_finished:
@@ -275,10 +280,29 @@ def __init__(self, npx_input_dir: str, ks_output_dir: str,
275280
self._json_directory = self._ks_output_dir / 'json_configs'
276281
self._json_directory.mkdir(parents=True, exist_ok=True)
277282

283+
self.ks_input_params = None
278284
self._modules_input_hash = None
279285
self._modules_input_hash_fp = None
280286

287+
def make_chanmap_file(self):
288+
continuous_file = self._npx_input_dir / 'continuous.dat'
289+
self._chanmap_filepath = self._ks_output_dir / 'chanMap.mat'
290+
291+
_write_channel_map_file(channel_ind=self._params['channel_ind'],
292+
x_coords=self._params['x_coords'],
293+
y_coords=self._params['y_coords'],
294+
shank_ind=self._params['shank_ind'],
295+
connected=self._params['connected'],
296+
probe_name=self._params['probe_type'],
297+
ap_band_file=continuous_file.as_posix(),
298+
bit_volts=self._params['uVPerBit'],
299+
sample_rate=self._params['sample_rate'],
300+
save_path=self._chanmap_filepath.as_posix(),
301+
is_0_based=True)
302+
281303
def generate_modules_input_json(self):
304+
self.make_chanmap_file()
305+
282306
self._module_input_json = self._json_directory / f'{self._npx_input_dir.name}-input.json'
283307

284308
continuous_file = self._npx_input_dir / 'continuous.dat'
@@ -291,23 +315,25 @@ def generate_modules_input_json(self):
291315
if k in self._input_json_args:
292316
params[k] = value
293317

294-
input_params = createInputJson(
318+
self.ks_input_params = createInputJson(
295319
self._module_input_json.as_posix(),
296320
KS2ver=self._KS2ver,
297321
npx_directory=self._npx_input_dir.as_posix(),
298322
spikeGLX_data=False,
299323
continuous_file=continuous_file.as_posix(),
300324
extracted_data_directory=self._ks_output_dir.parent.as_posix(),
301325
kilosort_output_directory=self._ks_output_dir.as_posix(),
326+
kilosort_output_tmp=self._ks_output_dir.as_posix(),
302327
ks_make_copy=True,
303328
noise_template_use_rf=self._params.get('noise_template_use_rf', False),
304329
c_Waves_snr_um=self._params.get('c_Waves_snr_um', 160),
305330
qm_isi_thresh=self._params.get('refPerMS', 2.0) / 1000,
306331
kilosort_repository=_get_kilosort_repository(self._KS2ver),
332+
chanMap_path=self._chanmap_filepath.as_posix(),
307333
**params
308334
)
309335

310-
self._modules_input_hash = dict_to_uuid(input_params)
336+
self._modules_input_hash = dict_to_uuid(self.ks_input_params)
311337

312338
def run_modules(self):
313339
print('---- Running Modules ----')
@@ -379,3 +405,44 @@ def _get_kilosort_repository(KS2ver):
379405
assert ks_repo.exists()
380406

381407
return ks_repo.as_posix()
408+
409+
410+
def _write_channel_map_file(*, channel_ind, x_coords, y_coords, shank_ind, connected,
411+
probe_name, ap_band_file, bit_volts, sample_rate,
412+
save_path, is_0_based=True):
413+
"""
414+
Write channel map into .mat file in 1-based indexing format (MATLAB style)
415+
"""
416+
417+
assert len(channel_ind) == len(x_coords) == len(y_coords) == len(shank_ind) == len(connected)
418+
419+
if is_0_based:
420+
channel_ind += 1
421+
shank_ind += 1
422+
423+
channel_count = len(channel_ind)
424+
chanMap0ind = np.arange(0, channel_count, dtype='float64')
425+
chanMap0ind = chanMap0ind.reshape((channel_count, 1))
426+
chanMap = chanMap0ind + 1
427+
428+
# channels to exclude
429+
mask = get_noise_channels(ap_band_file,
430+
channel_count,
431+
sample_rate,
432+
bit_volts)
433+
bad_channel_ind = np.where(mask is False)[0]
434+
connected[bad_channel_ind] = 0
435+
436+
mdict = {
437+
'chanMap': chanMap,
438+
'chanMap0ind': chanMap0ind,
439+
'connected': connected,
440+
'name': probe_name,
441+
'xcoords': x_coords,
442+
'ycoords': y_coords,
443+
'shankInd': shank_ind,
444+
'kcoords': shank_ind,
445+
'fs': sample_rate
446+
}
447+
448+
scipy.io.savemat(save_path, mdict)

element_array_ephys/readers/openephys.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,8 @@ def load_probe_data(self):
9696
meta = getattr(probe, continuous_type + '_meta')
9797
if not meta:
9898
meta.update(**continuous_info,
99-
channels_ids=analog_signal.channel_ids,
99+
channels_ids=[c['source_processor_index']
100+
for c in continuous_info['channels']],
100101
channels_names=analog_signal.channel_names,
101102
channels_gains=analog_signal.gains)
102103

@@ -120,7 +121,14 @@ def __init__(self, processor, probe_index=0):
120121
else:
121122
self.probe_info = processor['EDITOR']['NP_PROBE'][probe_index]
122123
self.probe_SN = self.probe_info['@probe_serial_number']
123-
self.probe_model = self.probe_info['@probe_name']
124+
self.probe_model = {
125+
"Neuropixels 1.0": "neuropixels 1.0 - 3B",
126+
"Neuropixels Ultra": "neuropixels UHD",
127+
"Neuropixels 21": "neuropixels 2.0 - SS",
128+
"Neuropixels 24": "neuropixels 2.0 - MS"}[self.probe_info['@probe_name']]
129+
130+
self.channel_status = {int(k.replace('@E', '')): int(v)
131+
for k, v in self.probe_info.pop('CHANNELSTATUS').items()}
124132

125133
self.ap_meta = {}
126134
self.lfp_meta = {}

element_array_ephys/readers/spikeglx.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,8 @@ def __init__(self, meta_filepath):
179179
self.probe_model = 'neuropixels 1.0 - 3A'
180180
elif 'typeImEnabled' in self.meta:
181181
self.probe_model = 'neuropixels 1.0 - 3B'
182+
elif probe_model == 1100:
183+
self.probe_model = 'neuropixels UHD'
182184
elif probe_model == 21:
183185
self.probe_model = 'neuropixels 2.0 - SS'
184186
elif probe_model == 24:

0 commit comments

Comments
 (0)