Skip to content

Commit 5a58dd9

Browse files
committed
initial feature scatter
1 parent 43dd094 commit 5a58dd9

File tree

2 files changed

+173
-16
lines changed

2 files changed

+173
-16
lines changed

src/napari_matplotlib/base.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,13 +79,14 @@ def setup_callbacks(self) -> None:
7979
# z-step changed in viewer
8080
self.viewer.dims.events.current_step.connect(self._draw)
8181
# Layer selection changed in viewer
82-
self.viewer.layers.selection.events.active.connect(self.update_layers)
82+
self.viewer.layers.selection.events.changed.connect(self.update_layers)
8383

8484
def update_layers(self, event: napari.utils.events.Event) -> None:
8585
"""
8686
Update the currently selected layers and re-draw.
8787
"""
8888
self.layers = list(self.viewer.layers.selection)
89+
self._on_update_layers()
8990
self._draw()
9091

9192
def _draw(self) -> None:
@@ -95,6 +96,7 @@ def _draw(self) -> None:
9596
"""
9697
self.clear()
9798
if self.n_selected_layers != self.n_layers_input:
99+
self.canvas.draw()
98100
return
99101
self.draw()
100102
self.canvas.draw()
@@ -112,3 +114,9 @@ def draw(self) -> None:
112114
113115
This is a no-op, and is intended for derived classes to override.
114116
"""
117+
118+
def _on_update_layers(self) -> None:
119+
"""This function is called when self.layers is updated via self.update_layer()
120+
121+
This is a no-op, and is intended for derived classes to override.
122+
"""

src/napari_matplotlib/scatter.py

Lines changed: 164 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,39 +1,188 @@
1+
from typing import List, Tuple, Union
2+
13
import matplotlib.colors as mcolor
24
import napari
5+
import numpy as np
6+
from magicgui import magicgui
37

48
from .base import NapariMPLWidget
59

610
__all__ = ["ScatterWidget"]
711

812

9-
class ScatterWidget(NapariMPLWidget):
10-
"""
11-
Widget to display scatter plot of two similarly shaped layers.
13+
class ScatterBaseWidget(NapariMPLWidget):
14+
def __init__(
15+
self,
16+
napari_viewer: napari.viewer.Viewer,
17+
marker_alpha: float = 0.5,
18+
histogram_for_large_data: bool = True,
19+
):
20+
super().__init__(napari_viewer)
1221

13-
If there are more than 500 data points, a 2D histogram is displayed instead
14-
of a scatter plot, to avoid too many scatter points.
15-
"""
22+
# flag set to True if histogram should be used
23+
# for plotting large points
24+
self.histogram_for_large_data = histogram_for_large_data
1625

17-
n_layers_input = 2
26+
# set plotting visualization attributes
27+
self._marker_alpha = 0.5
1828

19-
def __init__(self, napari_viewer: napari.viewer.Viewer):
20-
super().__init__(napari_viewer)
2129
self.axes = self.canvas.figure.subplots()
2230
self.update_layers(None)
2331

