Skip to content

Commit ad325f3

Browse files
authored
Merge pull request #145 from funkelab/86-merge-cost
Add merge indicator and cost
2 parents 6c77b48 + 303d512 commit ad325f3

File tree

6 files changed

+243
-0
lines changed

6 files changed

+243
-0
lines changed

docs/source/api.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,10 @@ Split
109109
^^^^^
110110
.. autoclass:: Split
111111

112+
Merge
113+
^^^^^
114+
.. autoclass:: Merge
115+
112116
EdgeDistance
113117
^^^^^^^^^^^^
114118
.. autoclass:: EdgeDistance

motile/costs/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from .edge_distance import EdgeDistance
55
from .edge_selection import EdgeSelection
66
from .features import Features
7+
from .merge import Merge
78
from .node_selection import NodeSelection
89
from .split import Split
910
from .weight import Weight
@@ -16,6 +17,7 @@
1617
"EdgeDistance",
1718
"EdgeSelection",
1819
"Features",
20+
"Merge",
1921
"NodeSelection",
2022
"Split",
2123
"Weight",

motile/costs/merge.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
from __future__ import annotations
2+
3+
from typing import TYPE_CHECKING
4+
5+
from ..variables import NodeMerge
6+
from .cost import Cost
7+
from .weight import Weight
8+
9+
if TYPE_CHECKING:
10+
from motile.solver import Solver
11+
12+
13+
class Merge(Cost):
14+
"""Cost for :class:`~motile.variables.NodeMerge` variables.
15+
16+
Args:
17+
weight:
18+
The weight to apply to the cost of each split. Default is ``1``.
19+
20+
attribute:
21+
The name of the attribute to use to look up the cost. Default is
22+
``None``, which means that a constant cost is used.
23+
24+
constant:
25+
A constant cost for each node that has more than one selected
26+
parent. Default is ``0``.
27+
"""
28+
29+
def __init__(
30+
self, weight: float = 1, attribute: str | None = None, constant: float = 0
31+
) -> None:
32+
self.weight = Weight(weight)
33+
self.constant = Weight(constant)
34+
self.attribute = attribute
35+
36+
def apply(self, solver: Solver) -> None:
37+
merge_indicators = solver.get_variables(NodeMerge)
38+
39+
for node, index in merge_indicators.items():
40+
if self.attribute is not None:
41+
solver.add_variable_cost(
42+
index, solver.graph.nodes[node][self.attribute], self.weight
43+
)
44+
solver.add_variable_cost(index, 1.0, self.constant)

motile/variables/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from .edge_selected import EdgeSelected
22
from .node_appear import NodeAppear
33
from .node_disappear import NodeDisappear
4+
from .node_merge import NodeMerge
45
from .node_selected import NodeSelected
56
from .node_split import NodeSplit
67
from .variable import Variable
@@ -9,6 +10,7 @@
910
"EdgeSelected",
1011
"NodeAppear",
1112
"NodeDisappear",
13+
"NodeMerge",
1214
"NodeSelected",
1315
"NodeSplit",
1416
"Variable",

