Skip to content

Commit 2c14948

Browse files
committed
add callback system
1 parent 81ebee5 commit 2c14948

File tree

3 files changed

+178
-1
lines changed

3 files changed

+178
-1
lines changed

docs/examples/callbacks.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
"""
2+
---------
3+
Callbacks
4+
---------
5+
6+
Demonstration of how to set up callback functions.
7+
8+
"""
9+
10+
import matplotlib.pyplot as plt
11+
import numpy as np
12+
13+
from typing import Tuple
14+
from mpl_point_clicker import clicker
15+
16+
from pathlib import Path
17+
18+
image = np.load(Path(__file__).parent / "example_image.npy")
19+
20+
fig, ax = plt.subplots()
21+
ax.imshow(image, cmap='gray')
22+
klicker = clicker(ax, ['cells', 'pdms', 'media'], markers=['o', 'x', '*'])
23+
24+
25+
def class_changed_cb(new_class: str):
26+
print(f'The newly selected class is {new_class}')
27+
28+
29+
def point_added_cb(position: Tuple[float, float], klass: str):
30+
x, y = position
31+
print(f"New point of class {klass} added at {x=}, {y=}")
32+
33+
34+
def point_removed_cb(position: Tuple[float, float], klass: str, idx):
35+
x, y = position
36+
37+
suffix = {1: 'st', 2: 'nd', 3: 'rd'}.get(idx, 'th')
38+
print(
39+
f"The {idx}{suffix} point of class {klass} with position {x=:.2f}, {y=:.2f} was removed"
40+
)
41+
42+
43+
klicker.on_class_changed(class_changed_cb)
44+
klicker.on_point_added(point_added_cb)
45+
klicker.on_point_removed(point_removed_cb)
46+
47+
48+
plt.show()
49+
50+
51+
print(klicker.get_positions())

examples/callbacks.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
"""
2+
---------
3+
Callbacks
4+
---------
5+
6+
Demonstration of how to set up callback functions.
7+
8+
"""
9+
10+
import matplotlib.pyplot as plt
11+
import numpy as np
12+
13+
from typing import Tuple
14+
from mpl_point_clicker import clicker
15+
16+
from pathlib import Path
17+
18+
image = np.load(Path(__file__).parent / "example_image.npy")
19+
20+
fig, ax = plt.subplots()
21+
ax.imshow(image, cmap='gray')
22+
klicker = clicker(ax, ['cells', 'pdms', 'media'], markers=['o', 'x', '*'])
23+
24+
25+
def class_changed_cb(new_class: str):
26+
print(f'The newly selected class is {new_class}')
27+
28+
29+
def point_added_cb(position: Tuple[float, float], klass: str):
30+
x, y = position
31+
print(f"New point of class {klass} added at {x=}, {y=}")
32+
33+
34+
def point_removed_cb(position: Tuple[float, float], klass: str, idx):
35+
x, y = position
36+
37+
suffix = {1: 'st', 2: 'nd', 3: 'rd'}.get(idx, 'th')
38+
print(
39+
f"The {idx}{suffix} point of class {klass} with position {x=:.2f}, {y=:.2f} was removed"
40+
)
41+
42+
43+
klicker.on_class_changed(class_changed_cb)
44+
klicker.on_point_added(point_added_cb)
45+
klicker.on_point_removed(point_removed_cb)
46+
47+
48+
plt.show()
49+
50+
51+
print(klicker.get_positions())

mpl_point_clicker/_clicker.py

Lines changed: 76 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
import numpy as np
1212
from matplotlib.backend_bases import MouseButton
13+
from matplotlib.cbook import CallbackRegistry
1314

1415

1516
class clicker:
@@ -100,6 +101,7 @@ class in *classes*
100101
self._fig.canvas.mpl_connect("pick_event", self._on_pick)
101102
self._positions = {c: [] for c in self._classes}
102103
self._update_legend_alpha()
104+
self._observers = CallbackRegistry()
103105

104106
def get_positions(self, copy=True):
105107
return {k: np.asarray(v) for k, v in self._positions.items()}
@@ -114,6 +116,7 @@ def _on_pick(self, event):
114116
return
115117
self._current_class = klass
116118
self._update_legend_alpha()
119+
self._observers.process('class-changed', klass)
117120

118121
def _update_legend_alpha(self):
119122
for c in self._classes:
@@ -122,13 +125,25 @@ def _update_legend_alpha(self):
122125
a.set_alpha(alpha)
123126
self._fig.canvas.draw()
124127

128+
def _has_cbs(self, name):
129+
"""return whether there are callbacks registered for the current class"""
130+
try:
131+
return len(self._observers.callbacks[name]) > 0
132+
except KeyError:
133+
return False
134+
125135
def _clicked(self, event):
126136
if not self._fig.canvas.widgetlock.available(self):
127137
return
128138
if event.inaxes is self.ax:
129139
if event.button is MouseButton.LEFT:
130140
self._positions[self._current_class].append((event.xdata, event.ydata))
131141
self._update_points(self._current_class)
142+
self._observers.process(
143+
'point-added',
144+
(event.xdata, event.ydata),
145+
self._current_class,
146+
)
132147
elif event.button is MouseButton.RIGHT:
133148
pos = self._positions[self._current_class]
134149
if len(pos) == 0:
@@ -139,8 +154,14 @@ def _clicked(self, event):
139154
axis=-1,
140155
)
141156
idx = np.argmin(dists[0])
142-
pos.pop(idx)
157+
removed = pos.pop(idx)
143158
self._update_points(self._current_class)
159+
self._observers.process(
160+
'point-removed',
161+
removed,
162+
self._current_class,
163+
idx,
164+
)
144165

145166
def _update_points(self, klass=None):
146167
if klass is None:
@@ -154,3 +175,57 @@ def _update_points(self, klass=None):
154175
new_off = np.zeros([0, 2])
155176
self._lines[c].set_data(new_off.T)
156177
self._fig.canvas.draw()
178+
179+
def on_point_added(self, func):
180+
"""
181+
Connect *func* as a callback function to new points being added.
182+
*func* will receive the the position of the new point as a tuple (x, y), and
183+
the class of the new point.
184+
185+
Parameters
186+
----------
187+
func : callable
188+
Function to call when a point is added.
189+
190+
Returns
191+
-------
192+
int
193+
Connection id (which can be used to disconnect *func*).
194+
"""
195+
return self._observers.connect('point-added', lambda *args: func(*args))
196+
197+
def on_point_removed(self, func):
198+
"""
199+
Connect *func* as a callback function when points are removed.
200+
*func* will receive the the position of the new point, the class of the removed point,
201+
the point's index in the old list of points of that class, and the updated dictionary of
202+
all points.
203+
204+
Parameters
205+
----------
206+
func : callable
207+
Function to call when a point is removed
208+
209+
Returns
210+
-------
211+
int
212+
Connection id (which can be used to disconnect *func*).
213+
"""
214+
return self._observers.connect('point-removed', lambda *args: func(*args))
215+
216+
def on_class_changed(self, func):
217+
"""
218+
Connect *func* as a callback function when the current class is changed.
219+
*func* will receive the new class.
220+
221+
Parameters
222+
----------
223+
func : callable
224+
Function to call when *set_positions* is called.
225+
226+
Returns
227+
-------
228+
int
229+
Connection id (which can be used to disconnect *func*).
230+
"""
231+
self._observers.connect('class-changed', lambda klass: func(klass))

0 commit comments

Comments
 (0)