Skip to content

Commit b9fb712

Browse files
authored
Merge pull request #380 from BiAPoL/make-layer-coloring-more-robust
Make layer coloring more robust by adding more test cases
2 parents 3e295ea + 6cfad71 commit b9fb712

File tree

2 files changed

+198
-39
lines changed

2 files changed

+198
-39
lines changed

src/napari_clusters_plotter/_new_plotter_widget.py

Lines changed: 26 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -449,33 +449,29 @@ def _apply_layer_color(layer, colors):
449449
"""
450450
from napari.utils import DirectLabelColormap
451451

452-
color_mapping = {
453-
napari.layers.Points: lambda _layer, _color: setattr(
454-
_layer, "face_color", _color
455-
),
456-
napari.layers.Vectors: lambda _layer, _color: setattr(
457-
_layer, "edge_color", _color
458-
),
459-
napari.layers.Surface: lambda _layer, _color: setattr(
460-
_layer, "vertex_colors", _color
461-
),
462-
napari.layers.Shapes: lambda _layer, _color: setattr(
463-
_layer, "face_color", _color
464-
),
465-
napari.layers.Labels: lambda _layer, _color: setattr(
466-
_layer,
467-
"colormap",
468-
DirectLabelColormap(
469-
color_dict={
470-
label: _color[label] for label in np.unique(_layer.data)
471-
}
472-
),
473-
),
474-
}
475-
476-
if type(layer) in color_mapping:
477-
if type(layer) is napari.layers.Labels:
478-
# add a color for the background at the first index
479-
colors = np.insert(colors, 0, [0, 0, 0, 0], axis=0)
480-
color_mapping[type(layer)](layer, colors)
481-
layer.refresh()
452+
if isinstance(layer, napari.layers.Points):
453+
layer.face_color = colors
454+
455+
elif isinstance(layer, napari.layers.Vectors):
456+
layer.edge_color = colors
457+
458+
elif isinstance(layer, napari.layers.Surface):
459+
layer.vertex_colors = colors
460+
461+
elif isinstance(layer, napari.layers.Shapes):
462+
layer.face_color = colors
463+
464+
elif isinstance(layer, napari.layers.Labels):
465+
466+
colors = np.insert(colors, 0, [0, 0, 0, 0], axis=0)
467+
color_dict = dict(zip(np.unique(layer.data), colors))
468+
469+
# Insert default colors for labels that are not in the color_dict
470+
# Relevant for non-sequential label images
471+
if max(color_dict.keys()) > len(colors):
472+
for i in range(1, max(color_dict.keys()) - 1):
473+
color_dict[i] = [0, 0, 0, 0]
474+
# Add a color for the background at the first index
475+
layer.colormap = DirectLabelColormap(color_dict=color_dict)
476+
477+
layer.refresh()

src/napari_clusters_plotter/_tests/test_plotter.py

Lines changed: 172 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import numpy as np
2+
import pytest
23

34

45
def create_multi_point_layer(n_samples: int = 100):
@@ -48,6 +49,151 @@ def create_multi_point_layer(n_samples: int = 100):
4849
return layer, layer2
4950

5051

52+
def create_multi_vectors_layer(n_samples: int = 100):
53+
from napari.layers import Vectors
54+
55+
points1, points2 = create_multi_point_layer(n_samples=n_samples)
56+
57+
points_direction1 = np.random.normal(size=points1.data.shape)
58+
points_direction2 = np.random.normal(size=points2.data.shape)
59+
60+
# set time index correctly
61+
points_direction1[:, 0] = points1.data[:, 0]
62+
points_direction2[:, 0] = points2.data[:, 1]
63+
64+
vectors1 = np.stack([points1.data, points_direction1], axis=1)
65+
vectors2 = np.stack([points2.data, points_direction2], axis=1)
66+
67+
vectors1 = Vectors(vectors1, features=points1.features, name="vectors1")
68+
vectors2 = Vectors(vectors2, features=points2.features, name="vectors2")
69+
70+
return vectors1, vectors2
71+
72+
73+
def create_multi_surface_layer(n_samples: int = 100):
74+
from napari.layers import Surface
75+
76+
vertices1, vertices2 = create_multi_point_layer(n_samples=n_samples)
77+
78+
faces1 = []
79+
faces2 = []
80+
for t in range(int(vertices1.data[:, 0].max())):
81+
vertex_indeces_t = np.argwhere(vertices1.data[:, 0] == t).flatten()
82+
83+
# draw some random triangles from the indeces
84+
_faces = np.random.randint(
85+
low=vertex_indeces_t.min(),
86+
high=vertex_indeces_t.max(),
87+
size=(10, 3),
88+
)
89+
faces1.append(_faces)
90+
91+
vertex_indeces_t = np.argwhere(vertices2.data[:, 0] == t).flatten()
92+
93+
# draw some random triangles from the indeces
94+
_faces = np.random.randint(
95+
low=vertex_indeces_t.min(),
96+
high=vertex_indeces_t.max(),
97+
size=(10, 3),
98+
)
99+
faces2.append(_faces)
100+
101+
faces1 = np.concatenate(faces1, axis=0)
102+
faces2 = np.concatenate(faces2, axis=0)
103+
104+
surface1 = Surface(
105+
(vertices1.data, faces1),
106+
features=vertices1.features,
107+
name="surface1",
108+
)
109+
110+
surface2 = Surface(
111+
(vertices2.data, faces2),
112+
features=vertices2.features,
113+
name="surface2",
114+
translate=(0, 0, 2),
115+
)
116+
return surface1, surface2
117+
118+
119+
def create_multi_shapes_layers(n_samples: int = 100):
120+
from napari.layers import Shapes
121+
122+
points1, points2 = create_multi_point_layer(n_samples=n_samples)
123+
124+
shapes1, shapes2 = [], []
125+
for i in range(len(points1.data)):
126+
# create a random shape around the point, whereas the shape consists of the coordinates
127+
# of the four corner of the rectangle
128+
y, x = points1.data[i, 2], points1.data[i, 3]
129+
w, h = np.random.randint(1, 5), np.random.randint(1, 5)
130+
131+
shape1 = np.array(
132+
[
133+
[y - h, x - w],
134+
[y - h, x + w],
135+
[y + h, x + w],
136+
[y + h, x - w],
137+
]
138+
)
139+
shapes1.append(shape1)
140+
141+
for i in range(len(points2.data)):
142+
# create a random shape around the point, whereas the shape consists of the coordinates
143+
# of the four corner of the rectangle
144+
y, x = points2.data[i, 2], points2.data[i, 3]
145+
w, h = np.random.randint(1, 5), np.random.randint(1, 5)
146+
147+
shape2 = np.array(
148+
[
149+
[y - h, x - w],
150+
[y - h, x + w],
151+
[y + h, x + w],
152+
[y + h, x - w],
153+
]
154+
)
155+
shapes2.append(shape2)
156+
157+
shape1 = Shapes(shapes1, features=points1.features, name="shapes1")
158+
shape2 = Shapes(
159+
shapes2, features=points2.features, name="shapes2", translate=(0, 2)
160+
)
161+
162+
return shape1, shape2
163+
164+
165+
def create_multi_labels_layer():
166+
import pandas as pd
167+
from napari.layers import Labels
168+
from skimage import data, measure
169+
170+
labels1 = measure.label(data.binary_blobs(length=64, n_dim=2))
171+
labels2 = measure.label(data.binary_blobs(length=64, n_dim=2))
172+
173+
features1 = pd.DataFrame(
174+
{
175+
"feature1": np.random.normal(size=labels1.max() + 1),
176+
"feature2": np.random.normal(size=labels1.max() + 1),
177+
"feature3": np.random.normal(size=labels1.max() + 1),
178+
}
179+
)
180+
181+
features2 = pd.DataFrame(
182+
{
183+
"feature1": np.random.normal(size=labels2.max() + 1),
184+
"feature2": np.random.normal(size=labels2.max() + 1),
185+
"feature3": np.random.normal(size=labels2.max() + 1),
186+
}
187+
)
188+
189+
labels1 = Labels(labels1, name="labels1", features=features1)
190+
labels2 = Labels(
191+
labels2, name="labels2", features=features2, translate=(0, 128)
192+
)
193+
194+
return labels1, labels2
195+
196+
51197
def test_mixed_layers(make_napari_viewer):
52198
from napari_clusters_plotter import PlotterWidget
53199

@@ -71,14 +217,22 @@ def test_mixed_layers(make_napari_viewer):
71217
viewer.add_image(random_image)
72218
viewer.add_labels(sample_labels)
73219

74-
#
75-
76220

77-
def test_cluster_memorization(make_napari_viewer, n_samples: int = 100):
221+
@pytest.mark.parametrize(
222+
"create_sample_layers",
223+
[
224+
create_multi_point_layer,
225+
create_multi_labels_layer,
226+
create_multi_vectors_layer,
227+
create_multi_surface_layer,
228+
create_multi_shapes_layers,
229+
],
230+
)
231+
def test_cluster_memorization(make_napari_viewer, create_sample_layers):
78232
from napari_clusters_plotter import PlotterWidget
79233

80234
viewer = make_napari_viewer()
81-
layer, layer2 = create_multi_point_layer(n_samples=n_samples)
235+
layer, layer2 = create_sample_layers()
82236

83237
# add layers to viewer
84238
viewer.add_layer(layer)
@@ -91,9 +245,8 @@ def test_cluster_memorization(make_napari_viewer, n_samples: int = 100):
91245
assert "MANUAL_CLUSTER_ID" in layer2.features.columns
92246

93247
plotter_widget._selectors["x"].setCurrentText("feature3")
94-
cluster_indeces = np.random.randint(0, 2, len(layer2.data))
95-
layer2.features["MANUAL_CLUSTER_ID"] = cluster_indeces
96-
plotter_widget._selectors["hue"].setCurrentText("MANUAL_CLUSTER_ID")
248+
cluster_indeces = np.random.randint(0, 2, len(layer2.features))
249+
plotter_widget._on_finish_draw(cluster_indeces)
97250

98251
# select first layer and make sure that no clusters are selected
99252
viewer.layers.selection.active = layer
@@ -110,11 +263,21 @@ def test_cluster_memorization(make_napari_viewer, n_samples: int = 100):
110263
)
111264

112265

113-
def test_categorical_handling(make_napari_viewer, n_samples: int = 100):
266+
@pytest.mark.parametrize(
267+
"create_sample_layers",
268+
[
269+
create_multi_point_layer,
270+
create_multi_labels_layer,
271+
create_multi_vectors_layer,
272+
create_multi_surface_layer,
273+
create_multi_shapes_layers,
274+
],
275+
)
276+
def test_categorical_handling(make_napari_viewer, create_sample_layers):
114277
from napari_clusters_plotter import PlotterWidget
115278

116279
viewer = make_napari_viewer()
117-
layer, layer2 = create_multi_point_layer(n_samples=n_samples)
280+
layer, layer2 = create_sample_layers()
118281

119282
# add layers to viewer
120283
viewer.add_layer(layer)

0 commit comments

Comments
 (0)