1010
1111import numpy as np
1212from matplotlib .backend_bases import MouseButton
13+ from matplotlib .cbook import CallbackRegistry
1314
1415
1516class clicker :
@@ -100,6 +101,7 @@ class in *classes*
100101 self ._fig .canvas .mpl_connect ("pick_event" , self ._on_pick )
101102 self ._positions = {c : [] for c in self ._classes }
102103 self ._update_legend_alpha ()
104+ self ._observers = CallbackRegistry ()
103105
104106 def get_positions (self , copy = True ):
105107 return {k : np .asarray (v ) for k , v in self ._positions .items ()}
@@ -114,6 +116,7 @@ def _on_pick(self, event):
114116 return
115117 self ._current_class = klass
116118 self ._update_legend_alpha ()
119+ self ._observers .process ('class-changed' , klass )
117120
118121 def _update_legend_alpha (self ):
119122 for c in self ._classes :
@@ -122,13 +125,25 @@ def _update_legend_alpha(self):
122125 a .set_alpha (alpha )
123126 self ._fig .canvas .draw ()
124127
128+ def _has_cbs (self , name ):
129+ """return whether there are callbacks registered for the current class"""
130+ try :
131+ return len (self ._observers .callbacks [name ]) > 0
132+ except KeyError :
133+ return False
134+
125135 def _clicked (self , event ):
126136 if not self ._fig .canvas .widgetlock .available (self ):
127137 return
128138 if event .inaxes is self .ax :
129139 if event .button is MouseButton .LEFT :
130140 self ._positions [self ._current_class ].append ((event .xdata , event .ydata ))
131141 self ._update_points (self ._current_class )
142+ self ._observers .process (
143+ 'point-added' ,
144+ (event .xdata , event .ydata ),
145+ self ._current_class ,
146+ )
132147 elif event .button is MouseButton .RIGHT :
133148 pos = self ._positions [self ._current_class ]
134149 if len (pos ) == 0 :
@@ -139,8 +154,14 @@ def _clicked(self, event):
139154 axis = - 1 ,
140155 )
141156 idx = np .argmin (dists [0 ])
142- pos .pop (idx )
157+ removed = pos .pop (idx )
143158 self ._update_points (self ._current_class )
159+ self ._observers .process (
160+ 'point-removed' ,
161+ removed ,
162+ self ._current_class ,
163+ idx ,
164+ )
144165
145166 def _update_points (self , klass = None ):
146167 if klass is None :
@@ -154,3 +175,57 @@ def _update_points(self, klass=None):
154175 new_off = np .zeros ([0 , 2 ])
155176 self ._lines [c ].set_data (new_off .T )
156177 self ._fig .canvas .draw ()
178+
179+ def on_point_added (self , func ):
180+ """
181+ Connect *func* as a callback function to new points being added.
182+ *func* will receive the the position of the new point as a tuple (x, y), and
183+ the class of the new point.
184+
185+ Parameters
186+ ----------
187+ func : callable
188+ Function to call when a point is added.
189+
190+ Returns
191+ -------
192+ int
193+ Connection id (which can be used to disconnect *func*).
194+ """
195+ return self ._observers .connect ('point-added' , lambda * args : func (* args ))
196+
197+ def on_point_removed (self , func ):
198+ """
199+ Connect *func* as a callback function when points are removed.
200+ *func* will receive the the position of the new point, the class of the removed point,
201+ the point's index in the old list of points of that class, and the updated dictionary of
202+ all points.
203+
204+ Parameters
205+ ----------
206+ func : callable
207+ Function to call when a point is removed
208+
209+ Returns
210+ -------
211+ int
212+ Connection id (which can be used to disconnect *func*).
213+ """
214+ return self ._observers .connect ('point-removed' , lambda * args : func (* args ))
215+
216+ def on_class_changed (self , func ):
217+ """
218+ Connect *func* as a callback function when the current class is changed.
219+ *func* will receive the new class.
220+
221+ Parameters
222+ ----------
223+ func : callable
224+ Function to call when *set_positions* is called.
225+
226+ Returns
227+ -------
228+ int
229+ Connection id (which can be used to disconnect *func*).
230+ """
231+ self ._observers .connect ('class-changed' , lambda klass : func (klass ))
0 commit comments