|
| 1 | +#!/usr/bin/env python |
| 2 | +""" |
| 3 | +Comprehensive test suite for phate.tree module |
| 4 | +Tests DLA tree generation for synthetic test data |
| 5 | +""" |
| 6 | + |
| 7 | +import numpy as np |
| 8 | +import sys |
| 9 | +import os |
| 10 | + |
| 11 | +# Add parent directory to path |
| 12 | +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) |
| 13 | + |
| 14 | +import phate |
| 15 | +import phate.tree as tree |
| 16 | +import pytest |
| 17 | + |
| 18 | + |
| 19 | +##################################################### |
| 20 | +# Tests for gen_dla() |
| 21 | +##################################################### |
| 22 | + |
| 23 | + |
| 24 | +def test_gen_dla_basic(): |
| 25 | + """Test basic gen_dla functionality with default parameters""" |
| 26 | + print("\n" + "=" * 70) |
| 27 | + print("TEST 1: gen_dla() basic functionality") |
| 28 | + print("=" * 70) |
| 29 | + |
| 30 | + # Generate with default parameters |
| 31 | + M, C = tree.gen_dla() |
| 32 | + |
| 33 | + # Check return types |
| 34 | + assert isinstance(M, np.ndarray), f"Expected M to be ndarray, got {type(M)}" |
| 35 | + assert isinstance(C, np.ndarray), f"Expected C to be ndarray, got {type(C)}" |
| 36 | + print(f"✓ Returns numpy arrays") |
| 37 | + |
| 38 | + # Check shapes with defaults: n_branch=20, branch_length=100 |
| 39 | + expected_n_points = 20 * 100 # n_branch * branch_length |
| 40 | + assert M.shape[0] == expected_n_points, \ |
| 41 | + f"Expected {expected_n_points} points, got {M.shape[0]}" |
| 42 | + print(f"✓ Correct number of points: {M.shape[0]}") |
| 43 | + |
| 44 | + # Default n_dim=100 |
| 45 | + assert M.shape[1] == 100, f"Expected 100 dimensions, got {M.shape[1]}" |
| 46 | + print(f"✓ Correct dimensionality: {M.shape[1]}") |
| 47 | + |
| 48 | + # Cluster labels should match data |
| 49 | + assert C.shape[0] == M.shape[0], \ |
| 50 | + f"Mismatched shapes: M has {M.shape[0]} points, C has {C.shape[0]} labels" |
| 51 | + print(f"✓ Cluster labels match data points") |
| 52 | + |
| 53 | + # Cluster labels should be integers from 0 to n_branch-1 |
| 54 | + assert np.issubdtype(C.dtype, np.integer), f"Expected integer labels, got {C.dtype}" |
| 55 | + assert np.min(C) == 0, f"Expected min label 0, got {np.min(C)}" |
| 56 | + assert np.max(C) == 19, f"Expected max label 19 (n_branch-1), got {np.max(C)}" |
| 57 | + print(f"✓ Cluster labels in correct range [0, 19]") |
| 58 | + |
| 59 | + # All values should be finite |
| 60 | + assert np.all(np.isfinite(M)), "M contains non-finite values" |
| 61 | + print(f"✓ All data values are finite") |
| 62 | + |
| 63 | + print("✓ Test 1 PASSED\n") |
| 64 | + |
| 65 | + |
| 66 | +def test_gen_dla_custom_parameters(): |
| 67 | + """Test gen_dla with custom parameters""" |
| 68 | + print("=" * 70) |
| 69 | + print("TEST 2: gen_dla() with custom parameters") |
| 70 | + print("=" * 70) |
| 71 | + |
| 72 | + # Test with custom n_dim |
| 73 | + M, C = tree.gen_dla(n_dim=50, n_branch=3, branch_length=20, seed=42) |
| 74 | + |
| 75 | + assert M.shape == (60, 50), f"Expected shape (60, 50), got {M.shape}" |
| 76 | + assert C.shape == (60,), f"Expected C shape (60,), got {C.shape}" |
| 77 | + assert len(np.unique(C)) == 3, f"Expected 3 branches, got {len(np.unique(C))}" |
| 78 | + print(f"✓ n_dim=50, n_branch=3, branch_length=20 works correctly") |
| 79 | + |
| 80 | + # Test with different branch_length |
| 81 | + M, C = tree.gen_dla(n_dim=30, n_branch=5, branch_length=50, seed=42) |
| 82 | + |
| 83 | + assert M.shape == (250, 30), f"Expected shape (250, 30), got {M.shape}" |
| 84 | + assert C.shape == (250,), f"Expected C shape (250,), got {C.shape}" |
| 85 | + assert len(np.unique(C)) == 5, f"Expected 5 branches, got {len(np.unique(C))}" |
| 86 | + print(f"✓ n_dim=30, n_branch=5, branch_length=50 works correctly") |
| 87 | + |
| 88 | + print("✓ Test 2 PASSED\n") |
| 89 | + |
| 90 | + |
| 91 | +def test_gen_dla_single_branch(): |
| 92 | + """Test gen_dla with single branch (n_branch=1)""" |
| 93 | + print("=" * 70) |
| 94 | + print("TEST 3: gen_dla() with single branch") |
| 95 | + print("=" * 70) |
| 96 | + |
| 97 | + M, C = tree.gen_dla(n_dim=10, n_branch=1, branch_length=50, seed=42) |
| 98 | + |
| 99 | + assert M.shape == (50, 10), f"Expected shape (50, 10), got {M.shape}" |
| 100 | + assert C.shape == (50,), f"Expected C shape (50,), got {C.shape}" |
| 101 | + |
| 102 | + # With single branch, all labels should be 0 |
| 103 | + assert np.all(C == 0), f"Expected all labels to be 0, got {np.unique(C)}" |
| 104 | + print(f"✓ Single branch: all labels are 0") |
| 105 | + |
| 106 | + assert np.all(np.isfinite(M)), "M contains non-finite values" |
| 107 | + print(f"✓ Data is finite") |
| 108 | + |
| 109 | + print("✓ Test 3 PASSED\n") |
| 110 | + |
| 111 | + |
| 112 | +def test_gen_dla_reproducibility(): |
| 113 | + """Test gen_dla reproducibility with same seed""" |
| 114 | + print("=" * 70) |
| 115 | + print("TEST 4: gen_dla() reproducibility with seed") |
| 116 | + print("=" * 70) |
| 117 | + |
| 118 | + # Generate twice with same seed |
| 119 | + M1, C1 = tree.gen_dla(n_dim=20, n_branch=3, branch_length=30, seed=42) |
| 120 | + M2, C2 = tree.gen_dla(n_dim=20, n_branch=3, branch_length=30, seed=42) |
| 121 | + |
| 122 | + # Should be identical |
| 123 | + assert np.array_equal(M1, M2), "Same seed should produce identical data" |
| 124 | + assert np.array_equal(C1, C2), "Same seed should produce identical labels" |
| 125 | + print(f"✓ Same seed produces identical results") |
| 126 | + |
| 127 | + # Different seed should produce different results |
| 128 | + M3, C3 = tree.gen_dla(n_dim=20, n_branch=3, branch_length=30, seed=999) |
| 129 | + |
| 130 | + # Should be different (very unlikely to be identical by chance) |
| 131 | + assert not np.array_equal(M1, M3), "Different seeds should produce different data" |
| 132 | + print(f"✓ Different seed produces different results") |
| 133 | + |
| 134 | + # But should have same shape and label structure |
| 135 | + assert M1.shape == M3.shape, "Should have same shape" |
| 136 | + assert C1.shape == C3.shape, "Should have same label shape" |
| 137 | + print(f"✓ Different seeds maintain consistent structure") |
| 138 | + |
| 139 | + print("✓ Test 4 PASSED\n") |
| 140 | + |
| 141 | + |
| 142 | +def test_gen_dla_rand_multiplier(): |
| 143 | + """Test gen_dla with different rand_multiplier values""" |
| 144 | + print("=" * 70) |
| 145 | + print("TEST 5: gen_dla() with different rand_multiplier") |
| 146 | + print("=" * 70) |
| 147 | + |
| 148 | + # Generate with different rand_multiplier values |
| 149 | + M1, C1 = tree.gen_dla(n_dim=10, n_branch=2, branch_length=20, |
| 150 | + rand_multiplier=1, seed=42) |
| 151 | + M2, C2 = tree.gen_dla(n_dim=10, n_branch=2, branch_length=20, |
| 152 | + rand_multiplier=5, seed=42) |
| 153 | + |
| 154 | + # Higher rand_multiplier should generally give larger spread |
| 155 | + spread1 = np.std(M1) |
| 156 | + spread2 = np.std(M2) |
| 157 | + |
| 158 | + print(f"rand_multiplier=1: std={spread1:.4f}") |
| 159 | + print(f"rand_multiplier=5: std={spread2:.4f}") |
| 160 | + |
| 161 | + # Larger multiplier should give larger spread (in most cases) |
| 162 | + assert spread2 > spread1, \ |
| 163 | + f"Expected larger spread with higher rand_multiplier, got {spread1:.4f} vs {spread2:.4f}" |
| 164 | + print(f"✓ Higher rand_multiplier gives larger spread") |
| 165 | + |
| 166 | + print("✓ Test 5 PASSED\n") |
| 167 | + |
| 168 | + |
| 169 | +def test_gen_dla_sigma(): |
| 170 | + """Test gen_dla with different sigma (noise) values""" |
| 171 | + print("=" * 70) |
| 172 | + print("TEST 6: gen_dla() with different sigma (noise)") |
| 173 | + print("=" * 70) |
| 174 | + |
| 175 | + # Generate with no noise |
| 176 | + M1, C1 = tree.gen_dla(n_dim=10, n_branch=2, branch_length=20, |
| 177 | + sigma=0, seed=42) |
| 178 | + |
| 179 | + # Generate with noise |
| 180 | + M2, C2 = tree.gen_dla(n_dim=10, n_branch=2, branch_length=20, |
| 181 | + sigma=10, seed=42) |
| 182 | + |
| 183 | + # Should have same shape |
| 184 | + assert M1.shape == M2.shape |
| 185 | + print(f"✓ Same shape with different sigma") |
| 186 | + |
| 187 | + # Should be different due to noise |
| 188 | + assert not np.array_equal(M1, M2), "Different sigma should give different results" |
| 189 | + print(f"✓ Different sigma produces different results") |
| 190 | + |
| 191 | + # Both should be finite |
| 192 | + assert np.all(np.isfinite(M1)), "sigma=0 data should be finite" |
| 193 | + assert np.all(np.isfinite(M2)), "sigma=10 data should be finite" |
| 194 | + print(f"✓ All data finite for both sigma values") |
| 195 | + |
| 196 | + print("✓ Test 6 PASSED\n") |
| 197 | + |
| 198 | + |
| 199 | +def test_gen_dla_cluster_labels(): |
| 200 | + """Test gen_dla cluster label structure""" |
| 201 | + print("=" * 70) |
| 202 | + print("TEST 7: gen_dla() cluster label structure") |
| 203 | + print("=" * 70) |
| 204 | + |
| 205 | + n_branch = 4 |
| 206 | + branch_length = 25 |
| 207 | + M, C = tree.gen_dla(n_dim=10, n_branch=n_branch, |
| 208 | + branch_length=branch_length, seed=42) |
| 209 | + |
| 210 | + # Each branch should have exactly branch_length points |
| 211 | + for i in range(n_branch): |
| 212 | + n_points_in_branch = np.sum(C == i) |
| 213 | + assert n_points_in_branch == branch_length, \ |
| 214 | + f"Branch {i}: expected {branch_length} points, got {n_points_in_branch}" |
| 215 | + print(f"✓ Branch {i}: {n_points_in_branch} points") |
| 216 | + |
| 217 | + # Labels should be sequential |
| 218 | + # First branch_length points have label 0, next branch_length have label 1, etc. |
| 219 | + for i in range(n_branch): |
| 220 | + start_idx = i * branch_length |
| 221 | + end_idx = (i + 1) * branch_length |
| 222 | + branch_labels = C[start_idx:end_idx] |
| 223 | + assert np.all(branch_labels == i), \ |
| 224 | + f"Branch {i}: labels not all {i} in positions [{start_idx}, {end_idx})" |
| 225 | + |
| 226 | + print(f"✓ Labels are correctly sequential") |
| 227 | + |
| 228 | + print("✓ Test 7 PASSED\n") |
| 229 | + |
| 230 | + |
| 231 | +def test_gen_dla_various_dimensions(): |
| 232 | + """Test gen_dla with various dimensionalities""" |
| 233 | + print("=" * 70) |
| 234 | + print("TEST 8: gen_dla() with various dimensionalities") |
| 235 | + print("=" * 70) |
| 236 | + |
| 237 | + for n_dim in [2, 5, 10, 50, 100, 200]: |
| 238 | + M, C = tree.gen_dla(n_dim=n_dim, n_branch=2, branch_length=20, seed=42) |
| 239 | + |
| 240 | + assert M.shape[1] == n_dim, f"Expected {n_dim} dimensions, got {M.shape[1]}" |
| 241 | + assert M.shape[0] == 40, f"Expected 40 points, got {M.shape[0]}" |
| 242 | + assert np.all(np.isfinite(M)), f"Non-finite values with n_dim={n_dim}" |
| 243 | + print(f"✓ n_dim={n_dim}: shape={M.shape}, all finite") |
| 244 | + |
| 245 | + print("✓ Test 8 PASSED\n") |
| 246 | + |
| 247 | + |
| 248 | +def test_gen_dla_large_dataset(): |
| 249 | + """Test gen_dla with larger dataset""" |
| 250 | + print("=" * 70) |
| 251 | + print("TEST 9: gen_dla() with large dataset") |
| 252 | + print("=" * 70) |
| 253 | + |
| 254 | + # Generate larger dataset |
| 255 | + M, C = tree.gen_dla(n_dim=100, n_branch=50, branch_length=200, seed=42) |
| 256 | + |
| 257 | + expected_n_points = 50 * 200 # 10,000 points |
| 258 | + assert M.shape == (expected_n_points, 100), \ |
| 259 | + f"Expected shape ({expected_n_points}, 100), got {M.shape}" |
| 260 | + print(f"✓ Large dataset shape: {M.shape}") |
| 261 | + |
| 262 | + assert C.shape == (expected_n_points,), \ |
| 263 | + f"Expected {expected_n_points} labels, got {C.shape[0]}" |
| 264 | + print(f"✓ Correct number of labels: {C.shape[0]}") |
| 265 | + |
| 266 | + assert len(np.unique(C)) == 50, f"Expected 50 unique labels, got {len(np.unique(C))}" |
| 267 | + print(f"✓ Correct number of branches: {len(np.unique(C))}") |
| 268 | + |
| 269 | + assert np.all(np.isfinite(M)), "Large dataset contains non-finite values" |
| 270 | + print(f"✓ All values finite") |
| 271 | + |
| 272 | + print("✓ Test 9 PASSED\n") |
| 273 | + |
| 274 | + |
| 275 | +def test_gen_dla_data_structure(): |
| 276 | + """Test that gen_dla produces tree-like structure""" |
| 277 | + print("=" * 70) |
| 278 | + print("TEST 10: gen_dla() produces branching structure") |
| 279 | + print("=" * 70) |
| 280 | + |
| 281 | + n_branch = 3 |
| 282 | + branch_length = 50 |
| 283 | + M, C = tree.gen_dla(n_dim=20, n_branch=n_branch, |
| 284 | + branch_length=branch_length, seed=42) |
| 285 | + |
| 286 | + # Each branch should form a path (cumulative sum of random steps) |
| 287 | + # Points within a branch should be relatively close to each other |
| 288 | + # compared to points in different branches (on average) |
| 289 | + |
| 290 | + for i in range(n_branch): |
| 291 | + # Get points in this branch |
| 292 | + branch_points = M[C == i] |
| 293 | + |
| 294 | + # Check that branch forms a continuous path |
| 295 | + # (consecutive points should be close) |
| 296 | + consecutive_dists = [] |
| 297 | + for j in range(len(branch_points) - 1): |
| 298 | + dist = np.linalg.norm(branch_points[j+1] - branch_points[j]) |
| 299 | + consecutive_dists.append(dist) |
| 300 | + |
| 301 | + mean_consecutive_dist = np.mean(consecutive_dists) |
| 302 | + print(f"Branch {i}: mean consecutive distance = {mean_consecutive_dist:.4f}") |
| 303 | + |
| 304 | + # Consecutive distances should be relatively small (local structure) |
| 305 | + # This is a sanity check that the tree structure makes sense |
| 306 | + assert mean_consecutive_dist < 100, \ |
| 307 | + f"Branch {i}: consecutive points too far apart ({mean_consecutive_dist:.4f})" |
| 308 | + |
| 309 | + print(f"✓ All branches show local continuity") |
| 310 | + |
| 311 | + print("✓ Test 10 PASSED\n") |
| 312 | + |
| 313 | + |
| 314 | +def test_gen_dla_minimal_parameters(): |
| 315 | + """Test gen_dla with minimal/extreme parameters""" |
| 316 | + print("=" * 70) |
| 317 | + print("TEST 11: gen_dla() with minimal parameters") |
| 318 | + print("=" * 70) |
| 319 | + |
| 320 | + # Very small dataset |
| 321 | + M, C = tree.gen_dla(n_dim=2, n_branch=1, branch_length=5, seed=42) |
| 322 | + |
| 323 | + assert M.shape == (5, 2), f"Expected shape (5, 2), got {M.shape}" |
| 324 | + assert C.shape == (5,), f"Expected C shape (5,), got {C.shape}" |
| 325 | + assert np.all(np.isfinite(M)), "Minimal dataset contains non-finite values" |
| 326 | + print(f"✓ Minimal dataset (n_dim=2, n_branch=1, branch_length=5) works") |
| 327 | + |
| 328 | + # Very small dimensionality with multiple branches |
| 329 | + M, C = tree.gen_dla(n_dim=1, n_branch=3, branch_length=10, seed=42) |
| 330 | + |
| 331 | + assert M.shape == (30, 1), f"Expected shape (30, 1), got {M.shape}" |
| 332 | + assert np.all(np.isfinite(M)), "1D dataset contains non-finite values" |
| 333 | + print(f"✓ 1D dataset (n_dim=1) works") |
| 334 | + |
| 335 | + print("✓ Test 11 PASSED\n") |
| 336 | + |
| 337 | + |
| 338 | +##################################################### |
| 339 | +# Integration test |
| 340 | +##################################################### |
| 341 | + |
| 342 | + |
| 343 | +def test_gen_dla_with_phate(): |
| 344 | + """Test that gen_dla output works with PHATE""" |
| 345 | + print("=" * 70) |
| 346 | + print("TEST 12: gen_dla() output works with PHATE") |
| 347 | + print("=" * 70) |
| 348 | + |
| 349 | + # Generate tree data |
| 350 | + M, C = tree.gen_dla(n_dim=50, n_branch=4, branch_length=100, seed=42) |
| 351 | + |
| 352 | + # Should work with PHATE |
| 353 | + phate_op = phate.PHATE(knn=5, t=10, verbose=False, random_state=42) |
| 354 | + Y = phate_op.fit_transform(M) |
| 355 | + |
| 356 | + # Check PHATE output |
| 357 | + assert Y.shape == (400, 2), f"Expected PHATE output (400, 2), got {Y.shape}" |
| 358 | + assert np.all(np.isfinite(Y)), "PHATE embedding contains non-finite values" |
| 359 | + print(f"✓ gen_dla() output works with PHATE") |
| 360 | + print(f" Input shape: {M.shape}, Output shape: {Y.shape}") |
| 361 | + |
| 362 | + print("✓ Test 12 PASSED\n") |
| 363 | + |
| 364 | + |
| 365 | +##################################################### |
| 366 | +# Run all tests |
| 367 | +##################################################### |
| 368 | + |
| 369 | + |
| 370 | +if __name__ == "__main__": |
| 371 | + pytest.main([__file__, "-v"]) |
0 commit comments