motile/variables/node_merge.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
from __future__ import annotations
2+
3+
from typing import TYPE_CHECKING, Collection, Iterable
4+
5+
import ilpy
6+
7+
from .edge_selected import EdgeSelected
8+
from .variable import Variable
9+
10+
if TYPE_CHECKING:
11+
from motile._types import Node
12+
from motile.solver import Solver
13+
14+
15+
class NodeMerge(Variable):
16+
r"""Binary variable indicating whether a node has more than one parent.
17+
18+
(i.e., the node is selected and has more than one selected incoming edge).
19+
20+
This variable is coupled to the edge selection variables through the
21+
following linear constraints:
22+
23+
.. math::
24+
25+
2 m_v\; - &\sum_{e\in\text{in_edges}(v)} x_e &\leq&\;\; 0
26+
27+
(|\text{in_edges}(v)| - 1) m_v\; - &\sum_{e\in\text{in_edges}(v)}
28+
x_e &\geq&\;\; -1
29+
30+
where :math:`x_e` are selection indicators for edge :math:`e`, and
31+
:math:`m_v` is the merge indicator for node :math:`v`.
32+
"""
33+
34+
@staticmethod
35+
def instantiate(solver: Solver) -> Collection[Node]:
36+
return solver.graph.nodes
37+
38+
@staticmethod
39+
def instantiate_constraints(solver: Solver) -> Iterable[ilpy.Constraint]:
40+
merge_indicators = solver.get_variables(NodeMerge)
41+
edge_indicators = solver.get_variables(EdgeSelected)
42+
43+
for node in solver.graph.nodes:
44+
prev_edges = solver.graph.prev_edges[node]
45+
46+
# Ensure that the following holds:
47+
#
48+
# merge = 0 <=> sum(prev_selected) <= 1
49+
# merge = 1 <=> sum(prev_selected) > 1
50+
#
51+
# Two linear constraints are needed for that:
52+
#
53+
# (1) 2 * merge - sum(prev_selected) <= 0
54+
# (2) (num_prev - 1) * merge - sum(prev_selected) >= -1
55+
56+
constraint1 = ilpy.Constraint()
57+
constraint2 = ilpy.Constraint()
58+
59+
constraint1.set_coefficient(merge_indicators[node], 2.0)
60+
constraint2.set_coefficient(merge_indicators[node], len(prev_edges) - 1.0)
61+
62+
for prev_edge in prev_edges:
63+
constraint1.set_coefficient(edge_indicators[prev_edge], -1.0)
64+
constraint2.set_coefficient(edge_indicators[prev_edge], -1.0)
65+
66+
constraint1.set_relation(ilpy.Relation.LessEqual)
67+
constraint2.set_relation(ilpy.Relation.GreaterEqual)
68+
69+
constraint1.set_value(0.0)
70+
constraint2.set_value(-1.0)
71+
72+
yield constraint1
73+
yield constraint2

tests/test_costs.py

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
import motile
2+
import networkx as nx
23
from motile.costs import (
34
Appear,
45
Disappear,
56
EdgeSelection,
7+
Merge,
68
NodeSelection,
79
)
810

