Skip to content

Commit bf24cf5

Browse files
authored
Merge pull request #457 from jo-mueller/connect-napari-selection-with-clusters-plotter
Connect napari selection with clusters plotter
2 parents 04b30ed + dbcf74c commit bf24cf5

File tree

3 files changed

+128
-2
lines changed

3 files changed

+128
-2
lines changed

src/napari_clusters_plotter/_new_plotter_widget.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,11 @@
2121
from qtpy.QtWidgets import QComboBox, QVBoxLayout, QWidget
2222

2323
from ._algorithm_widget import BaseWidget
24+
from ._utilities import (
25+
_get_selected_objects,
26+
_get_selection_event,
27+
_is_selectable_layer,
28+
)
2429

2530

2631
class PlottingType(Enum):
@@ -565,6 +570,31 @@ def _on_update_layer_selection(
565570
f"Layer {layer.name} does not have events.features or events.properties"
566571
)
567572

573+
# connect selection event
574+
if _is_selectable_layer(layer):
575+
selection_event = _get_selection_event(layer)
576+
selection_event.connect(self._update_selected_object_feature)
577+
578+
def _update_selected_object_feature(self) -> None:
579+
"""
580+
Get the selected object from the layer and updates the entry in MANUAL_CLUSTER_ID
581+
"""
582+
# do nothing if more than one layer is selected
583+
if len(self.layers) > 1:
584+
return
585+
layer = self.layers[0]
586+
selected_data = _get_selected_objects(layer)
587+
cluster = np.zeros(len(layer.features), dtype=np.uint64)
588+
cluster[list(selected_data)] = 1
589+
590+
if np.all(cluster == 0):
591+
return
592+
593+
# get copy of features table, modify and overwrite to trigger draw event
594+
features_table = layer.features
595+
features_table["SELECTED_LAYER_CLUSTER_ID"] = pd.Categorical(cluster)
596+
layer.features = features_table
597+
568598
def _clean_up(self):
569599
"""In case of empty layer selection"""
570600

@@ -777,7 +807,9 @@ def _set_layer_color(self, layer, colors):
777807
# Ensure the first color is transparent for the background
778808
colors = np.insert(colors, 0, [0, 0, 0, 0], axis=0)
779809
color_dict = dict(zip(_get_unique_values(layer), colors))
810+
layer.events.selected_label.block()
780811
layer.colormap = DirectLabelColormap(color_dict=color_dict)
812+
layer.events.selected_label.unblock()
781813
layer.refresh()
782814

783815
def _reset(self):

src/napari_clusters_plotter/_tests/test_plotter.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -749,3 +749,48 @@ def test_cluster_visibility_toggle(make_napari_viewer, create_sample_layers):
749749
plotter_widget._on_show_plot_overlay(state=False)
750750
plotter_widget._on_show_plot_overlay(state=True)
751751
plotter_widget._on_show_plot_overlay(state=False)
752+
753+
754+
@pytest.mark.parametrize(
755+
"create_sample_layers",
756+
[create_multi_point_layer, create_multi_labels_layer],
757+
)
758+
def test_selected_data_point_layer(make_napari_viewer, create_sample_layers):
759+
from napari_clusters_plotter import PlotterWidget
760+
from napari_clusters_plotter._utilities import _get_selection_event
761+
762+
viewer = make_napari_viewer()
763+
_, layer2 = create_sample_layers()
764+
765+
# add layers to viewer
766+
viewer.add_layer(layer2)
767+
plotter_widget = PlotterWidget(viewer)
768+
viewer.window.add_dock_widget(plotter_widget, area="right")
769+
770+
# select last layer and create a random selection on the layer
771+
viewer.layers.selection.active = layer2
772+
773+
event = _get_selection_event(layer2)
774+
if isinstance(layer2, Points):
775+
selection = [1, 2]
776+
layer2.selected_data = selection
777+
event.emit()
778+
elif isinstance(layer2, Labels):
779+
selection = 1
780+
layer2.selected_label = selection
781+
event()
782+
783+
assert "SELECTED_LAYER_CLUSTER_ID" in layer2.features.columns
784+
785+
# use SELECTED_DATA_LAYER_CLUSTER_ID as hue for layer2
786+
viewer.layers.selection.active = layer2
787+
plotter_widget._selectors["hue"].setCurrentText(
788+
"SELECTED_LAYER_CLUSTER_ID"
789+
)
790+
791+
assert np.array_equal(
792+
np.argwhere(
793+
plotter_widget.plotting_widget.active_artist.color_indices
794+
).flatten(),
795+
np.asarray([selection]).flatten(),
796+
)

src/napari_clusters_plotter/_utilities.py

Lines changed: 51 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,15 @@
1-
from typing import Union
1+
from typing import List, Union
22

33
import dask.array as da
44
import numpy as np
5-
from napari.layers import Image, Labels
5+
from napari.layers import Image, Labels, Layer, Points, Shapes
6+
from napari.utils.events import Event
7+
8+
_selectable_layers = [
9+
Labels,
10+
Points,
11+
# TODO: Shapes support is planned for future implementation; currently excluded from selectable layers due to incomplete support.
12+
]
613

714

815
def _get_unique_values(layer: Union[Image, Labels]) -> np.ndarray:
@@ -29,3 +36,45 @@ def _get_unique_values(layer: Union[Image, Labels]) -> np.ndarray:
2936
unique_values = da.unique(data).compute()
3037

3138
return unique_values
39+
40+
41+
def _is_selectable_layer(layer: Layer) -> bool:
42+
"""
43+
Check if the layer is selectable.
44+
"""
45+
if type(layer) in _selectable_layers:
46+
return True
47+
return False
48+
49+
50+
def _get_selected_objects(layer: Layer) -> List[int]:
51+
"""
52+
Retrieve id of selected object on napari canvas
53+
"""
54+
if not _is_selectable_layer(layer):
55+
raise TypeError(
56+
f"Layer type {type(layer)} is not supported for selection."
57+
)
58+
59+
if isinstance(layer, Points):
60+
return list(layer.selected_data)
61+
elif isinstance(layer, Labels):
62+
return [layer.selected_label]
63+
elif isinstance(layer, Shapes):
64+
return list(layer.selected_data)
65+
66+
67+
def _get_selection_event(layer: Layer) -> Event:
68+
"""
69+
Get the selection event for the layer.
70+
"""
71+
if not _is_selectable_layer(layer):
72+
raise TypeError(
73+
f"Layer type {type(layer)} is not supported for selection events."
74+
)
75+
if isinstance(layer, Points):
76+
return layer.selected_data.events.items_changed
77+
elif isinstance(layer, Labels):
78+
return layer.events.selected_label
79+
elif isinstance(layer, Shapes):
80+
return layer.events.highlight

0 commit comments

Comments
 (0)