Skip to content

Commit 43c57c7

Browse files
committed
Refactor Open Ephys and add mux table contact annotation
1 parent 783dba9 commit 43c57c7

File tree

2 files changed

+68
-167
lines changed

2 files changed

+68
-167
lines changed

resources/generate_neuropixels_library.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import numpy as np
77
import matplotlib.pyplot as plt
88

9-
from probeinterface.neuropixels_tools import make_npx_description, _make_npx_probe_from_description
9+
from probeinterface.neuropixels_tools import _make_npx_probe_from_description, get_probe_metadata_from_probe_features
1010
from probeinterface.plotting import plot_probe
1111
from probeinterface import write_probeinterface
1212

@@ -37,18 +37,18 @@ def generate_all_npx():
3737
probe_folder = base_folder / probe_number
3838
probe_folder.mkdir(exist_ok=True)
3939

40-
probe_description = make_npx_description(probe_number)
40+
pt_metadata, _, _ = get_probe_metadata_from_probe_features(probe_features, probe_number)
4141

42-
num_shank = probe_description["shank_number"]
43-
contact_per_shank = probe_description["ncols_per_shank"] * probe_description["nrows_per_shank"]
42+
num_shank = pt_metadata["num_shanks"]
43+
contact_per_shank = pt_metadata["cols_per_shank"] * pt_metadata["rows_per_shank"]
4444
if num_shank == 1:
4545
elec_ids = np.arange(contact_per_shank)
4646
shank_ids = None
4747
else:
4848
elec_ids = np.concatenate([np.arange(contact_per_shank) for i in range(num_shank)])
4949
shank_ids = np.concatenate([np.zeros(contact_per_shank) + i for i in range(num_shank)])
5050

51-
probe = _make_npx_probe_from_description(probe_description, elec_ids, shank_ids)
51+
probe = _make_npx_probe_from_description(pt_metadata, elec_ids, shank_ids)
5252

5353
# ploting
5454
fig, axs = plt.subplots(ncols=2)
@@ -69,7 +69,7 @@ def generate_all_npx():
6969
plot_probe(probe, ax=ax)
7070
ax.set_title("")
7171

72-
yp = probe_description["y_pitch"]
72+
yp = pt_metadata["electrode_pitch_vert_um"]
7373
ax.set_ylim(-yp*8, yp*13)
7474
ax.yaxis.set_visible(False)
7575
ax.spines["top"].set_visible(False)

src/probeinterface/neuropixels_tools.py

Lines changed: 62 additions & 161 deletions
Original file line numberDiff line numberDiff line change
@@ -113,93 +113,6 @@ def get_probe_length(probe_part_number: str) -> int:
113113
return probe_length
114114

115115