@@ -81,3 +83,119 @@ def test_disappear_cost(arlo_graph):
8183
solution_graph = solver.get_selected_subgraph()
8284
assert list(solution_graph.nodes.keys()) == [2, 3, 4, 5, 6]
8385
assert len(solution_graph.edges) == 0
86+
87+
88+
def test_constant_merge_cost() -> None:
89+
"""Test that merge cost prevents merges when applied.
90+
91+
Graph structure:
92+
t=0: node 0, node 1
93+
t=1: node 2
94+
edges: 0->2, 1->2
95+
96+
With only negative edge selection cost, both edges should be selected
97+
(resulting in a merge). Adding a merge cost should prevent the merge,
98+
resulting in only one edge being selected.
99+
"""
100+
# Create nodes
101+
cells = [
102+
{"id": 0, "t": 0},
103+
{"id": 1, "t": 0},
104+
{"id": 2, "t": 1},
105+
]
106+
107+
# Create edges (both leading to node 2, creating potential merge)
108+
edges = [
109+
{"source": 0, "target": 2},
110+
{"source": 1, "target": 2},
111+
]
112+
113+
nx_graph = nx.DiGraph()
114+
nx_graph.add_nodes_from([(cell["id"], cell) for cell in cells])
115+
nx_graph.add_edges_from([(edge["source"], edge["target"], edge) for edge in edges])
116+
117+
graph = motile.TrackGraph(nx_graph)
118+
119+
# First test: without merge cost, both edges should be selected
120+
solver = motile.Solver(graph)
121+
solver.add_cost(EdgeSelection(constant=-1.0))
122+
solver.solve()
123+
solution_graph = solver.get_selected_subgraph().to_nx_graph()
124+
125+
# Should select all nodes and both edges (merge occurs)
126+
assert set(solution_graph.nodes.keys()) == {0, 1, 2}
127+
assert len(solution_graph.edges) == 2
128+
assert solution_graph.has_edge(0, 2)
129+
assert solution_graph.has_edge(1, 2)
130+
131+
# Second test: with merge cost, only one edge should be selected
132+
solver = motile.Solver(graph)
133+
solver.add_cost(EdgeSelection(constant=-1.0))
134+
solver.add_cost(Merge(constant=10.0)) # High cost to prevent merge
135+
solver.solve()
136+
solution_graph = solver.get_selected_subgraph().to_nx_graph()
137+
138+
# Should select all nodes but only one edge (no merge)
139+
assert set(solution_graph.nodes.keys()) == {0, 1, 2}
140+
assert len(solution_graph.edges) == 1
141+
# Either edge 0->2 or 1->2 should be selected, but not both
142+
assert solution_graph.has_edge(0, 2) or solution_graph.has_edge(1, 2)
143+
assert not (solution_graph.has_edge(0, 2) and solution_graph.has_edge(1, 2))
144+
145+
146+
def test_variable_merge_cost() -> None:
147+
"""Test that merge cost can use node attributes to selectively allow merges.
148+
149+
Graph structure:
150+
t=0: node 0, node 1, node 2
151+
t=1: node 3 (merge_cost=-1.0), node 4 (merge_cost=5.0)
152+
edges: 0->3, 1->3, 1->4, 2->4
153+
154+
With negative edge selection cost and attribute-based merge cost,
155+
only the node with negative merge_cost should have a merge (node 3).
156+
Node 4 with positive merge_cost should not have a merge.
157+
"""
158+
# Create nodes - all nodes need merge_cost attribute
159+
cells = [
160+
{"id": 0, "t": 0, "merge_cost": 0.0},
161+
{"id": 1, "t": 0, "merge_cost": 0.0},
162+
{"id": 2, "t": 0, "merge_cost": 0.0},
163+
{"id": 3, "t": 1, "merge_cost": -1.0}, # negative cost = good merge
164+
{"id": 4, "t": 1, "merge_cost": 5.0}, # positive cost = bad merge
165+
]
166+
167+
# Create edges - two edges to each node in time 1 (potential merges)
168+
edges = [
169+
{"source": 0, "target": 3},
170+
{"source": 1, "target": 3},
171+
{"source": 1, "target": 4},
172+
{"source": 2, "target": 4},
173+
]
174+
175+
nx_graph = nx.DiGraph()
176+
nx_graph.add_nodes_from([(cell["id"], cell) for cell in cells])
177+
nx_graph.add_edges_from([(edge["source"], edge["target"], edge) for edge in edges])
178+
179+
graph = motile.TrackGraph(nx_graph)
180+
181+
solver = motile.Solver(graph)
182+
solver.add_cost(EdgeSelection(constant=-1.0))
183+
solver.add_cost(Merge(attribute="merge_cost", weight=1.0, constant=0.0))
184+
solver.solve()
185+
solution_graph = solver.get_selected_subgraph().to_nx_graph()
186+
187+
# Should select all nodes
188+
assert set(solution_graph.nodes.keys()) == {0, 1, 2, 3, 4}
189+
190+
# Node 3 should have a merge (both edges 0->3 and 1->3 selected)
191+
# because its merge_cost is negative (-1.0), making the total cost attractive
192+
assert solution_graph.has_edge(0, 3)
193+
assert solution_graph.has_edge(1, 3)
194+
195+
# Node 4 should NOT have a merge (only one edge selected)
196+
# because its merge_cost is positive (5.0), making the merge too expensive
197+
edges_to_4 = [
198+
solution_graph.has_edge(1, 4),
199+
solution_graph.has_edge(2, 4),
200+
]
201+
assert sum(edges_to_4) == 1 # exactly one edge to node 4

0 commit comments

Comments
 (0)