|
8 | 8 | from load_tests import generate_swiss_roll |
9 | 9 | from load_tests import graphtools |
10 | 10 | from load_tests import np |
| 11 | +from load_tests import sp |
11 | 12 |
|
12 | 13 | import pygsp |
13 | 14 | import warnings |
@@ -487,6 +488,126 @@ def test_random_landmarking_with_precomputed_distance(): |
487 | 488 | assert G.landmark_op.shape == (n_landmark, n_landmark) |
488 | 489 |
|
489 | 490 |
|
| 491 | +def test_random_landmarking_with_sparse_precomputed_affinity(): |
| 492 | + """Random landmarking should work with sparse precomputed affinity matrices""" |
| 493 | + affinity = np.array( |
| 494 | + [ |
| 495 | + [1.0, 0.8, 0.1, 0.0, 0.0, 0.0], |
| 496 | + [0.8, 1.0, 0.2, 0.0, 0.0, 0.0], |
| 497 | + [0.1, 0.2, 1.0, 0.9, 0.4, 0.0], |
| 498 | + [0.0, 0.0, 0.9, 1.0, 0.5, 0.2], |
| 499 | + [0.0, 0.0, 0.4, 0.5, 1.0, 0.9], |
| 500 | + [0.0, 0.0, 0.0, 0.2, 0.9, 1.0], |
| 501 | + ] |
| 502 | + ) |
| 503 | + affinity = (affinity + affinity.T) / 2 # ensure symmetry |
| 504 | + affinity_sparse = sp.csr_matrix(affinity) |
| 505 | + n_landmark = 3 |
| 506 | + random_state = 42 |
| 507 | + |
| 508 | + G = graphtools.Graph( |
| 509 | + affinity_sparse, |
| 510 | + precomputed="affinity", |
| 511 | + n_landmark=n_landmark, |
| 512 | + random_landmarking=True, |
| 513 | + random_state=random_state, |
| 514 | + knn=3, |
| 515 | + thresh=0, |
| 516 | + ) |
| 517 | + |
| 518 | + # Trigger landmark construction |
| 519 | + _ = G.landmark_op |
| 520 | + |
| 521 | + rng = np.random.default_rng(random_state) |
| 522 | + landmark_indices = rng.choice(affinity.shape[0], n_landmark, replace=False) |
| 523 | + expected_clusters = np.asarray( |
| 524 | + G.kernel[:, landmark_indices].argmax(axis=1) |
| 525 | + ).reshape(-1) |
| 526 | + |
| 527 | + assert np.array_equal(G.clusters, expected_clusters) |
| 528 | + assert G.transitions.shape == (affinity.shape[0], n_landmark) |
| 529 | + assert G.landmark_op.shape == (n_landmark, n_landmark) |
| 530 | + |
| 531 | + |
| 532 | +def test_random_landmarking_with_sparse_precomputed_distance(): |
| 533 | + """Random landmarking should work with sparse precomputed distance matrices""" |
| 534 | + dist = np.array( |
| 535 | + [ |
| 536 | + [0, 1, 4, 4, 4, 4], |
| 537 | + [1, 0, 4, 4, 4, 4], |
| 538 | + [4, 4, 0, 1, 4, 4], |
| 539 | + [4, 4, 1, 0, 4, 4], |
| 540 | + [4, 4, 4, 4, 0, 1], |
| 541 | + [4, 4, 4, 4, 1, 0], |
| 542 | + ] |
| 543 | + ) |
| 544 | + dist_sparse = sp.csr_matrix(dist) |
| 545 | + |
| 546 | + n_landmark = 3 |
| 547 | + random_state = 42 |
| 548 | + |
| 549 | + G = graphtools.Graph( |
| 550 | + dist_sparse, |
| 551 | + precomputed="distance", |
| 552 | + n_landmark=n_landmark, |
| 553 | + random_landmarking=True, |
| 554 | + random_state=random_state, |
| 555 | + bandwidth=1, # deterministic affinity: exp(-dist) |
| 556 | + decay=1, |
| 557 | + thresh=0, |
| 558 | + knn=3, |
| 559 | + ) |
| 560 | + |
| 561 | + # Trigger landmark construction |
| 562 | + _ = G.landmark_op |
| 563 | + |
| 564 | + rng = np.random.default_rng(random_state) |
| 565 | + landmark_indices = rng.choice(dist.shape[0], n_landmark, replace=False) |
| 566 | + expected_clusters = np.asarray( |
| 567 | + G.kernel[:, landmark_indices].argmax(axis=1) |
| 568 | + ).reshape(-1) |
| 569 | + |
| 570 | + assert np.array_equal(G.clusters, expected_clusters) |
| 571 | + assert G.transitions.shape == (dist.shape[0], n_landmark) |
| 572 | + assert G.landmark_op.shape == (n_landmark, n_landmark) |
| 573 | + |
| 574 | + |
| 575 | +def test_random_landmarking_zero_affinity_warning(): |
| 576 | + """Test warning when samples have zero affinity to all landmarks""" |
| 577 | + # Create an affinity matrix where point 5 has no connection to other points |
| 578 | + affinity = np.array( |
| 579 | + [ |
| 580 | + [1.0, 0.8, 0.1, 0.0, 0.0, 0.0], |
| 581 | + [0.8, 1.0, 0.2, 0.0, 0.0, 0.0], |
| 582 | + [0.1, 0.2, 1.0, 0.9, 0.4, 0.0], |
| 583 | + [0.0, 0.0, 0.9, 1.0, 0.5, 0.0], |
| 584 | + [0.0, 0.0, 0.4, 0.5, 1.0, 0.0], |
| 585 | + [0.0, 0.0, 0.0, 0.0, 0.0, 1.0], # isolated point |
| 586 | + ] |
| 587 | + ) |
| 588 | + affinity = (affinity + affinity.T) / 2 # ensure symmetry |
| 589 | + n_landmark = 2 |
| 590 | + random_state = 42 # This seed selects landmarks that don't include point 5 |
| 591 | + |
| 592 | + # Should warn about zero affinity |
| 593 | + with warnings.catch_warnings(record=True) as w: |
| 594 | + warnings.simplefilter("always") |
| 595 | + G = graphtools.Graph( |
| 596 | + affinity, |
| 597 | + precomputed="affinity", |
| 598 | + n_landmark=n_landmark, |
| 599 | + random_landmarking=True, |
| 600 | + random_state=random_state, |
| 601 | + knn=3, |
| 602 | + thresh=0, |
| 603 | + ) |
| 604 | + _ = G.landmark_op |
| 605 | + |
| 606 | + assert len(w) == 1 |
| 607 | + assert issubclass(w[0].category, RuntimeWarning) |
| 608 | + assert "zero affinity to all randomly selected landmarks" in str(w[0].message) |
| 609 | + |
| 610 | + |
490 | 611 | ############# |
491 | 612 | # Test API |
492 | 613 | ############# |
|
0 commit comments