Skip to content

Commit dd82cbc

Browse files
author
Joao Felipe Rocha
committed
Added 3d unittests
1 parent eb9cfcf commit dd82cbc

File tree

1 file changed

+128
-0
lines changed

1 file changed

+128
-0
lines changed

Python/test/test_mds.py

Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -622,6 +622,130 @@ def test_mds_high_dimensions():
622622

623623
print("✓ Test 16 PASSED\n")
624624

625+
#####################################################
626+
# SMACOF and SGD 3D Tests
627+
#####################################################
628+
629+
def test_smacof_basic_3d():
630+
"""Test basic SMACOF functionality"""
631+
print("=" * 70)
632+
print("TEST 11: SMACOF basic functionality")
633+
print("=" * 70)
634+
635+
np.random.seed(42)
636+
n_samples = 80
637+
X = np.random.randn(n_samples, 6)
638+
D = squareform(pdist(X, "euclidean"))
639+
640+
# Run SMACOF
641+
Y = smacof(D, n_components=3, random_state=42, metric=True, max_iter=300)
642+
643+
# Check output shape
644+
assert Y.shape == (n_samples, 3), f"Expected shape (80, 3), got {Y.shape}"
645+
print(f"✓ Output shape correct: {Y.shape}")
646+
647+
# Check for NaN or Inf
648+
assert not np.any(np.isnan(Y)), "Output contains NaN values"
649+
assert not np.any(np.isinf(Y)), "Output contains Inf values"
650+
print("✓ No NaN or Inf values")
651+
652+
# Compute stress
653+
stress = compute_stress(D, Y)
654+
print(f"Stress: {stress:.6f}")
655+
print("✓ SMACOF converged")
656+
657+
print("✓ Test 17 PASSED\n")
658+
659+
def test_sgd_mds_basic_3d():
660+
"""Test basic functionality of SGD-MDS"""
661+
print("\n" + "=" * 70)
662+
print("TEST 1: Basic SGD-MDS functionality")
663+
print("=" * 70)
664+
665+
# Create simple test data
666+
np.random.seed(42)
667+
n_samples = 100
668+
X = np.random.randn(n_samples, 10)
669+
D = squareform(pdist(X, "euclidean"))
670+
671+
# Run SGD-MDS
672+
Y = sgd_mds(D, n_components=3, n_iter=200, random_state=42)
673+
674+
# Check output shape
675+
assert Y.shape == (n_samples, 3), f"Expected shape (100, 3), got {Y.shape}"
676+
print(f"✓ Output shape correct: {Y.shape}")
677+
678+
# Check for NaN or Inf
679+
assert not np.any(np.isnan(Y)), "Output contains NaN values"
680+
assert not np.any(np.isinf(Y)), "Output contains Inf values"
681+
print("✓ No NaN or Inf values")
682+
683+
# Check that points are spread out (not collapsed)
684+
variance = np.var(Y, axis=0)
685+
assert np.all(variance > 1e-6), f"Embedding collapsed: variance={variance}"
686+
print(f"✓ Embedding has variance: {variance}")
687+
688+
print("✓ Test 18 PASSED\n")
689+
690+
def test_phate_with_sgd_mds_3d():
691+
"""Test PHATE integration with SGD-MDS solver"""
692+
print("=" * 70)
693+
print("TEST 5: PHATE integration with SGD-MDS")
694+
print("=" * 70)
695+
696+
# Generate tree data
697+
np.random.seed(42)
698+
data, labels = phate.tree.gen_dla(n_dim=50, n_branch=5, branch_length=100)
699+
700+
print(f"Generated tree data: {data.shape}")
701+
702+
# Test with SGD solver
703+
print("\nRunning PHATE with mds_solver='sgd'...")
704+
phate_sgd = phate.PHATE(
705+
n_components=3,
706+
knn=5,
707+
t=20,
708+
mds_solver="sgd",
709+
random_state=42,
710+
verbose=0,
711+
)
712+
embedding_sgd = phate_sgd.fit_transform(data)
713+
714+
assert embedding_sgd.shape == (
715+
data.shape[0],
716+
3,
717+
), f"Expected shape {(data.shape[0], 3)}, got {embedding_sgd.shape}"
718+
assert not np.any(np.isnan(embedding_sgd)), "SGD embedding contains NaN"
719+
print(f"✓ SGD embedding shape: {embedding_sgd.shape}")
720+
721+
# Test with SMACOF solver for comparison
722+
print("\nRunning PHATE with mds_solver='smacof' for comparison...")
723+
phate_smacof = phate.PHATE(
724+
n_components=3,
725+
knn=5,
726+
t=20,
727+
mds_solver="smacof",
728+
random_state=42,
729+
verbose=0,
730+
)
731+
embedding_smacof = phate_smacof.fit_transform(data)
732+
733+
assert embedding_smacof.shape == (
734+
data.shape[0],
735+
3,
736+
), f"Expected shape {(data.shape[0], 3)}, got {embedding_smacof.shape}"
737+
print(f"✓ SMACOF embedding shape: {embedding_smacof.shape}")
738+
739+
# Compare embeddings
740+
_, embedding_sgd_aligned, disparity = procrustes(embedding_smacof, embedding_sgd)
741+
print(f"\nProcrustes disparity between SGD and SMACOF: {disparity:.6f}")
742+
743+
# PHATE adds additional processing (diffusion, potential), so embeddings
744+
# may differ more than raw MDS. Accept < 0.85 as reasonable.
745+
assert disparity < 0.85, f"Embeddings too different: {disparity:.6f}"
746+
print(f"✓ PHATE embeddings reasonably similar (disparity={disparity:.6f})")
747+
748+
print("✓ Test 19 PASSED\n")
625749

626750
def run_all_tests():
627751
"""Run all tests"""
@@ -650,6 +774,10 @@ def run_all_tests():
650774
test_mds_tiny_dataset,
651775
test_mds_zero_distances,
652776
test_mds_high_dimensions,
777+
# 3D SGD_MDS tests,
778+
test_smacof_basic_3d,
779+
test_sgd_mds_basic_3d,
780+
test_phate_with_sgd_mds_3d
653781
]
654782

655783
failed = []

0 commit comments

Comments
 (0)