55import re
66import inspect
77import os
8+ import scipy .io
9+ import numpy as np
810from datetime import datetime
911
1012from ..import dict_to_uuid
1416try :
1517 from ecephys_spike_sorting .scripts .create_input_json import createInputJson
1618 from ecephys_spike_sorting .scripts .helpers import SpikeGLX_utils , log_from_json
19+ from ecephys_spike_sorting .modules .kilosort_helper .__main__ import get_noise_channels
1720except Exception as e :
18- print (f'Error in loading "ecephys_spike_sorting" - { str (e )} ' )
21+ print (f'Error in loading "ecephys_spike_sorting" package - { str (e )} ' )
1922
2023
2124class SGLXKilosortPipeline :
@@ -66,6 +69,7 @@ def __init__(self, npx_input_dir: str, ks_output_dir: str,
6669 self ._json_directory .mkdir (parents = True , exist_ok = True )
6770
6871 self ._CatGT_finished = False
72+ self .ks_input_params = None
6973 self ._modules_input_hash = None
7074 self ._modules_input_hash_fp = None
7175
@@ -147,7 +151,7 @@ def generate_modules_input_json(self):
147151 if k in self ._input_json_args :
148152 params [k ] = value
149153
150- input_params = createInputJson (
154+ self . ks_input_params = createInputJson (
151155 self ._module_input_json .as_posix (),
152156 KS2ver = self ._KS2ver ,
153157 npx_directory = self ._npx_input_dir .as_posix (),
@@ -156,6 +160,7 @@ def generate_modules_input_json(self):
156160 input_meta_path = input_meta_fullpath .as_posix (),
157161 extracted_data_directory = self ._ks_output_dir .parent .as_posix (),
158162 kilosort_output_directory = self ._ks_output_dir .as_posix (),
163+ kilosort_output_tmp = self ._ks_output_dir .as_posix (),
159164 ks_make_copy = True ,
160165 noise_template_use_rf = self ._params .get ('noise_template_use_rf' , False ),
161166 c_Waves_snr_um = self ._params .get ('c_Waves_snr_um' , 160 ),
@@ -164,7 +169,7 @@ def generate_modules_input_json(self):
164169 ** params
165170 )
166171
167- self ._modules_input_hash = dict_to_uuid (input_params )
172+ self ._modules_input_hash = dict_to_uuid (self . ks_input_params )
168173
169174 def run_modules (self ):
170175 if self ._run_CatGT and not self ._CatGT_finished :
@@ -275,10 +280,29 @@ def __init__(self, npx_input_dir: str, ks_output_dir: str,
275280 self ._json_directory = self ._ks_output_dir / 'json_configs'
276281 self ._json_directory .mkdir (parents = True , exist_ok = True )
277282
283+ self .ks_input_params = None
278284 self ._modules_input_hash = None
279285 self ._modules_input_hash_fp = None
280286
287+ def make_chanmap_file (self ):
288+ continuous_file = self ._npx_input_dir / 'continuous.dat'
289+ self ._chanmap_filepath = self ._ks_output_dir / 'chanMap.mat'
290+
291+ _write_channel_map_file (channel_ind = self ._params ['channel_ind' ],
292+ x_coords = self ._params ['x_coords' ],
293+ y_coords = self ._params ['y_coords' ],
294+ shank_ind = self ._params ['shank_ind' ],
295+ connected = self ._params ['connected' ],
296+ probe_name = self ._params ['probe_type' ],
297+ ap_band_file = continuous_file .as_posix (),
298+ bit_volts = self ._params ['uVPerBit' ],
299+ sample_rate = self ._params ['sample_rate' ],
300+ save_path = self ._chanmap_filepath .as_posix (),
301+ is_0_based = True )
302+
281303 def generate_modules_input_json (self ):
304+ self .make_chanmap_file ()
305+
282306 self ._module_input_json = self ._json_directory / f'{ self ._npx_input_dir .name } -input.json'
283307
284308 continuous_file = self ._npx_input_dir / 'continuous.dat'
@@ -291,23 +315,25 @@ def generate_modules_input_json(self):
291315 if k in self ._input_json_args :
292316 params [k ] = value
293317
294- input_params = createInputJson (
318+ self . ks_input_params = createInputJson (
295319 self ._module_input_json .as_posix (),
296320 KS2ver = self ._KS2ver ,
297321 npx_directory = self ._npx_input_dir .as_posix (),
298322 spikeGLX_data = False ,
299323 continuous_file = continuous_file .as_posix (),
300324 extracted_data_directory = self ._ks_output_dir .parent .as_posix (),
301325 kilosort_output_directory = self ._ks_output_dir .as_posix (),
326+ kilosort_output_tmp = self ._ks_output_dir .as_posix (),
302327 ks_make_copy = True ,
303328 noise_template_use_rf = self ._params .get ('noise_template_use_rf' , False ),
304329 c_Waves_snr_um = self ._params .get ('c_Waves_snr_um' , 160 ),
305330 qm_isi_thresh = self ._params .get ('refPerMS' , 2.0 ) / 1000 ,
306331 kilosort_repository = _get_kilosort_repository (self ._KS2ver ),
332+ chanMap_path = self ._chanmap_filepath .as_posix (),
307333 ** params
308334 )
309335
310- self ._modules_input_hash = dict_to_uuid (input_params )
336+ self ._modules_input_hash = dict_to_uuid (self . ks_input_params )
311337
312338 def run_modules (self ):
313339 print ('---- Running Modules ----' )
@@ -379,3 +405,44 @@ def _get_kilosort_repository(KS2ver):
379405 assert ks_repo .exists ()
380406
381407 return ks_repo .as_posix ()
408+
409+
410+ def _write_channel_map_file (* , channel_ind , x_coords , y_coords , shank_ind , connected ,
411+ probe_name , ap_band_file , bit_volts , sample_rate ,
412+ save_path , is_0_based = True ):
413+ """
414+ Write channel map into .mat file in 1-based indexing format (MATLAB style)
415+ """
416+
417+ assert len (channel_ind ) == len (x_coords ) == len (y_coords ) == len (shank_ind ) == len (connected )
418+
419+ if is_0_based :
420+ channel_ind += 1
421+ shank_ind += 1
422+
423+ channel_count = len (channel_ind )
424+ chanMap0ind = np .arange (0 , channel_count , dtype = 'float64' )
425+ chanMap0ind = chanMap0ind .reshape ((channel_count , 1 ))
426+ chanMap = chanMap0ind + 1
427+
428+ # channels to exclude
429+ mask = get_noise_channels (ap_band_file ,
430+ channel_count ,
431+ sample_rate ,
432+ bit_volts )
433+ bad_channel_ind = np .where (mask is False )[0 ]
434+ connected [bad_channel_ind ] = 0
435+
436+ mdict = {
437+ 'chanMap' : chanMap ,
438+ 'chanMap0ind' : chanMap0ind ,
439+ 'connected' : connected ,
440+ 'name' : probe_name ,
441+ 'xcoords' : x_coords ,
442+ 'ycoords' : y_coords ,
443+ 'shankInd' : shank_ind ,
444+ 'kcoords' : shank_ind ,
445+ 'fs' : sample_rate
446+ }
447+
448+ scipy .io .savemat (save_path , mdict )
0 commit comments