Skip to content

Commit 303d512

Browse files
committed
Add tests for merge cost
1 parent e1f24c7 commit 303d512

File tree

1 file changed

+118
-0
lines changed

1 file changed

+118
-0
lines changed

tests/test_costs.py

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
import motile
2+
import networkx as nx
23
from motile.costs import (
34
Appear,
45
Disappear,
56
EdgeSelection,
7+
Merge,
68
NodeSelection,
79
)
810

@@ -81,3 +83,119 @@ def test_disappear_cost(arlo_graph):
8183
solution_graph = solver.get_selected_subgraph()
8284
assert list(solution_graph.nodes.keys()) == [2, 3, 4, 5, 6]
8385
assert len(solution_graph.edges) == 0
86+
87+
88+
def test_constant_merge_cost() -> None:
89+
"""Test that merge cost prevents merges when applied.
90+
91+
Graph structure:
92+
t=0: node 0, node 1
93+
t=1: node 2
94+
edges: 0->2, 1->2
95+
96+
With only negative edge selection cost, both edges should be selected
97+
(resulting in a merge). Adding a merge cost should prevent the merge,
98+
resulting in only one edge being selected.
99+
"""
100+
# Create nodes
101+
cells = [
102+
{"id": 0, "t": 0},
103+
{"id": 1, "t": 0},
104+
{"id": 2, "t": 1},
105+
]
106+
107+
# Create edges (both leading to node 2, creating potential merge)
108+
edges = [
109+
{"source": 0, "target": 2},
110+
{"source": 1, "target": 2},
111+
]
112+
113+
nx_graph = nx.DiGraph()
114+
nx_graph.add_nodes_from([(cell["id"], cell) for cell in cells])
115+
nx_graph.add_edges_from([(edge["source"], edge["target"], edge) for edge in edges])
116+
117+
graph = motile.TrackGraph(nx_graph)
118+
119+
# First test: without merge cost, both edges should be selected
120+
solver = motile.Solver(graph)
121+
solver.add_cost(EdgeSelection(constant=-1.0))
122+
solver.solve()
123+
solution_graph = solver.get_selected_subgraph().to_nx_graph()
124+
125+
# Should select all nodes and both edges (merge occurs)
126+
assert set(solution_graph.nodes.keys()) == {0, 1, 2}
127+
assert len(solution_graph.edges) == 2
128+
assert solution_graph.has_edge(0, 2)
129+
assert solution_graph.has_edge(1, 2)
130+
131+
# Second test: with merge cost, only one edge should be selected
132+
solver = motile.Solver(graph)
133+
solver.add_cost(EdgeSelection(constant=-1.0))
134+
solver.add_cost(Merge(constant=10.0)) # High cost to prevent merge
135+
solver.solve()
136+
solution_graph = solver.get_selected_subgraph().to_nx_graph()
137+
138+
# Should select all nodes but only one edge (no merge)
139+
assert set(solution_graph.nodes.keys()) == {0, 1, 2}
140+
assert len(solution_graph.edges) == 1
141+
# Either edge 0->2 or 1->2 should be selected, but not both
142+
assert solution_graph.has_edge(0, 2) or solution_graph.has_edge(1, 2)
143+
assert not (solution_graph.has_edge(0, 2) and solution_graph.has_edge(1, 2))
144+
145+
146+
def test_variable_merge_cost() -> None:
147+
"""Test that merge cost can use node attributes to selectively allow merges.
148+
149+
Graph structure:
150+
t=0: node 0, node 1, node 2
151+
t=1: node 3 (merge_cost=-1.0), node 4 (merge_cost=5.0)
152+
edges: 0->3, 1->3, 1->4, 2->4
153+
154+
With negative edge selection cost and attribute-based merge cost,
155+
only the node with negative merge_cost should have a merge (node 3).
156+
Node 4 with positive merge_cost should not have a merge.
157+
"""
158+
# Create nodes - all nodes need merge_cost attribute
159+
cells = [
160+
{"id": 0, "t": 0, "merge_cost": 0.0},
161+
{"id": 1, "t": 0, "merge_cost": 0.0},
162+
{"id": 2, "t": 0, "merge_cost": 0.0},
163+
{"id": 3, "t": 1, "merge_cost": -1.0}, # negative cost = good merge
164+
{"id": 4, "t": 1, "merge_cost": 5.0}, # positive cost = bad merge
165+
]
166+
167+
# Create edges - two edges to each node in time 1 (potential merges)
168+
edges = [
169+
{"source": 0, "target": 3},
170+
{"source": 1, "target": 3},
171+
{"source": 1, "target": 4},
172+
{"source": 2, "target": 4},
173+
]
174+
175+
nx_graph = nx.DiGraph()
176+
nx_graph.add_nodes_from([(cell["id"], cell) for cell in cells])
177+
nx_graph.add_edges_from([(edge["source"], edge["target"], edge) for edge in edges])
178+
179+
graph = motile.TrackGraph(nx_graph)
180+
181+
solver = motile.Solver(graph)
182+
solver.add_cost(EdgeSelection(constant=-1.0))
183+
solver.add_cost(Merge(attribute="merge_cost", weight=1.0, constant=0.0))
184+
solver.solve()
185+
solution_graph = solver.get_selected_subgraph().to_nx_graph()
186+
187+
# Should select all nodes
188+
assert set(solution_graph.nodes.keys()) == {0, 1, 2, 3, 4}
189+
190+
# Node 3 should have a merge (both edges 0->3 and 1->3 selected)
191+
# because its merge_cost is negative (-1.0), making the total cost attractive
192+
assert solution_graph.has_edge(0, 3)
193+
assert solution_graph.has_edge(1, 3)
194+
195+
# Node 4 should NOT have a merge (only one edge selected)
196+
# because its merge_cost is positive (5.0), making the merge too expensive
197+
edges_to_4 = [
198+
solution_graph.has_edge(1, 4),
199+
solution_graph.has_edge(2, 4),
200+
]
201+
assert sum(edges_to_4) == 1 # exactly one edge to node 4

0 commit comments

Comments
 (0)