|
| 1 | +from __future__ import print_function |
| 2 | + |
| 3 | +from load_tests import build_graph |
| 4 | +from load_tests import data |
| 5 | +from load_tests import graphtools |
| 6 | +from load_tests import np |
| 7 | + |
| 8 | +import pytest |
| 9 | + |
| 10 | + |
| 11 | +def test_connected_graph(): |
| 12 | + """Test that a normal graph is connected""" |
| 13 | + G = build_graph(data, n_pca=20, decay=10, knn=5) |
| 14 | + |
| 15 | + # Check that the graph is connected |
| 16 | + assert G.is_connected, "Expected graph to be connected" |
| 17 | + assert G.n_connected_components == 1, "Expected exactly 1 connected component" |
| 18 | + |
| 19 | + # Check component labels |
| 20 | + labels = G.component_labels |
| 21 | + assert labels.shape[0] == data.shape[0], "Component labels should match data size" |
| 22 | + assert np.all(labels == 0), "All nodes should be in component 0" |
| 23 | + |
| 24 | + |
| 25 | +def test_disconnected_graph(): |
| 26 | + """Test a graph that is intentionally disconnected""" |
| 27 | + # Create two separate clusters of data that won't connect |
| 28 | + cluster1 = np.random.randn(50, 10) |
| 29 | + cluster2 = np.random.randn(50, 10) + 100 # Far away from cluster1 |
| 30 | + disconnected_data = np.vstack([cluster1, cluster2]) |
| 31 | + |
| 32 | + # Build graph with small knn to ensure disconnection |
| 33 | + G = build_graph(disconnected_data, n_pca=None, decay=10, knn=3, thresh=1e-4) |
| 34 | + |
| 35 | + # Check that the graph is disconnected |
| 36 | + assert not G.is_connected, "Expected graph to be disconnected" |
| 37 | + assert G.n_connected_components >= 2, "Expected at least 2 connected components" |
| 38 | + |
| 39 | + # Check component labels |
| 40 | + labels = G.component_labels |
| 41 | + assert labels.shape[0] == disconnected_data.shape[0], "Component labels should match data size" |
| 42 | + assert len(np.unique(labels)) >= 2, "Should have at least 2 unique component labels" |
| 43 | + |
| 44 | + |
| 45 | +def test_component_labels_consistency(): |
| 46 | + """Test that component labels are consistent across calls""" |
| 47 | + # Create disconnected data |
| 48 | + cluster1 = np.random.randn(30, 5) |
| 49 | + cluster2 = np.random.randn(30, 5) + 50 |
| 50 | + disconnected_data = np.vstack([cluster1, cluster2]) |
| 51 | + |
| 52 | + G = build_graph(disconnected_data, n_pca=None, decay=10, knn=2) |
| 53 | + |
| 54 | + # Get labels multiple times - should be cached and identical |
| 55 | + labels1 = G.component_labels |
| 56 | + labels2 = G.component_labels |
| 57 | + n_comp1 = G.n_connected_components |
| 58 | + n_comp2 = G.n_connected_components |
| 59 | + |
| 60 | + assert np.array_equal(labels1, labels2), "Component labels should be cached" |
| 61 | + assert n_comp1 == n_comp2, "n_connected_components should be cached" |
| 62 | + |
| 63 | + |
| 64 | +def test_precomputed_graph_connectivity(): |
| 65 | + """Test connectivity with precomputed distance matrix""" |
| 66 | + from scipy.spatial.distance import pdist, squareform |
| 67 | + |
| 68 | + # Create small disconnected dataset |
| 69 | + cluster1 = np.array([[0, 0], [0, 1], [1, 0]]) |
| 70 | + cluster2 = np.array([[100, 100], [100, 101], [101, 100]]) |
| 71 | + disconnected_data = np.vstack([cluster1, cluster2]) |
| 72 | + |
| 73 | + # Compute distance matrix |
| 74 | + dist_matrix = squareform(pdist(disconnected_data)) |
| 75 | + |
| 76 | + # For precomputed graphs, n_pca must be None |
| 77 | + G = build_graph(dist_matrix, n_pca=None, precomputed="distance", decay=10, knn=2) |
| 78 | + |
| 79 | + # Should be disconnected |
| 80 | + assert not G.is_connected, "Precomputed disconnected graph should be disconnected" |
| 81 | + assert G.n_connected_components == 2, "Should have exactly 2 components" |
| 82 | + |
| 83 | + |
| 84 | +def test_landmark_graph_connectivity(): |
| 85 | + """Test connectivity with landmark graphs""" |
| 86 | + G = build_graph(data, n_pca=20, decay=10, knn=5, n_landmark=100) |
| 87 | + |
| 88 | + # Landmark graphs should still support connectivity checks |
| 89 | + assert hasattr(G, 'is_connected'), "Landmark graph should have is_connected property" |
| 90 | + assert hasattr(G, 'n_connected_components'), "Landmark graph should have n_connected_components" |
| 91 | + assert hasattr(G, 'component_labels'), "Landmark graph should have component_labels" |
| 92 | + |
| 93 | + # Check that properties work |
| 94 | + is_conn = G.is_connected |
| 95 | + n_comp = G.n_connected_components |
| 96 | + labels = G.component_labels |
| 97 | + |
| 98 | + assert isinstance(is_conn, (bool, np.bool_)), "is_connected should return boolean" |
| 99 | + assert isinstance(n_comp, (int, np.integer)), "n_connected_components should return int" |
| 100 | + assert labels.shape[0] == data.shape[0], "component_labels should match data size" |
| 101 | + |
| 102 | + |
| 103 | +def test_knn_graph_connectivity(): |
| 104 | + """Test connectivity with different knn values""" |
| 105 | + # With high knn, should be connected |
| 106 | + G_high_knn = build_graph(data, n_pca=20, knn=10, decay=10) |
| 107 | + assert G_high_knn.is_connected, "Graph with high knn should be connected" |
| 108 | + |
| 109 | + # Create data that might disconnect with very low knn |
| 110 | + sparse_data = np.random.randn(100, 10) * 2 |
| 111 | + G_low_knn = build_graph(sparse_data, n_pca=None, knn=2, decay=10) |
| 112 | + |
| 113 | + # Just check that the properties work (connectivity depends on data) |
| 114 | + assert isinstance(G_low_knn.is_connected, (bool, np.bool_)) |
| 115 | + assert G_low_knn.n_connected_components >= 1 |
| 116 | + |
| 117 | + |
| 118 | +def test_component_labels_range(): |
| 119 | + """Test that component labels are in the correct range""" |
| 120 | + cluster1 = np.random.randn(20, 5) |
| 121 | + cluster2 = np.random.randn(20, 5) + 100 |
| 122 | + cluster3 = np.random.randn(20, 5) - 100 |
| 123 | + disconnected_data = np.vstack([cluster1, cluster2, cluster3]) |
| 124 | + |
| 125 | + G = build_graph(disconnected_data, n_pca=None, decay=10, knn=2) |
| 126 | + |
| 127 | + labels = G.component_labels |
| 128 | + n_comp = G.n_connected_components |
| 129 | + |
| 130 | + # Labels should be in range [0, n_components) |
| 131 | + assert labels.min() >= 0, "Minimum label should be >= 0" |
| 132 | + assert labels.max() < n_comp, "Maximum label should be < n_connected_components" |
| 133 | + assert len(np.unique(labels)) == n_comp, "Number of unique labels should equal n_components" |
| 134 | + |
| 135 | + |
| 136 | +def test_connectivity_caching(): |
| 137 | + """Test that connectivity properties are properly cached""" |
| 138 | + G = build_graph(data, n_pca=20, decay=10, knn=5) |
| 139 | + |
| 140 | + # Access properties to trigger caching |
| 141 | + _ = G.is_connected |
| 142 | + |
| 143 | + # Check that internal cache exists |
| 144 | + assert hasattr(G, '_n_connected_components'), "Should cache n_connected_components" |
| 145 | + assert hasattr(G, '_component_labels'), "Should cache component_labels" |
| 146 | + |
| 147 | + # Verify cached values are used |
| 148 | + n_comp_cached = G._n_connected_components |
| 149 | + n_comp_property = G.n_connected_components |
| 150 | + |
| 151 | + assert n_comp_cached == n_comp_property, "Cached value should match property value" |
0 commit comments