Skip to content

Commit 4eb5fc1

Browse files
committed
Load spike data only on demand and when not already cached
1 parent 2bc830b commit 4eb5fc1

File tree

1 file changed

+12
-4
lines changed

1 file changed

+12
-4
lines changed

neo/rawio/plexon2rawio/plexon2rawio.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -161,10 +161,6 @@ def _parse_header(self):
161161
spike_channels.append((unit_name, unit_id, wf_units, wf_gain,
162162
wf_offset, wf_left_sweep, wf_sampling_rate))
163163

164-
# pre-loading spiking data
165-
schannel_name = schannel_info.m_Name.decode()
166-
self._spike_channel_cache[schannel_name] = self.pl2reader.pl2_get_spike_channel_data_by_name(schannel_name)
167-
168164
spike_channels = np.array(spike_channels, dtype=_spike_channel_dtype)
169165

170166
# creating event/epoch channel
@@ -366,6 +362,10 @@ def _spike_count(self, block_index, seg_index, spike_channel_index):
366362
channel_name, channel_unit_id = channel_header['name'].split('.')
367363
channel_unit_id = int(channel_unit_id)
368364

365+
# loading spike channel data on demand when not already cached
366+
if channel_name not in self._spike_channel_cache:
367+
self._spike_channel_cache[channel_name] = self.pl2reader.pl2_get_spike_channel_data_by_name(channel_name)
368+
369369
spike_timestamps, unit_ids, waveforms = self._spike_channel_cache[channel_name]
370370
nb_spikes = np.count_nonzero(unit_ids == channel_unit_id)
371371

@@ -376,6 +376,10 @@ def _get_spike_timestamps(self, block_index, seg_index, spike_channel_index, t_s
376376
channel_name, channel_unit_id = channel_header['name'].split('.')
377377
channel_unit_id = int(channel_unit_id)
378378

379+
# loading spike channel data on demand when not already cached
380+
if channel_name not in self._spike_channel_cache:
381+
self._spike_channel_cache[channel_name] = self.pl2reader.pl2_get_spike_channel_data_by_name(channel_name)
382+
379383
spike_timestamps, unit_ids, waveforms = self._spike_channel_cache[channel_name]
380384

381385
if t_start is not None or t_stop is not None:
@@ -417,6 +421,10 @@ def _get_spike_raw_waveforms(self, block_index, seg_index, spike_channel_index,
417421
channel_header = self.header['spike_channels'][spike_channel_index]
418422
channel_name, channel_unit_id = channel_header['name'].split('.')
419423

424+
# loading spike channel data on demand when not already cached
425+
if channel_name not in self._spike_channel_cache:
426+
self._spike_channel_cache[channel_name] = self.pl2reader.pl2_get_spike_channel_data_by_name(channel_name)
427+
420428
spike_timestamps, unit_ids, waveforms = self._spike_channel_cache[channel_name]
421429

422430
if t_start is not None or t_stop is not None:

0 commit comments

Comments
 (0)