32+
@property
33+
def marker_alpha(self) -> float:
34+
"""Alpha (opacity) for the scatter markers"""
35+
return self._marker_alpha
36+
37+
@marker_alpha.setter
38+
def marker_alpha(self, alpha: float):
39+
self._marker_alpha = alpha
40+
self._draw()
41+
42+
def clear(self) -> None:
43+
self.axes.clear()
44+
2445
def draw(self) -> None:
2546
"""
2647
Clear the axes and scatter the currently selected layers.
2748
"""
28-
data = [layer.data[self.current_z] for layer in self.layers]
29-
if data[0].size < 500:
30-
self.axes.scatter(data[0], data[1], alpha=0.5)
31-
else:
49+
data, x_axis_name, y_axis_name = self._get_data()
50+
51+
if len(data) == 0:
52+
# don't plot if there isn't data
53+
return
54+
55+
if self.histogram_for_large_data and (data[0].size > 500):
3256
self.axes.hist2d(
3357
data[0].ravel(),
3458
data[1].ravel(),
3559
bins=100,
3660
norm=mcolor.LogNorm(),
3761
)
38-
self.axes.set_xlabel(self.layers[0].name)
39-
self.axes.set_ylabel(self.layers[1].name)
62+
else:
63+
self.axes.scatter(data[0], data[1], alpha=self.marker_alpha)
64+
65+
self.axes.set_xlabel(x_axis_name)
66+
self.axes.set_ylabel(y_axis_name)
67+
68+
def _get_data(self) -> Tuple[np.ndarray, str, str]:
69+
raise NotImplementedError
70+
71+
72+
class ScatterWidget(ScatterBaseWidget):
73+
"""
74+
Widget to display scatter plot of two similarly shaped layers.
75+
76+
If there are more than 500 data points, a 2D histogram is displayed instead
77+
of a scatter plot, to avoid too many scatter points.
78+
"""
79+
80+
n_layers_input = 2
81+
82+
def __init__(
83+
self,
84+
napari_viewer: napari.viewer.Viewer,
85+
marker_alpha: float = 0.5,
86+
histogram_for_large_data: bool = True,
87+
):
88+
super().__init__(
89+
napari_viewer,
90+
marker_alpha=marker_alpha,
91+
histogram_for_large_data=histogram_for_large_data,
92+
)
93+
94+
def _get_data(self) -> Tuple[np.ndarray, str, str]:
95+
data = [layer.data[self.current_z] for layer in self.layers]
96+
x_axis_name = self.layers[0].name
97+
y_axis_name = self.layers[1].name
98+
99+
return data, x_axis_name, y_axis_name
100+
101+
102+
class FeaturesScatterWidget(ScatterBaseWidget):
103+
n_layers_input = 1
104+
105+
def __init__(
106+
self,
107+
napari_viewer: napari.viewer.Viewer,
108+
marker_alpha: float = 0.5,
109+
histogram_for_large_data: bool = True,
110+
key_selection_gui: bool = True,
111+
):
112+
self._key_selection_widget = None
113+
super().__init__(
114+
napari_viewer,
115+
marker_alpha=marker_alpha,
116+
histogram_for_large_data=histogram_for_large_data,
117+
)
118+
119+
if key_selection_gui is True:
120+
self._key_selection_widget = magicgui(
121+
self._set_axis_keys,
122+
x_axis_key={"choices": self._get_valid_axis_keys},
123+
y_axis_key={"choices": self._get_valid_axis_keys},
124+
call_button="plot",
125+
)
126+
self.layout().addWidget(self._key_selection_widget.native)
127+
128+
@property
129+
def x_axis_key(self) -> Union[None, str]:
130+
"""Key to access x axis data from the FeaturesTable"""
131+
return self._x_axis_key
132+
133+
@x_axis_key.setter
134+
def x_axis_key(self, key: Union[None, str]):
135+
self._x_axis_key = key
136+
self._draw()
137+
138+
@property
139+
def y_axis_key(self) -> Union[None, str]:
140+
"""Key to access y axis data from the FeaturesTable"""
141+
return self._y_axis_key
142+
143+
@y_axis_key.setter
144+
def y_axis_key(self, key: Union[None, str]):
145+
self._y_axis_key = key
146+
self._draw()
147+
148+
def _set_axis_keys(self, x_axis_key: str, y_axis_key: str):
149+
"""Set both axis keys and then redraw the plot"""
150+
self._x_axis_key = x_axis_key
151+
self._y_axis_key = y_axis_key
152+
self._draw()
153+
154+
def _get_valid_axis_keys(self, combo_widget=None) -> List[str]:
155+
if len(self.layers) == 0:
156+
return []
157+
else:
158+
return self.layers[0].features.keys()
159+
160+
def _get_data(self) -> Tuple[np.ndarray, str, str]:
161+
feature_table = self.layers[0].features
162+
163+
if (
164+
(len(feature_table) == 0)
165+
or (self.x_axis_key is None)
166+
or (self.y_axis_key is None)
167+
):
168+
return np.array([]), "", ""
169+
170+
data_x = feature_table[self.x_axis_key]
171+
data_y = feature_table[self.y_axis_key]
172+
data = np.stack((data_x, data_y))
173+
174+
x_axis_name = self.x_axis_key.replace("_", " ")
175+
y_axis_name = self.y_axis_key.replace("_", " ")
176+
177+
return data, x_axis_name, y_axis_name
178+
179+
def _on_update_layers(self) -> None:
180+
"""This is called when the layer selection changes
181+
by self.update_layers().
182+
"""
183+
if self._key_selection_widget is not None:
184+
self._key_selection_widget.reset_choices()
185+
186+
# reset the axis keys
187+
self._x_axis_key = None
188+
self._y_axis_key = None

0 commit comments

Comments
 (0)