Skip to content

Commit e8d8038

Browse files
authored
Merge pull request #82 from MattScicluna/add_disconnection_warning
Add `component_labels` and `n_components` properties for disconnected graphs. Bump to version 2.0.0
2 parents b28338c + f64c8c5 commit e8d8038

File tree

3 files changed

+227
-1
lines changed

3 files changed

+227
-1
lines changed

graphtools/base.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -723,6 +723,81 @@ def kernel(self):
723723
"""Synonym for K"""
724724
return self.K
725725

726+
@property
727+
def n_connected_components(self):
728+
"""Number of connected components in the graph (cached)
729+
730+
A connected component is a maximal set of nodes where there is a path
731+
between every pair of nodes in the set. This property uses scipy's
732+
connected_components function on the kernel matrix to count components.
733+
734+
Returns
735+
-------
736+
n_components : int
737+
Number of connected components in the graph
738+
739+
Examples
740+
--------
741+
>>> G = graphtools.Graph(data)
742+
>>> print(G.n_connected_components)
743+
1
744+
"""
745+
try:
746+
return self._n_connected_components
747+
except AttributeError:
748+
from scipy.sparse.csgraph import connected_components
749+
750+
self._n_connected_components, self._component_labels = connected_components(
751+
csgraph=self.kernel, directed=False, return_labels=True
752+
)
753+
return self._n_connected_components
754+
755+
@property
756+
def component_labels(self):
757+
"""Component label for each node (cached)
758+
759+
Returns the connected component index for each node in the graph.
760+
Nodes with the same label belong to the same connected component.
761+
762+
Returns
763+
-------
764+
labels : np.ndarray, shape=[n_samples]
765+
Component index for each node (0 to n_connected_components - 1)
766+
767+
Examples
768+
--------
769+
>>> G = graphtools.Graph(data)
770+
>>> labels = G.component_labels
771+
>>> # Find nodes in component 0
772+
>>> component_0_nodes = np.where(labels == 0)[0]
773+
"""
774+
try:
775+
return self._component_labels
776+
except AttributeError:
777+
# Trigger computation via n_connected_components
778+
_ = self.n_connected_components
779+
return self._component_labels
780+
781+
@property
782+
def is_connected(self):
783+
"""Check if the graph is connected (cached)
784+
785+
A graph is connected if there is a path between every pair of nodes,
786+
i.e., if it has exactly one connected component.
787+
788+
Returns
789+
-------
790+
connected : bool
791+
True if graph has exactly 1 connected component, False otherwise
792+
793+
Examples
794+
--------
795+
>>> G = graphtools.Graph(data)
796+
>>> if not G.is_connected:
797+
... print(f"Warning: Graph has {G.n_connected_components} components")
798+
"""
799+
return self.n_connected_components == 1
800+
726801
@property
727802
def weighted(self):
728803
return self.decay is not None

graphtools/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = "1.5.3"
1+
__version__ = "2.0.0"

test/test_connectivity.py

