Skip to content

Commit cc2d7f1

Browse files
committed
add "unit_class" to spike_channels
1 parent 6a0c010 commit cc2d7f1

File tree

1 file changed

+29
-3
lines changed

1 file changed

+29
-3
lines changed

neo/rawio/blackrockrawio.py

Lines changed: 29 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,6 @@
7373
_signal_channel_dtype,
7474
_signal_stream_dtype,
7575
_signal_buffer_dtype,
76-
_spike_channel_dtype,
7776
_event_channel_dtype,
7877
)
7978

@@ -133,6 +132,19 @@ class BlackrockRawIO(BaseRawIO):
133132
# We need to document the origin of this value
134133
main_sampling_rate = 30000.0
135134

135+
# Override spike channel dtype to include unit_class field specific to Blackrock
136+
_spike_channel_dtype = [
137+
("name", "U64"),
138+
("id", "U64"),
139+
# for waveform
140+
("wf_units", "U64"),
141+
("wf_gain", "float64"),
142+
("wf_offset", "float64"),
143+
("wf_left_sweep", "int64"),
144+
("wf_sampling_rate", "float64"),
145+
("unit_class", "U64"),
146+
]
147+
136148
def __init__(
137149
self, filename=None, nsx_override=None, nev_override=None, nsx_to_load=None, load_nev=True, verbose=False
138150
):
@@ -300,7 +312,21 @@ def _parse_header(self):
300312
# default value: threshold crossing after 10 samples of waveform
301313
wf_left_sweep = 10
302314
wf_sampling_rate = self.main_sampling_rate
303-
spike_channels.append((name, _id, wf_units, wf_gain, wf_offset, wf_left_sweep, wf_sampling_rate))
315+
316+
# Map unit_class_nb to unit classification string
317+
if unit_id == 0:
318+
unit_class = "unclassified"
319+
warnings.warn(f"Unit {unit_id} for channel {channel_id} is unclassified events!")
320+
elif 1 <= unit_id <= 16:
321+
unit_class = "sorted"
322+
elif unit_id == 255:
323+
unit_class = "noise"
324+
warnings.warn(f"Unit {unit_id} for channel {channel_id} is noisy events!")
325+
else: # 17-254 are reserved but treated as "non-spike-events"
326+
unit_class = "non-spike-events"
327+
warnings.warn(f"Unit {unit_id} for channel {channel_id} is non-spike events!")
328+
329+
spike_channels.append((name, _id, wf_units, wf_gain, wf_offset, wf_left_sweep, wf_sampling_rate, unit_class))
304330

305331
# scan events
306332
# NonNeural: serial and digital input
@@ -520,7 +546,7 @@ def _parse_header(self):
520546
self._sigs_t_starts = [None] * self._nb_segment
521547

522548
# finalize header
523-
spike_channels = np.array(spike_channels, dtype=_spike_channel_dtype)
549+
spike_channels = np.array(spike_channels, dtype=self._spike_channel_dtype)
524550
event_channels = np.array(event_channels, dtype=_event_channel_dtype)
525551
signal_channels = np.array(signal_channels, dtype=_signal_channel_dtype)
526552
signal_streams = np.array(signal_streams, dtype=_signal_stream_dtype)

0 commit comments

Comments
 (0)