Skip to content

Commit 521ad8e

Browse files
authored
Merge pull request #4 from astronomy-commons/new_crossmatch_and_tests
magnitude difference crossmatch algorithm and tests
2 parents 5b62162 + b19d856 commit 521ad8e

23 files changed

+1269
-9
lines changed
Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
from __future__ import annotations
2+
3+
from typing import TYPE_CHECKING
4+
5+
import numpy as np
6+
import pandas as pd
7+
import pyarrow as pa
8+
from lsdb.core.crossmatch.kdtree_match import KdTreeCrossmatch
9+
from lsdb.core.crossmatch.kdtree_utils import _find_crossmatch_indices, _get_chord_distance
10+
11+
if TYPE_CHECKING:
12+
from lsdb.catalog import Catalog
13+
14+
15+
class MagnitudeDifferenceCrossmatch(KdTreeCrossmatch):
16+
"""Cross-matching algorithm that extends KdTreeCrossmatch to include
17+
magnitude difference calculations and filtering.
18+
"""
19+
20+
extra_columns = pd.DataFrame(
21+
{
22+
"_dist_arcsec": pd.Series(dtype=pd.ArrowDtype(pa.float64())),
23+
"_magnitude_difference": pd.Series(dtype=pd.ArrowDtype(pa.float64())),
24+
}
25+
)
26+
27+
@classmethod
28+
def validate(
29+
cls,
30+
left: Catalog,
31+
right: Catalog,
32+
left_mag_col: str,
33+
right_mag_col: str,
34+
radius_arcsec: float = 1,
35+
n_neighbors: int = 1,
36+
): # pylint: disable=too-many-arguments,arguments-renamed,too-many-positional-arguments
37+
super().validate(left, right, n_neighbors=n_neighbors, radius_arcsec=radius_arcsec)
38+
39+
if left_mag_col not in left.columns:
40+
raise ValueError(f"Left catalog must have column '{left_mag_col}'")
41+
if right_mag_col not in right.columns:
42+
raise ValueError(f"Right catalog must have column '{right_mag_col}'")
43+
44+
def _calculate_magnitude_differences(
45+
self, all_matches_df: pd.DataFrame, left_mag_col: str, right_mag_col: str
46+
) -> pd.DataFrame:
47+
all_matches_df["left_mag"] = self.left.iloc[all_matches_df["left_idx"]][left_mag_col].to_numpy()
48+
all_matches_df["right_mag"] = self.right.iloc[all_matches_df["right_idx"]][right_mag_col].to_numpy()
49+
all_matches_df["_magnitude_difference"] = np.abs(
50+
all_matches_df["right_mag"] - all_matches_df["left_mag"]
51+
)
52+
return all_matches_df
53+
54+
def _select_best_matches(self, all_matches_df: pd.DataFrame) -> pd.DataFrame:
55+
best_match_indices_in_all_matches_df = all_matches_df.groupby("left_idx")[
56+
"_magnitude_difference"
57+
].idxmin()
58+
return all_matches_df.loc[best_match_indices_in_all_matches_df].reset_index(drop=True)
59+
60+
# pylint: disable=arguments-differ
61+
def perform_crossmatch(
62+
self,
63+
left_mag_col: str,
64+
right_mag_col: str,
65+
radius_arcsec: float,
66+
n_neighbors: int = 1,
67+
) -> tuple[np.ndarray, np.ndarray, pd.DataFrame]:
68+
max_d_chord = _get_chord_distance(radius_arcsec)
69+
70+
left_xyz, right_xyz = self._get_point_coordinates()
71+
72+
chord_distances_all, left_idx_all, right_idx_all = _find_crossmatch_indices(
73+
left_xyz=left_xyz,
74+
right_xyz=right_xyz,
75+
n_neighbors=n_neighbors,
76+
max_distance=max_d_chord,
77+
)
78+
79+
arc_distances_all = np.degrees(2.0 * np.arcsin(0.5 * chord_distances_all)) * 3600
80+
81+
all_matches_df = pd.DataFrame(
82+
{
83+
"left_idx": left_idx_all,
84+
"right_idx": right_idx_all,
85+
"arc_dist_arcsec": arc_distances_all,
86+
}
87+
)
88+
89+
all_matches_df = self._calculate_magnitude_differences(all_matches_df, left_mag_col, right_mag_col)
90+
final_matches_df = self._select_best_matches(all_matches_df)
91+
92+
final_left_indices = final_matches_df["left_idx"].to_numpy()
93+
final_right_indices = final_matches_df["right_idx"].to_numpy()
94+
final_distances = final_matches_df["arc_dist_arcsec"].to_numpy()
95+
final_magnitude_differences = final_matches_df["_magnitude_difference"].to_numpy()
96+
97+
extra_columns = pd.DataFrame(
98+
{
99+
"_dist_arcsec": pd.Series(final_distances, dtype=pd.ArrowDtype(pa.float64())),
100+
"_magnitude_difference": pd.Series(
101+
final_magnitude_differences, dtype=pd.ArrowDtype(pa.float64())
102+
),
103+
}
104+
)
105+
106+
return final_left_indices, final_right_indices, extra_columns

tests/conftest.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from pathlib import Path
22

3+
import pandas as pd
34
import pytest
45
from dask.distributed import Client, LocalCluster
56

@@ -29,3 +30,18 @@ def m67_delve_dir(test_data_dir):
2930
@pytest.fixture
3031
def m67_ps1_dir(test_data_dir):
3132
return test_data_dir / "m67" / "ps1_cone"
33+
34+
35+
@pytest.fixture
36+
def m67_delve_small_dir(test_data_dir):
37+
return test_data_dir / "m67" / "delve_cone_small"
38+
39+
40+
@pytest.fixture
41+
def m67_ps1_small_dir(test_data_dir):
42+
return test_data_dir / "m67" / "ps1_cone_small"
43+
44+
45+
@pytest.fixture
46+
def xmatch_mags(test_data_dir):
47+
return pd.read_csv(test_data_dir / "expected_results" / "xmatch_mags_rband.csv")

0 commit comments

Comments
 (0)