@@ -622,6 +622,130 @@ def test_mds_high_dimensions():
622622
623623 print ("✓ Test 16 PASSED\n " )
624624
625+ #####################################################
626+ # SMACOF and SGD 3D Tests
627+ #####################################################
628+
629+ def test_smacof_basic_3d ():
630+ """Test basic SMACOF functionality"""
631+ print ("=" * 70 )
632+ print ("TEST 11: SMACOF basic functionality" )
633+ print ("=" * 70 )
634+
635+ np .random .seed (42 )
636+ n_samples = 80
637+ X = np .random .randn (n_samples , 6 )
638+ D = squareform (pdist (X , "euclidean" ))
639+
640+ # Run SMACOF
641+ Y = smacof (D , n_components = 3 , random_state = 42 , metric = True , max_iter = 300 )
642+
643+ # Check output shape
644+ assert Y .shape == (n_samples , 3 ), f"Expected shape (80, 3), got { Y .shape } "
645+ print (f"✓ Output shape correct: { Y .shape } " )
646+
647+ # Check for NaN or Inf
648+ assert not np .any (np .isnan (Y )), "Output contains NaN values"
649+ assert not np .any (np .isinf (Y )), "Output contains Inf values"
650+ print ("✓ No NaN or Inf values" )
651+
652+ # Compute stress
653+ stress = compute_stress (D , Y )
654+ print (f"Stress: { stress :.6f} " )
655+ print ("✓ SMACOF converged" )
656+
657+ print ("✓ Test 17 PASSED\n " )
658+
659+ def test_sgd_mds_basic_3d ():
660+ """Test basic functionality of SGD-MDS"""
661+ print ("\n " + "=" * 70 )
662+ print ("TEST 1: Basic SGD-MDS functionality" )
663+ print ("=" * 70 )
664+
665+ # Create simple test data
666+ np .random .seed (42 )
667+ n_samples = 100
668+ X = np .random .randn (n_samples , 10 )
669+ D = squareform (pdist (X , "euclidean" ))
670+
671+ # Run SGD-MDS
672+ Y = sgd_mds (D , n_components = 3 , n_iter = 200 , random_state = 42 )
673+
674+ # Check output shape
675+ assert Y .shape == (n_samples , 3 ), f"Expected shape (100, 3), got { Y .shape } "
676+ print (f"✓ Output shape correct: { Y .shape } " )
677+
678+ # Check for NaN or Inf
679+ assert not np .any (np .isnan (Y )), "Output contains NaN values"
680+ assert not np .any (np .isinf (Y )), "Output contains Inf values"
681+ print ("✓ No NaN or Inf values" )
682+
683+ # Check that points are spread out (not collapsed)
684+ variance = np .var (Y , axis = 0 )
685+ assert np .all (variance > 1e-6 ), f"Embedding collapsed: variance={ variance } "
686+ print (f"✓ Embedding has variance: { variance } " )
687+
688+ print ("✓ Test 18 PASSED\n " )
689+
690+ def test_phate_with_sgd_mds_3d ():
691+ """Test PHATE integration with SGD-MDS solver"""
692+ print ("=" * 70 )
693+ print ("TEST 5: PHATE integration with SGD-MDS" )
694+ print ("=" * 70 )
695+
696+ # Generate tree data
697+ np .random .seed (42 )
698+ data , labels = phate .tree .gen_dla (n_dim = 50 , n_branch = 5 , branch_length = 100 )
699+
700+ print (f"Generated tree data: { data .shape } " )
701+
702+ # Test with SGD solver
703+ print ("\n Running PHATE with mds_solver='sgd'..." )
704+ phate_sgd = phate .PHATE (
705+ n_components = 3 ,
706+ knn = 5 ,
707+ t = 20 ,
708+ mds_solver = "sgd" ,
709+ random_state = 42 ,
710+ verbose = 0 ,
711+ )
712+ embedding_sgd = phate_sgd .fit_transform (data )
713+
714+ assert embedding_sgd .shape == (
715+ data .shape [0 ],
716+ 3 ,
717+ ), f"Expected shape { (data .shape [0 ], 3 )} , got { embedding_sgd .shape } "
718+ assert not np .any (np .isnan (embedding_sgd )), "SGD embedding contains NaN"
719+ print (f"✓ SGD embedding shape: { embedding_sgd .shape } " )
720+
721+ # Test with SMACOF solver for comparison
722+ print ("\n Running PHATE with mds_solver='smacof' for comparison..." )
723+ phate_smacof = phate .PHATE (
724+ n_components = 3 ,
725+ knn = 5 ,
726+ t = 20 ,
727+ mds_solver = "smacof" ,
728+ random_state = 42 ,
729+ verbose = 0 ,
730+ )
731+ embedding_smacof = phate_smacof .fit_transform (data )
732+
733+ assert embedding_smacof .shape == (
734+ data .shape [0 ],
735+ 3 ,
736+ ), f"Expected shape { (data .shape [0 ], 3 )} , got { embedding_smacof .shape } "
737+ print (f"✓ SMACOF embedding shape: { embedding_smacof .shape } " )
738+
739+ # Compare embeddings
740+ _ , embedding_sgd_aligned , disparity = procrustes (embedding_smacof , embedding_sgd )
741+ print (f"\n Procrustes disparity between SGD and SMACOF: { disparity :.6f} " )
742+
743+ # PHATE adds additional processing (diffusion, potential), so embeddings
744+ # may differ more than raw MDS. Accept < 0.85 as reasonable.
745+ assert disparity < 0.85 , f"Embeddings too different: { disparity :.6f} "
746+ print (f"✓ PHATE embeddings reasonably similar (disparity={ disparity :.6f} )" )
747+
748+ print ("✓ Test 19 PASSED\n " )
625749
626750def run_all_tests ():
627751 """Run all tests"""
@@ -650,6 +774,10 @@ def run_all_tests():
650774 test_mds_tiny_dataset ,
651775 test_mds_zero_distances ,
652776 test_mds_high_dimensions ,
777+ # 3D SGD_MDS tests,
778+ test_smacof_basic_3d ,
779+ test_sgd_mds_basic_3d ,
780+ test_phate_with_sgd_mds_3d
653781 ]
654782
655783 failed = []
0 commit comments