Skip to content

Commit cd0c203

Browse files
committed
added tests for dla tree
1 parent 2752b2f commit cd0c203

File tree

1 file changed

+371
-0
lines changed

1 file changed

+371
-0
lines changed

Python/test/test_tree.py

Lines changed: 371 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,371 @@
1+
#!/usr/bin/env python
2+
"""
3+
Comprehensive test suite for phate.tree module
4+
Tests DLA tree generation for synthetic test data
5+
"""
6+
7+
import numpy as np
8+
import sys
9+
import os
10+
11+
# Add parent directory to path
12+
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
13+
14+
import phate
15+
import phate.tree as tree
16+
import pytest
17+
18+
19+
#####################################################
20+
# Tests for gen_dla()
21+
#####################################################
22+
23+
24+
def test_gen_dla_basic():
25+
"""Test basic gen_dla functionality with default parameters"""
26+
print("\n" + "=" * 70)
27+
print("TEST 1: gen_dla() basic functionality")
28+
print("=" * 70)
29+
30+
# Generate with default parameters
31+
M, C = tree.gen_dla()
32+
33+
# Check return types
34+
assert isinstance(M, np.ndarray), f"Expected M to be ndarray, got {type(M)}"
35+
assert isinstance(C, np.ndarray), f"Expected C to be ndarray, got {type(C)}"
36+
print(f"✓ Returns numpy arrays")
37+
38+
# Check shapes with defaults: n_branch=20, branch_length=100
39+
expected_n_points = 20 * 100 # n_branch * branch_length
40+
assert M.shape[0] == expected_n_points, \
41+
f"Expected {expected_n_points} points, got {M.shape[0]}"
42+
print(f"✓ Correct number of points: {M.shape[0]}")
43+
44+
# Default n_dim=100
45+
assert M.shape[1] == 100, f"Expected 100 dimensions, got {M.shape[1]}"
46+
print(f"✓ Correct dimensionality: {M.shape[1]}")
47+
48+
# Cluster labels should match data
49+
assert C.shape[0] == M.shape[0], \
50+
f"Mismatched shapes: M has {M.shape[0]} points, C has {C.shape[0]} labels"
51+
print(f"✓ Cluster labels match data points")
52+
53+
# Cluster labels should be integers from 0 to n_branch-1
54+
assert np.issubdtype(C.dtype, np.integer), f"Expected integer labels, got {C.dtype}"
55+
assert np.min(C) == 0, f"Expected min label 0, got {np.min(C)}"
56+
assert np.max(C) == 19, f"Expected max label 19 (n_branch-1), got {np.max(C)}"
57+
print(f"✓ Cluster labels in correct range [0, 19]")
58+
59+
# All values should be finite
60+
assert np.all(np.isfinite(M)), "M contains non-finite values"
61+
print(f"✓ All data values are finite")
62+
63+
print("✓ Test 1 PASSED\n")
64+
65+
66+
def test_gen_dla_custom_parameters():
67+
"""Test gen_dla with custom parameters"""
68+
print("=" * 70)
69+
print("TEST 2: gen_dla() with custom parameters")
70+
print("=" * 70)
71+
72+
# Test with custom n_dim
73+
M, C = tree.gen_dla(n_dim=50, n_branch=3, branch_length=20, seed=42)
74+
75+
assert M.shape == (60, 50), f"Expected shape (60, 50), got {M.shape}"
76+
assert C.shape == (60,), f"Expected C shape (60,), got {C.shape}"
77+
assert len(np.unique(C)) == 3, f"Expected 3 branches, got {len(np.unique(C))}"
78+
print(f"✓ n_dim=50, n_branch=3, branch_length=20 works correctly")
79+
80+
# Test with different branch_length
81+
M, C = tree.gen_dla(n_dim=30, n_branch=5, branch_length=50, seed=42)
82+
83+
assert M.shape == (250, 30), f"Expected shape (250, 30), got {M.shape}"
84+
assert C.shape == (250,), f"Expected C shape (250,), got {C.shape}"
85+
assert len(np.unique(C)) == 5, f"Expected 5 branches, got {len(np.unique(C))}"
86+
print(f"✓ n_dim=30, n_branch=5, branch_length=50 works correctly")
87+
88+
print("✓ Test 2 PASSED\n")
89+
90+
91+
def test_gen_dla_single_branch():
92+
"""Test gen_dla with single branch (n_branch=1)"""
93+
print("=" * 70)
94+
print("TEST 3: gen_dla() with single branch")
95+
print("=" * 70)
96+
97+
M, C = tree.gen_dla(n_dim=10, n_branch=1, branch_length=50, seed=42)
98+
99+
assert M.shape == (50, 10), f"Expected shape (50, 10), got {M.shape}"
100+
assert C.shape == (50,), f"Expected C shape (50,), got {C.shape}"
101+
102+
# With single branch, all labels should be 0
103+
assert np.all(C == 0), f"Expected all labels to be 0, got {np.unique(C)}"
104+
print(f"✓ Single branch: all labels are 0")
105+
106+
assert np.all(np.isfinite(M)), "M contains non-finite values"
107+
print(f"✓ Data is finite")
108+
109+
print("✓ Test 3 PASSED\n")
110+
111+
112+
def test_gen_dla_reproducibility():
113+
"""Test gen_dla reproducibility with same seed"""
114+
print("=" * 70)
115+
print("TEST 4: gen_dla() reproducibility with seed")
116+
print("=" * 70)
117+
118+
# Generate twice with same seed
119+
M1, C1 = tree.gen_dla(n_dim=20, n_branch=3, branch_length=30, seed=42)
120+
M2, C2 = tree.gen_dla(n_dim=20, n_branch=3, branch_length=30, seed=42)
121+
122+
# Should be identical
123+
assert np.array_equal(M1, M2), "Same seed should produce identical data"
124+
assert np.array_equal(C1, C2), "Same seed should produce identical labels"
125+
print(f"✓ Same seed produces identical results")
126+
127+
# Different seed should produce different results
128+
M3, C3 = tree.gen_dla(n_dim=20, n_branch=3, branch_length=30, seed=999)
129+
130+
# Should be different (very unlikely to be identical by chance)
131+
assert not np.array_equal(M1, M3), "Different seeds should produce different data"
132+
print(f"✓ Different seed produces different results")
133+
134+
# But should have same shape and label structure
135+
assert M1.shape == M3.shape, "Should have same shape"
136+
assert C1.shape == C3.shape, "Should have same label shape"
137+
print(f"✓ Different seeds maintain consistent structure")
138+
139+
print("✓ Test 4 PASSED\n")
140+
141+
142+
def test_gen_dla_rand_multiplier():
143+
"""Test gen_dla with different rand_multiplier values"""
144+
print("=" * 70)
145+
print("TEST 5: gen_dla() with different rand_multiplier")
146+
print("=" * 70)
147+
148+
# Generate with different rand_multiplier values
149+
M1, C1 = tree.gen_dla(n_dim=10, n_branch=2, branch_length=20,
150+
rand_multiplier=1, seed=42)
151+
M2, C2 = tree.gen_dla(n_dim=10, n_branch=2, branch_length=20,
152+
rand_multiplier=5, seed=42)
153+
154+
# Higher rand_multiplier should generally give larger spread
155+
spread1 = np.std(M1)
156+
spread2 = np.std(M2)
157+
158+
print(f"rand_multiplier=1: std={spread1:.4f}")
159+
print(f"rand_multiplier=5: std={spread2:.4f}")
160+
161+
# Larger multiplier should give larger spread (in most cases)
162+
assert spread2 > spread1, \
163+
f"Expected larger spread with higher rand_multiplier, got {spread1:.4f} vs {spread2:.4f}"
164+
print(f"✓ Higher rand_multiplier gives larger spread")
165+
166+
print("✓ Test 5 PASSED\n")
167+
168+
169+
def test_gen_dla_sigma():
170+
"""Test gen_dla with different sigma (noise) values"""
171+
print("=" * 70)
172+
print("TEST 6: gen_dla() with different sigma (noise)")
173+
print("=" * 70)
174+
175+
# Generate with no noise
176+
M1, C1 = tree.gen_dla(n_dim=10, n_branch=2, branch_length=20,
177+
sigma=0, seed=42)
178+
179+
# Generate with noise
180+
M2, C2 = tree.gen_dla(n_dim=10, n_branch=2, branch_length=20,
181+
sigma=10, seed=42)
182+
183+
# Should have same shape
184+
assert M1.shape == M2.shape
185+
print(f"✓ Same shape with different sigma")
186+
187+
# Should be different due to noise
188+
assert not np.array_equal(M1, M2), "Different sigma should give different results"
189+
print(f"✓ Different sigma produces different results")
190+
191+
# Both should be finite
192+
assert np.all(np.isfinite(M1)), "sigma=0 data should be finite"
193+
assert np.all(np.isfinite(M2)), "sigma=10 data should be finite"
194+
print(f"✓ All data finite for both sigma values")
195+
196+
print("✓ Test 6 PASSED\n")
197+
198+
199+
def test_gen_dla_cluster_labels():
200+
"""Test gen_dla cluster label structure"""
201+
print("=" * 70)
202+
print("TEST 7: gen_dla() cluster label structure")
203+
print("=" * 70)
204+
205+
n_branch = 4
206+
branch_length = 25
207+
M, C = tree.gen_dla(n_dim=10, n_branch=n_branch,
208+
branch_length=branch_length, seed=42)
209+
210+
# Each branch should have exactly branch_length points
211+
for i in range(n_branch):
212+
n_points_in_branch = np.sum(C == i)
213+
assert n_points_in_branch == branch_length, \
214+
f"Branch {i}: expected {branch_length} points, got {n_points_in_branch}"
215+
print(f"✓ Branch {i}: {n_points_in_branch} points")
216+
217+
# Labels should be sequential
218+
# First branch_length points have label 0, next branch_length have label 1, etc.
219+
for i in range(n_branch):
220+
start_idx = i * branch_length
221+
end_idx = (i + 1) * branch_length
222+
branch_labels = C[start_idx:end_idx]
223+
assert np.all(branch_labels == i), \
224+
f"Branch {i}: labels not all {i} in positions [{start_idx}, {end_idx})"
225+
226+
print(f"✓ Labels are correctly sequential")
227+
228+
print("✓ Test 7 PASSED\n")
229+
230+
231+
def test_gen_dla_various_dimensions():
232+
"""Test gen_dla with various dimensionalities"""
233+
print("=" * 70)
234+
print("TEST 8: gen_dla() with various dimensionalities")
235+
print("=" * 70)
236+
237+
for n_dim in [2, 5, 10, 50, 100, 200]:
238+
M, C = tree.gen_dla(n_dim=n_dim, n_branch=2, branch_length=20, seed=42)
239+
240+
assert M.shape[1] == n_dim, f"Expected {n_dim} dimensions, got {M.shape[1]}"
241+
assert M.shape[0] == 40, f"Expected 40 points, got {M.shape[0]}"
242+
assert np.all(np.isfinite(M)), f"Non-finite values with n_dim={n_dim}"
243+
print(f"✓ n_dim={n_dim}: shape={M.shape}, all finite")
244+
245+
print("✓ Test 8 PASSED\n")
246+
247+
248+
def test_gen_dla_large_dataset():
249+
"""Test gen_dla with larger dataset"""
250+
print("=" * 70)
251+
print("TEST 9: gen_dla() with large dataset")
252+
print("=" * 70)
253+
254+
# Generate larger dataset
255+
M, C = tree.gen_dla(n_dim=100, n_branch=50, branch_length=200, seed=42)
256+
257+
expected_n_points = 50 * 200 # 10,000 points
258+
assert M.shape == (expected_n_points, 100), \
259+
f"Expected shape ({expected_n_points}, 100), got {M.shape}"
260+
print(f"✓ Large dataset shape: {M.shape}")
261+
262+
assert C.shape == (expected_n_points,), \
263+
f"Expected {expected_n_points} labels, got {C.shape[0]}"
264+
print(f"✓ Correct number of labels: {C.shape[0]}")
265+
266+
assert len(np.unique(C)) == 50, f"Expected 50 unique labels, got {len(np.unique(C))}"
267+
print(f"✓ Correct number of branches: {len(np.unique(C))}")
268+
269+
assert np.all(np.isfinite(M)), "Large dataset contains non-finite values"
270+
print(f"✓ All values finite")
271+
272+
print("✓ Test 9 PASSED\n")
273+
274+
275+
def test_gen_dla_data_structure():
276+
"""Test that gen_dla produces tree-like structure"""
277+
print("=" * 70)
278+
print("TEST 10: gen_dla() produces branching structure")
279+
print("=" * 70)
280+
281+
n_branch = 3
282+
branch_length = 50
283+
M, C = tree.gen_dla(n_dim=20, n_branch=n_branch,
284+
branch_length=branch_length, seed=42)
285+
286+
# Each branch should form a path (cumulative sum of random steps)
287+
# Points within a branch should be relatively close to each other
288+
# compared to points in different branches (on average)
289+
290+
for i in range(n_branch):
291+
# Get points in this branch
292+
branch_points = M[C == i]
293+
294+
# Check that branch forms a continuous path
295+
# (consecutive points should be close)
296+
consecutive_dists = []
297+
for j in range(len(branch_points) - 1):
298+
dist = np.linalg.norm(branch_points[j+1] - branch_points[j])
299+
consecutive_dists.append(dist)
300+
301+
mean_consecutive_dist = np.mean(consecutive_dists)
302+
print(f"Branch {i}: mean consecutive distance = {mean_consecutive_dist:.4f}")
303+
304+
# Consecutive distances should be relatively small (local structure)
305+
# This is a sanity check that the tree structure makes sense
306+
assert mean_consecutive_dist < 100, \
307+
f"Branch {i}: consecutive points too far apart ({mean_consecutive_dist:.4f})"
308+
309+
print(f"✓ All branches show local continuity")
310+
311+
print("✓ Test 10 PASSED\n")
312+
313+
314+
def test_gen_dla_minimal_parameters():
315+
"""Test gen_dla with minimal/extreme parameters"""
316+
print("=" * 70)
317+
print("TEST 11: gen_dla() with minimal parameters")
318+
print("=" * 70)
319+
320+
# Very small dataset
321+
M, C = tree.gen_dla(n_dim=2, n_branch=1, branch_length=5, seed=42)
322+
323+
assert M.shape == (5, 2), f"Expected shape (5, 2), got {M.shape}"
324+
assert C.shape == (5,), f"Expected C shape (5,), got {C.shape}"
325+
assert np.all(np.isfinite(M)), "Minimal dataset contains non-finite values"
326+
print(f"✓ Minimal dataset (n_dim=2, n_branch=1, branch_length=5) works")
327+
328+
# Very small dimensionality with multiple branches
329+
M, C = tree.gen_dla(n_dim=1, n_branch=3, branch_length=10, seed=42)
330+
331+
assert M.shape == (30, 1), f"Expected shape (30, 1), got {M.shape}"
332+
assert np.all(np.isfinite(M)), "1D dataset contains non-finite values"
333+
print(f"✓ 1D dataset (n_dim=1) works")
334+
335+
print("✓ Test 11 PASSED\n")
336+
337+
338+
#####################################################
339+
# Integration test
340+
#####################################################
341+
342+
343+
def test_gen_dla_with_phate():
344+
"""Test that gen_dla output works with PHATE"""
345+
print("=" * 70)
346+
print("TEST 12: gen_dla() output works with PHATE")
347+
print("=" * 70)
348+
349+
# Generate tree data
350+
M, C = tree.gen_dla(n_dim=50, n_branch=4, branch_length=100, seed=42)
351+
352+
# Should work with PHATE
353+
phate_op = phate.PHATE(knn=5, t=10, verbose=False, random_state=42)
354+
Y = phate_op.fit_transform(M)
355+
356+
# Check PHATE output
357+
assert Y.shape == (400, 2), f"Expected PHATE output (400, 2), got {Y.shape}"
358+
assert np.all(np.isfinite(Y)), "PHATE embedding contains non-finite values"
359+
print(f"✓ gen_dla() output works with PHATE")
360+
print(f" Input shape: {M.shape}, Output shape: {Y.shape}")
361+
362+
print("✓ Test 12 PASSED\n")
363+
364+
365+
#####################################################
366+
# Run all tests
367+
#####################################################
368+
369+
370+
if __name__ == "__main__":
371+
pytest.main([__file__, "-v"])

0 commit comments

Comments
 (0)