Lines changed: 151 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,151 @@
1+
from __future__ import print_function
2+
3+
from load_tests import build_graph
4+
from load_tests import data
5+
from load_tests import graphtools
6+
from load_tests import np
7+
8+
import pytest
9+
10+
11+
def test_connected_graph():
12+
"""Test that a normal graph is connected"""
13+
G = build_graph(data, n_pca=20, decay=10, knn=5)
14+
15+
# Check that the graph is connected
16+
assert G.is_connected, "Expected graph to be connected"
17+
assert G.n_connected_components == 1, "Expected exactly 1 connected component"
18+
19+
# Check component labels
20+
labels = G.component_labels
21+
assert labels.shape[0] == data.shape[0], "Component labels should match data size"
22+
assert np.all(labels == 0), "All nodes should be in component 0"
23+
24+
25+
def test_disconnected_graph():
26+
"""Test a graph that is intentionally disconnected"""
27+
# Create two separate clusters of data that won't connect
28+
cluster1 = np.random.randn(50, 10)
29+
cluster2 = np.random.randn(50, 10) + 100 # Far away from cluster1
30+
disconnected_data = np.vstack([cluster1, cluster2])
31+
32+
# Build graph with small knn to ensure disconnection
33+
G = build_graph(disconnected_data, n_pca=None, decay=10, knn=3, thresh=1e-4)
34+
35+
# Check that the graph is disconnected
36+
assert not G.is_connected, "Expected graph to be disconnected"
37+
assert G.n_connected_components >= 2, "Expected at least 2 connected components"
38+
39+
# Check component labels
40+
labels = G.component_labels
41+
assert labels.shape[0] == disconnected_data.shape[0], "Component labels should match data size"
42+
assert len(np.unique(labels)) >= 2, "Should have at least 2 unique component labels"
43+
44+
45+
def test_component_labels_consistency():
46+
"""Test that component labels are consistent across calls"""
47+
# Create disconnected data
48+
cluster1 = np.random.randn(30, 5)
49+
cluster2 = np.random.randn(30, 5) + 50
50+
disconnected_data = np.vstack([cluster1, cluster2])
51+
52+
G = build_graph(disconnected_data, n_pca=None, decay=10, knn=2)
53+
54+
# Get labels multiple times - should be cached and identical
55+
labels1 = G.component_labels
56+
labels2 = G.component_labels
57+
n_comp1 = G.n_connected_components
58+
n_comp2 = G.n_connected_components
59+
60+
assert np.array_equal(labels1, labels2), "Component labels should be cached"
61+
assert n_comp1 == n_comp2, "n_connected_components should be cached"
62+
63+
64+
def test_precomputed_graph_connectivity():
65+
"""Test connectivity with precomputed distance matrix"""
66+
from scipy.spatial.distance import pdist, squareform
67+
68+
# Create small disconnected dataset
69+
cluster1 = np.array([[0, 0], [0, 1], [1, 0]])
70+
cluster2 = np.array([[100, 100], [100, 101], [101, 100]])
71+
disconnected_data = np.vstack([cluster1, cluster2])
72+
73+
# Compute distance matrix
74+
dist_matrix = squareform(pdist(disconnected_data))
75+
76+
# For precomputed graphs, n_pca must be None
77+
G = build_graph(dist_matrix, n_pca=None, precomputed="distance", decay=10, knn=2)
78+
79+
# Should be disconnected
80+
assert not G.is_connected, "Precomputed disconnected graph should be disconnected"
81+
assert G.n_connected_components == 2, "Should have exactly 2 components"
82+
83+
84+
def test_landmark_graph_connectivity():
85+
"""Test connectivity with landmark graphs"""
86+
G = build_graph(data, n_pca=20, decay=10, knn=5, n_landmark=100)
87+
88+
# Landmark graphs should still support connectivity checks
89+
assert hasattr(G, 'is_connected'), "Landmark graph should have is_connected property"
90+
assert hasattr(G, 'n_connected_components'), "Landmark graph should have n_connected_components"
91+
assert hasattr(G, 'component_labels'), "Landmark graph should have component_labels"
92+
93+
# Check that properties work
94+
is_conn = G.is_connected
95+
n_comp = G.n_connected_components
96+
labels = G.component_labels
97+
98+
assert isinstance(is_conn, (bool, np.bool_)), "is_connected should return boolean"
99+
assert isinstance(n_comp, (int, np.integer)), "n_connected_components should return int"
100+
assert labels.shape[0] == data.shape[0], "component_labels should match data size"
101+
102+
103+
def test_knn_graph_connectivity():
104+
"""Test connectivity with different knn values"""
105+
# With high knn, should be connected
106+
G_high_knn = build_graph(data, n_pca=20, knn=10, decay=10)
107+
assert G_high_knn.is_connected, "Graph with high knn should be connected"
108+
109+
# Create data that might disconnect with very low knn
110+
sparse_data = np.random.randn(100, 10) * 2
111+
G_low_knn = build_graph(sparse_data, n_pca=None, knn=2, decay=10)
112+
113+
# Just check that the properties work (connectivity depends on data)
114+
assert isinstance(G_low_knn.is_connected, (bool, np.bool_))
115+
assert G_low_knn.n_connected_components >= 1
116+
117+
118+
def test_component_labels_range():
119+
"""Test that component labels are in the correct range"""
120+
cluster1 = np.random.randn(20, 5)
121+
cluster2 = np.random.randn(20, 5) + 100
122+
cluster3 = np.random.randn(20, 5) - 100
123+
disconnected_data = np.vstack([cluster1, cluster2, cluster3])
124+
125+
G = build_graph(disconnected_data, n_pca=None, decay=10, knn=2)
126+
127+
labels = G.component_labels
128+
n_comp = G.n_connected_components
129+
130+
# Labels should be in range [0, n_components)
131+
assert labels.min() >= 0, "Minimum label should be >= 0"
132+
assert labels.max() < n_comp, "Maximum label should be < n_connected_components"
133+
assert len(np.unique(labels)) == n_comp, "Number of unique labels should equal n_components"
134+
135+
136+
def test_connectivity_caching():
137+
"""Test that connectivity properties are properly cached"""
138+
G = build_graph(data, n_pca=20, decay=10, knn=5)
139+
140+
# Access properties to trigger caching
141+
_ = G.is_connected
142+
143+
# Check that internal cache exists
144+
assert hasattr(G, '_n_connected_components'), "Should cache n_connected_components"
145+
assert hasattr(G, '_component_labels'), "Should cache component_labels"
146+
147+
# Verify cached values are used
148+
n_comp_cached = G._n_connected_components
149+
n_comp_property = G.n_connected_components
150+
151+
assert n_comp_cached == n_comp_property, "Cached value should match property value"

0 commit comments

Comments
 (0)