Skip to content

Commit a0e87ac

Browse files
authored
Merge pull request #240 from funkelab/213-improve-user-experience-for-making-a-new-track-id-when-labeling-with-segs
213 improve user experience for making a new track id when labeling with segs
2 parents 0977fe6 + 520832a commit a0e87ac

File tree

9 files changed

+492
-123
lines changed

9 files changed

+492
-123
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ dependencies =[
4444
"dask[array]>=2021.10.0",
4545
"fonticon-fontawesome6>=6,<7",
4646
"pyqtgraph>=0.13,<1",
47-
"napari-orthogonal-views @ git+https://github.com/AnniekStok/napari-orthogonal-views.git@v0.0.4",
47+
"napari-orthogonal-views==0.0.6",
4848
"lxml_html_clean>=0.4,<1",
4949
"zarr>=2.10,<3",
5050
]

src/motile_tracker/application_menus/editing_menu.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import napari
22
from qtpy.QtWidgets import (
33
QGroupBox,
4+
QHBoxLayout,
5+
QLabel,
46
QPushButton,
57
QVBoxLayout,
68
QWidget,
@@ -17,6 +19,15 @@ def __init__(self, viewer: napari.Viewer):
1719
self.tracks_viewer.selected_nodes.list_updated.connect(self.update_buttons)
1820
layout = QVBoxLayout()
1921

22+
self.label = QLabel(f"Current Track ID: {self.tracks_viewer.selected_track}")
23+
self.tracks_viewer.update_track_id.connect(self.update_track_id_color)
24+
new_track_btn = QPushButton("Start new")
25+
new_track_btn.clicked.connect(self.tracks_viewer.request_new_track)
26+
track_layout = QHBoxLayout()
27+
track_layout.addWidget(self.label)
28+
track_layout.addWidget(new_track_btn)
29+
layout.addLayout(track_layout)
30+
2031
node_box = QGroupBox("Edit Node(s)")
2132
node_box.setMaximumHeight(60)
2233
node_box_layout = QVBoxLayout()
@@ -69,7 +80,22 @@ def __init__(self, viewer: napari.Viewer):
6980
layout.addWidget(self.redo_btn)
7081

7182
self.setLayout(layout)
72-
self.setMaximumHeight(300)
83+
self.setMaximumHeight(400)
84+
85+
def update_track_id_color(self):
86+
"""Display track ID value and color"""
87+
88+
color = self.tracks_viewer.track_id_color
89+
r, g, b, a = [int(c * 255) if i < 3 else c for i, c in enumerate(color)]
90+
css_color = f"rgba({r}, {g}, {b}, {a})"
91+
self.label.setText(f"Current Track ID: {self.tracks_viewer.selected_track}")
92+
self.label.setStyleSheet(
93+
f"""
94+
color: white;
95+
border: 2px solid {css_color};
96+
padding: 5px;
97+
"""
98+
)
7399

74100
def update_buttons(self):
75101
"""Set the buttons to enabled/disabled depending on the selected nodes"""
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
import networkx as nx
2+
import numpy as np
3+
import pytest
4+
from funtracks.data_model import NodeAttr
5+
6+
from motile_tracker.data_views.views_coordinator.tracks_viewer import TracksViewer
7+
8+
9+
@pytest.fixture(autouse=True)
10+
def reset_tracks_viewer():
11+
# clear the singleton before test
12+
if hasattr(TracksViewer, "_instance"):
13+
del TracksViewer._instance
14+
15+
# after test, close all viewers and clear again
16+
yield
17+
if hasattr(TracksViewer, "_instance"):
18+
del TracksViewer._instance
19+
20+
21+
@pytest.fixture
22+
def graph_3d():
23+
graph = nx.DiGraph()
24+
nodes = [
25+
(
26+
1,
27+
{
28+
NodeAttr.POS.value: [50, 50, 50],
29+
NodeAttr.TIME.value: 0,
30+
NodeAttr.AREA.value: 1000,
31+
},
32+
),
33+
(
34+
2,
35+
{
36+
NodeAttr.POS.value: [20, 50, 80],
37+
NodeAttr.TIME.value: 1,
38+
NodeAttr.AREA.value: 1000,
39+
},
40+
),
41+
(
42+
3,
43+
{
44+
NodeAttr.POS.value: [60, 50, 45],
45+
NodeAttr.TIME.value: 2,
46+
NodeAttr.AREA.value: 1000,
47+
},
48+
),
49+
(
50+
4,
51+
{
52+
NodeAttr.POS.value: [40, 70, 60],
53+
NodeAttr.TIME.value: 2,
54+
NodeAttr.AREA.value: 1000,
55+
},
56+
),
57+
]
58+
edges = [(1, 2), (2, 3), (2, 4)]
59+
graph.add_nodes_from(nodes)
60+
graph.add_edges_from(edges)
61+
return graph
62+
63+
64+
@pytest.fixture
65+
def segmentation_3d():
66+
frame_shape = (100, 100, 100)
67+
total_shape = (5, *frame_shape)
68+
segmentation = np.zeros(total_shape, dtype="int32")
69+
segmentation[0, 45:55, 45:55, 45:55] = 1
70+
segmentation[1, 15:25, 45:55, 75:85] = 2
71+
segmentation[2, 55:65, 45:55, 40:50] = 3
72+
segmentation[2, 35:45, 65:75, 55:65] = 4
73+
return segmentation

src/motile_tracker/data_views/_tests/test_ortho_views.py

Lines changed: 39 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
1-
import networkx as nx
21
import numpy as np
3-
import pytest
4-
from funtracks.data_model import NodeAttr, SolutionTracks
2+
from funtracks.data_model import SolutionTracks
53
from napari.layers import Labels, Points
64
from napari_orthogonal_views.ortho_view_widget import OrthoViewWidget
75

@@ -13,52 +11,6 @@
1311
from motile_tracker.data_views.views_coordinator.tracks_viewer import TracksViewer
1412

1513

16-
@pytest.fixture
17-
def graph_3d():
18-
graph = nx.DiGraph()
19-
nodes = [
20-
(
21-
1,
22-
{
23-
NodeAttr.POS.value: [50, 50, 50],
24-
NodeAttr.TIME.value: 0,
25-
},
26-
),
27-
(
28-
2,
29-
{
30-
NodeAttr.POS.value: [20, 50, 80],
31-
NodeAttr.TIME.value: 1,
32-
},
33-
),
34-
(
35-
3,
36-
{
37-
NodeAttr.POS.value: [60, 50, 45],
38-
NodeAttr.TIME.value: 1,
39-
},
40-
),
41-
]
42-
edges = [
43-
(1, 2),
44-
(1, 3),
45-
]
46-
graph.add_nodes_from(nodes)
47-
graph.add_edges_from(edges)
48-
return graph
49-
50-
51-
@pytest.fixture
52-
def segmentation_3d():
53-
frame_shape = (100, 100, 100)
54-
total_shape = (5, *frame_shape)
55-
segmentation = np.zeros(total_shape, dtype="int32")
56-
segmentation[0, 45:55, 45:55, 45:55] = 1
57-
segmentation[1, 15:25, 45:55, 75:85] = 2
58-
segmentation[1, 55:65, 45:55, 40:50] = 3
59-
return segmentation
60-
61-
6214
class MockEvent:
6315
def __init__(self, value):
6416
self.value = value
@@ -71,52 +23,74 @@ def test_ortho_views(make_napari_viewer, qtbot, graph_3d, segmentation_3d):
7123
viewer = make_napari_viewer()
7224
m = initialize_ortho_views(viewer)
7325

74-
m.show()
75-
qtbot.waitUntil(lambda: m.is_shown(), timeout=1000)
76-
assert isinstance(m.right_widget, OrthoViewWidget)
77-
7826
# Create example tracks
7927
tracks = SolutionTracks(graph=graph_3d, segmentation=segmentation_3d, ndim=4)
8028
tracks_viewer = TracksViewer.get_instance(viewer)
8129
tracks_viewer.update_tracks(tracks=tracks, name="test")
8230

8331
assert isinstance(viewer.layers[-1], TrackPoints)
8432
assert isinstance(viewer.layers[-2], TrackLabels)
33+
34+
# change attributes on the TrackLabels layer to check that they are correctly copied
35+
viewer.layers[-2].contour = 1
36+
viewer.layers[-2].mode = "erase"
37+
38+
# show orthogonal views and check attributes
39+
m.show()
40+
qtbot.waitUntil(lambda: m.is_shown(), timeout=1000)
41+
assert isinstance(m.right_widget, OrthoViewWidget)
8542
assert isinstance(m.right_widget.vm_container.viewer_model.layers[-1], Points)
8643
assert isinstance(m.bottom_widget.vm_container.viewer_model.layers[-1], Points)
8744
assert isinstance(m.right_widget.vm_container.viewer_model.layers[-2], Labels)
8845
assert isinstance(m.bottom_widget.vm_container.viewer_model.layers[-2], Labels)
46+
assert (
47+
m.right_widget.vm_container.viewer_model.layers[-2].contour
48+
== viewer.layers[-2].contour
49+
)
50+
assert (
51+
m.right_widget.vm_container.viewer_model.layers[-2].mode
52+
== viewer.layers[-2].mode
53+
)
54+
55+
# set to paint mode and test syncing
56+
viewer.layers[-2].mode = "paint"
57+
assert (
58+
viewer.layers[-2].mode
59+
== m.right_widget.vm_container.viewer_model.layers[-2].mode
60+
== m.bottom_widget.vm_container.viewer_model.layers[-2].mode
61+
)
8962

9063
# Test paint event on main viewer (indices, orig value, target_value)
9164
event_val = [
9265
(
9366
(np.array([1]), np.array([15]), np.array([45]), np.array([75])),
9467
np.array([2], dtype=np.uint16),
95-
np.uint16(4),
68+
np.uint16(5),
9669
)
9770
]
9871
event = MockEvent(event_val)
72+
step = list(viewer.dims.current_step)
73+
step[0] = 1
74+
viewer.dims.current_step = step
9975
viewer.layers[-2]._on_paint(event)
10076

101-
assert viewer.layers[-2].data[1, 15, 45, 75] == 4
77+
assert viewer.layers[-2].data[1, 15, 45, 75] == 5
10278
assert np.array_equal(
10379
viewer.layers[-2].data, m.right_widget.vm_container.viewer_model.layers[-2].data
10480
)
10581

106-
# test paint even on one of the ortho views and see if a new node is added
82+
# test paint event on one of the ortho views and see if a new node is added
83+
assert len(tracks_viewer.tracks.graph.nodes) == 5
84+
step = list(viewer.dims.current_step)
85+
step[0] = 2
86+
viewer.dims.current_step = step
10787
m.right_widget.vm_container.viewer_model.layers[-2].paint(
108-
coord=(2, 63, 20, 30), new_label=5, refresh=True
88+
coord=(2, 63, 20, 30), new_label=6, refresh=True
10989
)
110-
assert len(tracks_viewer.tracks.graph.nodes) == 5
90+
assert len(tracks_viewer.tracks.graph.nodes) == 6
11191

11292
# test syncing of properties
113-
viewer.layers[-2].mode = "paint"
114-
assert (
115-
viewer.layers[-2].mode
116-
== m.right_widget.vm_container.viewer_model.layers[-2].mode
117-
== m.bottom_widget.vm_container.viewer_model.layers[-2].mode
118-
)
119-
viewer.layers[-2].selected_label = 6
93+
viewer.layers[-2].selected_label = 7 # forward sync only
12094
assert (
12195
viewer.layers[-2].selected_label
12296
== m.right_widget.vm_container.viewer_model.layers[-2].selected_label

0 commit comments

Comments
 (0)