Skip to content

Commit 3962460

Browse files
authored
Merge pull request #148 from funkelab/empty-graph
Gracefully handle empty graph and graph with no edges
2 parents 6968f69 + 409fc1b commit 3962460

File tree

5 files changed

+73
-6
lines changed

5 files changed

+73
-6
lines changed

motile/costs/cost.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,11 @@
88

99

1010
class Cost(ABC):
11-
"""A base class for a cost that can be added to a solver."""
11+
"""A base class for a cost that can be added to a solver.
12+
13+
Weights should be initialized in the __init__ and added to a instance
14+
variable so that the Solver can discover them.
15+
"""
1216

1317
@abstractmethod
1418
def apply(self, solver: Solver) -> None:

motile/costs/features.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,14 @@ def _increase_features(self, num_features: int) -> None:
5252
new_features = np.zeros(shape, dtype=self._values.dtype)
5353
self._values = np.hstack((self._values, new_features))
5454

55+
def register_feature(self, feature_index: int) -> None:
56+
num_variables, num_features = self._values.shape
57+
if feature_index >= num_features:
58+
self.resize(
59+
num_variables,
60+
max(feature_index + 1, num_features),
61+
)
62+
5563
def add_feature(
5664
self, variable_index: int | ilpy.Variable, feature_index: int, value: float
5765
) -> None:
@@ -68,10 +76,10 @@ def add_feature(
6876
num_variables, num_features = self._values.shape
6977

7078
variable_index = int(variable_index)
71-
if variable_index >= num_variables or feature_index >= num_features:
79+
if variable_index >= num_variables:
7280
self.resize(
7381
max(variable_index + 1, num_variables),
74-
max(feature_index + 1, num_features),
82+
num_features,
7583
)
7684

7785
self._values[variable_index, feature_index] += value

motile/costs/weights.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,23 +24,28 @@ def __init__(self) -> None:
2424
self._weight_indices: dict[Weight, int] = {}
2525
self._modify_callbacks: list[Callback] = []
2626

27-
def add_weight(self, weight: Weight, name: Hashable) -> None:
27+
def add_weight(self, weight: Weight, name: Hashable) -> int:
2828
"""Add a weight to the container.
2929
3030
Args:
3131
weight:
3232
The :class:`~motile.costs.Weight` to add.
3333
name:
3434
The name of the weight.
35+
36+
Returns:
37+
int: the index of the weight
3538
"""
36-
self._weight_indices[weight] = len(self._weights)
39+
weight_index = len(self._weights)
40+
self._weight_indices[weight] = weight_index
3741
self._weights.append(weight)
3842
self._weights_by_name[name] = weight
3943

4044
for callback in self._modify_callbacks:
4145
weight.register_modify_callback(callback)
4246

4347
self._notify_modified(None, weight.value)
48+
return weight_index
4449

4550
def register_modify_callback(self, callback: Callback) -> None:
4651
"""Register ``callback`` to be called when a weight is modified.

motile/solver.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,8 @@ def add_cost(self, cost: Cost, name: str | None = None) -> None:
9393
for var_name, var in cost.__dict__.items():
9494
if not isinstance(var, Weight):
9595
continue
96-
self.weights.add_weight(var, (name, var_name))
96+
weight_index = self.weights.add_weight(var, (name, var_name))
97+
self.features.register_feature(weight_index)
9798

9899
cost.apply(self)
99100

tests/test_solver.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
import motile
2+
import networkx as nx
3+
from motile.costs import EdgeSelection, NodeSelection
4+
5+
6+
def test_empty_graph() -> None:
7+
"""Test that solving an empty graph does not error and returns empty solution."""
8+
nx_graph = nx.DiGraph()
9+
graph = motile.TrackGraph(nx_graph)
10+
11+
solver = motile.Solver(graph)
12+
solver.add_cost(NodeSelection(constant=-1))
13+
solver.add_cost(EdgeSelection(constant=-1))
14+
15+
# Should not error
16+
solver.solve()
17+
solution_graph = solver.get_selected_subgraph()
18+
19+
# Solution should be empty
20+
assert len(solution_graph.nodes) == 0
21+
assert len(solution_graph.edges) == 0
22+
23+
24+
def test_graph_with_no_edges() -> None:
25+
"""Test that solving a graph with nodes but no edges does not error."""
26+
cells = [
27+
{"id": 0, "t": 0},
28+
{"id": 1, "t": 0},
29+
{"id": 2, "t": 1},
30+
]
31+
32+
nx_graph = nx.DiGraph()
33+
nx_graph.add_nodes_from([(cell["id"], cell) for cell in cells])
34+
# No edges added
35+
36+
graph = motile.TrackGraph(nx_graph)
37+
38+
solver = motile.Solver(graph)
39+
solver.add_cost(NodeSelection(constant=-1))
40+
solver.add_cost(EdgeSelection(constant=-1))
41+
42+
# Should not error
43+
solver.solve()
44+
solution_graph = solver.get_selected_subgraph()
45+
46+
# All nodes should be selected due to negative cost
47+
assert set(solution_graph.nodes.keys()) == {0, 1, 2}
48+
# No edges should be selected (none exist)
49+
assert len(solution_graph.edges) == 0

0 commit comments

Comments
 (0)