116-
def make_npx_description(probe_part_number):
117-
"""
118-
Extracts probe metadata from the `probeinterface/resources/probe_features.json` file and converts
119-
to probeinterface syntax. File is maintained by Bill Karsh in ProbeTable
120-
(https://github.com/billkarsh/ProbeTable/tree/main).
121-
122-
Parameters
123-
----------
124-
probe_part_number : str
125-
The part number of the probe e.g. 'NP2013'.
126-
127-
Returns
128-
-------
129-
pi_metadata : dict
130-
Dictionary containing metadata about NeuroPixels probes using ProbeInterface syntax.
131-
"""
132-
133-
probe_features_filepath = Path(__file__).absolute().parent / Path("resources/probe_features.json")
134-
probe_features = json.load(open(probe_features_filepath, "r"))
135-
136-
# We use `pt` and `pi` as shorthand for `ProbeTable` and `ProbeInterface` throughout this function
137-
pt_metadata = probe_features["neuropixels_probes"].get(probe_part_number)
138-
pi_metadata = {}
139-
140-
if pt_metadata is None:
141-
warnings.warn(f"Probe part number {probe_part_number} not known. Assume a NP1.0 probe.")
142-
pt_metadata = probe_features["neuropixels_probes"].get("NP1010")
143-
144-
# Extract most of the metadata
145-
for pi_name, pt_name in pi_to_pt_names.items():
146-
if pt_name in ["num_shanks", "cols_per_shank", "rows_per_shank", "adc_bit_depth", "num_readout_channels"]:
147-
pi_metadata[pi_name] = int(pt_metadata[pt_name])
148-
elif pt_name in [
149-
"electrode_pitch_horz_um",
150-
"electrode_pitch_vert_um",
151-
"electrode_size_horz_direction_um",
152-
"shank_pitch_um",
153-
]:
154-
pi_metadata[pi_name] = float(pt_metadata[pt_name])
155-
else:
156-
pi_metadata[pi_name] = pt_metadata[pt_name]
157-
158-
# Use offsets to compute stagger and contour shift
159-
odd_row_horz_offset_left_edge_to_leftmost_electrode_center_um = float(
160-
pt_metadata["odd_row_horz_offset_left_edge_to_leftmost_electrode_center_um"]
161-
)
162-
even_row_horz_offset_left_edge_to_leftmost_electrode_center_um = float(
163-
pt_metadata["even_row_horz_offset_left_edge_to_leftmost_electrode_center_um"]
164-
)
165-
middle_of_bottommost_electrode_to_top_of_shank_tip = 11
166-
pi_metadata["contour_shift"] = [
167-
-odd_row_horz_offset_left_edge_to_leftmost_electrode_center_um,
168-
-middle_of_bottommost_electrode_to_top_of_shank_tip,
169-
]
170-
pi_metadata["stagger"] = (
171-
even_row_horz_offset_left_edge_to_leftmost_electrode_center_um
172-
- odd_row_horz_offset_left_edge_to_leftmost_electrode_center_um
173-
)
174-
175-
# Read the imro table formats to find out which fields the imro tables contain
176-
imro_table_format_type = pt_metadata["imro_table_format_type"]
177-
imro_table_fields = probe_features["z_imro_formats"][imro_table_format_type + "_elm_flds"]
178-
179-
# parse the imro_table_fields, which look like (value value value ...)
180-
list_of_imro_fields = imro_table_fields.replace("(", "").replace(")", "").split(" ")
181-
182-
pi_imro_fields = []
183-
for imro_field in list_of_imro_fields:
184-
pi_imro_fields.append(imro_field_to_pi_field[imro_field])
185-
pi_metadata["fields_in_imro_table"] = tuple(pi_imro_fields)
186-
187-
# Construct probe contour, for styling the probe
188-
shank_width = float(pt_metadata["shank_width_um"])
189-
tip_length = float(pt_metadata["tip_length_um"])
190-
191-
probe_length = get_probe_length(probe_part_number)
192-
pi_metadata["contour_description"] = get_probe_contour_vertices(shank_width, tip_length, probe_length)
193-
194-
# Get the mux table. This describes which electrodes are multiplexed together, meaning
195-
# which electrodes are sampled at the same time.
196-
mux_table_format_type = pt_metadata["mux_table_format_type"]
197-
mux_information = probe_features["z_mux_tables"].get(mux_table_format_type)
198-
pi_metadata["mux_table_array"] = make_mux_table_array(mux_information)
199-
200-
return pi_metadata
201-
202-
203116
def make_mux_table_array(mux_information) -> np.array:
204117
"""
205118
Function to parse the mux_table from ProbeTable.
@@ -314,7 +227,7 @@ def read_imro(file_path: Union[str, Path]) -> Probe:
314227
return _read_imro_string(imro_str, imDatPrb_pn)
315228

316229

317-
def _make_npx_probe_from_description(probe_description, elec_ids, shank_ids):
230+
def _make_npx_probe_from_description(probe_description, elec_ids, shank_ids, mux_table=None) -> Probe:
318231
# used by _read_imro_string and for generating the NP library
319232

320233
model_name = probe_description["description"]
@@ -392,6 +305,17 @@ def _make_npx_probe_from_description(probe_description, elec_ids, shank_ids):
392305
# wire it
393306
probe.set_device_channel_indices(np.arange(positions.shape[0]))
394307

308+
# annotate with MUX table
309+
if mux_table is not None:
310+
print("Adding MUX table to probe")
311+
# annotate each contact with its mux channel
312+
num_contacts = positions.shape[0]
313+
mux_channels = np.zeros(num_contacts, dtype="int64")
314+
for adc_idx, mux_channels_per_adc in enumerate(mux_table):
315+
mux_channels_per_adc = mux_channels_per_adc[mux_channels_per_adc < num_contacts]
316+
mux_channels[mux_channels_per_adc] = adc_idx
317+
probe.annotate_contacts(mux_channels=mux_channels)
318+
395319
return probe
396320

397321

@@ -424,7 +348,7 @@ def _read_imro_string(imro_str: str, imDatPrb_pn: Optional[str] = None) -> Probe
424348

425349
probe_features_filepath = Path(__file__).absolute().parent / Path("resources/probe_features.json")
426350
probe_features = json.load(open(probe_features_filepath, "r"))
427-
pt_metadata, fields = get_probe_metadata_from_probe_features(probe_features, imDatPrb_pn)
351+
pt_metadata, fields, mux_table = get_probe_metadata_from_probe_features(probe_features, imDatPrb_pn)
428352

429353
# fields = probe_description["fields_in_imro_table"]
430354
contact_info = {k: [] for k in fields}
@@ -450,7 +374,7 @@ def _read_imro_string(imro_str: str, imDatPrb_pn: Optional[str] = None) -> Probe
450374
else:
451375
shank_ids = None
452376

453-
probe = _make_npx_probe_from_description(pt_metadata, elec_ids, shank_ids)
377+
probe = _make_npx_probe_from_description(pt_metadata, elec_ids, shank_ids, mux_table)
454378

455379
# scalar annotations
456380
probe.annotate(
@@ -478,14 +402,14 @@ def get_probe_metadata_from_probe_features(probe_features: dict, imDatPrb_pn: st
478402
to construct a probe with part number `imDatPrb_pn`.
479403
480404
Parameters
481-
==========
405+
----------
482406
probe_features : dict
483407
Dictionary obtained when reading in the `probe_features.json` file.
484408
imDatPrb_pn : str
485409
Probe part number.
486410
487411
Returns
488-
=======
412+
-------
489413
probe_metadata, imro_field
490414
Dictionary of probe metadata.
491415
Tuple of fields included in the `imro_table_fields`.
@@ -520,7 +444,16 @@ def get_probe_metadata_from_probe_features(probe_features: dict, imDatPrb_pn: st
520444

521445
imro_fields = tuple(imro_fields_list)
522446

523-
return probe_metadata, imro_fields
447+
# Read MUX table information
448+
mux_table = None
449+
450+
if "z_mux_tables" in probe_features:
451+
mux_table_format_type = probe_metadata.get("mux_table_format_type", None)
452+
mux_information = probe_features["z_mux_tables"].get(mux_table_format_type, None)
453+
if mux_information is not None:
454+
mux_table = make_mux_table_array(mux_information)
455+
456+
return probe_metadata, imro_fields, mux_table
524457

525458

526459
def write_imro(file: str | Path, probe: Probe):
@@ -894,6 +827,9 @@ def read_openephys(
894827
)
895828
return None
896829

830+
probe_features_filepath = Path(__file__).absolute().parent / Path("resources/probe_features.json")
831+
probe_features = json.load(open(probe_features_filepath, "r"))
832+
897833
# now load probe info from NP_PROBE fields
898834
np_probes_info = []
899835
for probe_idx, np_probe in enumerate(np_probes):
@@ -932,9 +868,9 @@ def read_openephys(
932868
positions = np.array([xpos, ypos]).T
933869

934870
probe_part_number = np_probe.get("probe_part_number", None)
871+
pt_metadata, _, mux_table = get_probe_metadata_from_probe_features(probe_features, probe_part_number)
935872

936-
probe_dict = make_npx_description(probe_part_number)
937-
shank_pitch = probe_dict["shank_pitch"]
873+
shank_pitch = pt_metadata["shank_pitch_um"]
938874

939875
if fix_x_position_for_oe_5 and oe_version < parse("0.6.0") and shank_ids is not None:
940876
positions[:, 1] = positions[:, 1] - shank_pitch * shank_ids
@@ -946,21 +882,27 @@ def read_openephys(
946882
offset -= np.min(shank_ids) * shank_pitch
947883
positions[:, 0] -= offset
948884

949-
contact_ids = []
950-
y_pitch = probe_dict["y_pitch"] # Vertical spacing between the centers of adjacent contacts
951-
x_pitch = probe_dict["x_pitch"] # Horizontal spacing between the centers of contacts within the same row
952-
number_of_columns = probe_dict["ncols_per_shank"]
953-
probe_stagger = probe_dict["stagger"]
954-
shank_number = probe_dict["shank_number"]
885+
#
886+
y_pitch = pt_metadata["electrode_pitch_vert_um"] # Vertical spacing between the centers of adjacent contacts
887+
x_pitch = pt_metadata[
888+
"electrode_pitch_horz_um"
889+
] # Horizontal spacing between the centers of contacts within the same row
890+
number_of_columns = pt_metadata["cols_per_shank"]
891+
probe_stagger = (
892+
pt_metadata["even_row_horz_offset_left_edge_to_leftmost_electrode_center_um"]
893+
- pt_metadata["odd_row_horz_offset_left_edge_to_leftmost_electrode_center_um"]
894+
)
895+
num_shanks = pt_metadata["num_shanks"]
955896

956-
model_name = probe_dict.get("model_name")
897+
model_name = pt_metadata.get("description")
957898
if model_name is None:
958899
model_name = "Unknown"
959900

901+
elec_ids = []
960902
for i, pos in enumerate(positions):
961903
# Do not calculate contact ids if the model name is not known
962904
if model_name == "Unknown":
963-
contact_ids = None
905+
elec_ids = None
964906
break
965907

966908
x_pos = pos[0]
@@ -971,34 +913,28 @@ def read_openephys(
971913
row_stagger = probe_stagger if is_row_staggered else 0
972914

973915
# Map the positions to the contacts ids
974-
shank_id = shank_ids[i] if shank_number > 1 else 0
916+
shank_id = shank_ids[i] if num_shanks > 1 else 0
975917

976-
# Contact ids are computed from the positions of the electrodes. The computation
918+
# Electrode ids are computed from the positions of the electrodes. The computation
977919
# is different for probes with one row of electrodes, or more than one.
978920
if x_pitch == 0:
979-
contact_id = int(number_of_columns * y_pos / y_pitch)
921+
elec_id = int(number_of_columns * y_pos / y_pitch)
980922
else:
981-
contact_id = int(
923+
elec_id = int(
982924
(x_pos - row_stagger - shank_pitch * shank_id) / x_pitch + number_of_columns * y_pos / y_pitch
983925
)
984-
if shank_number > 1:
985-
contact_ids.append(f"s{shank_id}e{contact_id}")
986-
else:
987-
contact_ids.append(f"e{contact_id}")
988-
989-
mux_table_array = probe_dict["mux_table_array"]
926+
elec_ids.append(elec_id)
990927

991928
np_probe_dict = {
992-
"model_name": model_name,
993929
"shank_ids": shank_ids,
994-
"contact_ids": contact_ids,
995-
"positions": positions,
930+
"elec_ids": elec_ids,
931+
"pt_metadata": pt_metadata,
996932
"slot": slot,
997933
"port": port,
998934
"dock": dock,
999935
"serial_number": probe_serial_number,
1000936
"part_number": probe_part_number,
1001-
"mux_table_array": mux_table_array,
937+
"mux_table": mux_table,
1002938
}
1003939
# Sequentially assign probe names
1004940
if "custom_probe_name" in np_probe.attrib and np_probe.attrib["custom_probe_name"] != probe_serial_number:
@@ -1093,14 +1029,10 @@ def read_openephys(
10931029

10941030
np_probe_info = np_probes_info[probe_idx]
10951031
np_probe = np_probes[probe_idx]
1096-
positions = np_probe_info["positions"]
10971032
shank_ids = np_probe_info["shank_ids"]
1098-
1099-
contact_width = probe_dict["contact_width"]
1100-
num_shanks = probe_dict["shank_number"]
1101-
contour_description = probe_dict["contour_description"]
1102-
1103-
contact_ids = np_probe_info["contact_ids"] if np_probe_info["contact_ids"] is not None else None
1033+
elec_ids = np_probe_info["elec_ids"]
1034+
pt_metadata = np_probe_info["pt_metadata"]
1035+
mux_table = np_probe_info["mux_table"]
11041036

11051037
# check if subset of channels
11061038
chans_saved = get_saved_channel_indices_from_openephys_settings(settings_file, stream_name=stream_name)
@@ -1110,50 +1042,19 @@ def read_openephys(
11101042
positions = positions[chans_saved]
11111043
if shank_ids is not None:
11121044
shank_ids = np.array(shank_ids)[chans_saved]
1113-
if contact_ids is not None:
1114-
contact_ids = np.array(contact_ids)[chans_saved]
1115-
1116-
probe = Probe(
1117-
ndim=2,
1118-
si_units="um",
1119-
name=np_probe_info["name"],
1120-
serial_number=np_probe_info["serial_number"],
1121-
manufacturer="IMEC",
1122-
model_name=np_probe_info["model_name"],
1123-
)
1124-
probe.set_contacts(
1125-
positions=positions,
1126-
shapes="square",
1127-
shank_ids=shank_ids,
1128-
shape_params={"width": contact_width},
1129-
)
1045+
if elec_ids is not None:
1046+
elec_ids = np.array(elec_ids)[chans_saved]
1047+
1048+
probe = _make_npx_probe_from_description(pt_metadata, elec_ids, shank_ids=shank_ids, mux_table=mux_table)
1049+
probe.serial_number = np_probe_info["serial_number"]
1050+
11301051
probe.annotate(
11311052
part_number=np_probe_info["part_number"],
11321053
slot=np_probe_info["slot"],
11331054
dock=np_probe_info["dock"],
11341055
port=np_probe_info["port"],
1135-
mux_table_array=np_probe_info["mux_table_array"],
11361056
)
11371057

1138-
if contact_ids is not None:
1139-
probe.set_contact_ids(contact_ids)
1140-
1141-
polygon = contour_description
1142-
contour_shift = np.array(probe_dict["contour_shift"])
1143-
if shank_ids is None:
1144-
contour = polygon
1145-
else:
1146-
contour = []
1147-
for i in range(num_shanks):
1148-
contour += list(np.array(polygon) + [shank_pitch * i, 0])
1149-
1150-
# shift
1151-
contour = np.array(contour) + contour_shift
1152-
probe.set_planar_contour(contour)
1153-
1154-
# wire it
1155-
probe.set_device_channel_indices(np.arange(positions.shape[0]))
1156-
11571058
return probe
11581059

11591060

0 commit comments

Comments
 (0)