Skip to content

Commit cbb8653

Browse files
authored
Merge pull request #352 from h-mayorquin/assertion_for_unique_positions
Check for unique positions within probe
2 parents 27137e9 + bfd71ce commit cbb8653

File tree

3 files changed

+80
-0
lines changed

3 files changed

+80
-0
lines changed

src/probeinterface/probe.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,40 @@
88
_possible_contact_shapes = ["circle", "square", "rect"]
99

1010

11+
def _raise_non_unique_positions_error(positions):
12+
"""
13+
Check for duplicate positions and raise ValueError with detailed information.
14+
15+
Parameters
16+
----------
17+
positions : array
18+
Array of positions to check for duplicates.
19+
20+
Raises
21+
------
22+
ValueError
23+
If duplicate positions are found, with detailed information about duplicates.
24+
"""
25+
duplicates = {}
26+
for index, pos in enumerate(positions):
27+
pos_key = tuple(pos)
28+
if pos_key in duplicates:
29+
duplicates[pos_key].append(index)
30+
else:
31+
duplicates[pos_key] = [index]
32+
33+
duplicate_groups = {pos: indices for pos, indices in duplicates.items() if len(indices) > 1}
34+
duplicate_info = []
35+
for pos, indices in duplicate_groups.items():
36+
pos_str = f"({', '.join(map(str, pos))})"
37+
indices_str = f"[{', '.join(map(str, indices))}]"
38+
duplicate_info.append(f"Position {pos_str} appears at indices {indices_str}")
39+
40+
raise ValueError(
41+
f"Contact positions must be unique within a probe. Found {len(duplicate_groups)} duplicate(s): {'; '.join(duplicate_info)}"
42+
)
43+
44+
1145
class Probe:
1246
"""
1347
Class to handle the geometry of one probe.
@@ -279,6 +313,12 @@ def set_contacts(
279313
if positions.shape[1] != self.ndim:
280314
raise ValueError(f"positions.shape[1]: {positions.shape[1]} and ndim: {self.ndim} do not match!")
281315

316+
# Check for duplicate positions
317+
unique_positions = np.unique(positions, axis=0)
318+
positions_are_not_unique = unique_positions.shape[0] != positions.shape[0]
319+
if positions_are_not_unique:
320+
_raise_non_unique_positions_error(positions)
321+
282322
self._contact_positions = positions
283323
n = positions.shape[0]
284324

tests/test_probe.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,21 @@ def test_save_to_zarr(tmp_path):
182182
assert probe == reloaded_probe, "Reloaded Probe object does not match the original"
183183

184184

185+
def test_position_uniqueness():
186+
"""Test that the error message matches the full expected string for three duplicates using pytest's match regex."""
187+
import re
188+
189+
positions_with_dups = np.array([[0, 0], [10, 10], [0, 0], [20, 20], [0, 0], [10, 10]])
190+
probe = Probe(ndim=2, si_units="um")
191+
expected_error = (
192+
"Contact positions must be unique within a probe. "
193+
"Found 2 duplicate(s): Position (0, 0) appears at indices [0, 2, 4]; Position (10, 10) appears at indices [1, 5]"
194+
)
195+
196+
with pytest.raises(ValueError, match=re.escape(expected_error)):
197+
probe.set_contacts(positions=positions_with_dups, shapes="circle", shape_params={"radius": 5})
198+
199+
185200
if __name__ == "__main__":
186201
test_probe()
187202

tests/test_probegroup.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,31 @@ def test_probegroup_3d():
6767
assert probegroup.ndim == 3
6868

6969

70+
def test_probegroup_allows_duplicate_positions_across_probes():
71+
"""Test that ProbeGroup allows duplicate contact positions if they are in different probes."""
72+
from probeinterface import ProbeGroup, Probe
73+
import numpy as np
74+
75+
# Probes have the same internal relative positions
76+
positions = np.array([[0, 0], [10, 10]])
77+
probe1 = Probe(ndim=2, si_units="um")
78+
probe1.set_contacts(positions=positions, shapes="circle", shape_params={"radius": 5})
79+
probe2 = Probe(ndim=2, si_units="um")
80+
probe2.set_contacts(positions=positions, shapes="circle", shape_params={"radius": 5})
81+
82+
group = ProbeGroup()
83+
group.add_probe(probe1)
84+
group.add_probe(probe2)
85+
86+
# Should not raise any error
87+
all_positions = np.vstack([p.contact_positions for p in group.probes])
88+
# There are duplicates across probes, but this is allowed
89+
assert (all_positions == [0, 0]).any()
90+
assert (all_positions == [10, 10]).any()
91+
# The group should have both probes
92+
assert len(group.probes) == 2
93+
94+
7095
if __name__ == "__main__":
7196
test_probegroup()
7297
# ~ test_probegroup_3d()

0 commit comments

Comments
 (0)