Skip to content

Commit 69a6c67

Browse files
committed
formatting fixes - crossmatch file
1 parent 86610e9 commit 69a6c67

File tree

1 file changed

+60
-36
lines changed

1 file changed

+60
-36
lines changed

src/lsdb_crossmatch/mag_difference_crossmatch.py

Lines changed: 60 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
import numpy as np
66
import pandas as pd
77
import pyarrow as pa
8-
98
from lsdb.core.crossmatch.kdtree_match import KdTreeCrossmatch
109
from lsdb.core.crossmatch.kdtree_utils import _find_crossmatch_indices, _get_chord_distance
1110

@@ -14,11 +13,16 @@
1413

1514

1615
class MyCrossmatchAlgorithm(KdTreeCrossmatch):
16+
"""Cross-matching algorithm that extends KdTreeCrossmatch to include
17+
magnitude difference calculations and filtering.
18+
"""
1719

18-
extra_columns = pd.DataFrame({
19-
"_dist_arcsec": pd.Series(dtype=pd.ArrowDtype(pa.float64())),
20-
"magnitude_difference": pd.Series(dtype=pd.ArrowDtype(pa.float64())),
21-
})
20+
extra_columns = pd.DataFrame(
21+
{
22+
"_dist_arcsec": pd.Series(dtype=pd.ArrowDtype(pa.float64())),
23+
"magnitude_difference": pd.Series(dtype=pd.ArrowDtype(pa.float64())),
24+
}
25+
)
2226

2327
@classmethod
2428
def validate(
@@ -29,66 +33,86 @@ def validate(
2933
right_mag_col: str,
3034
radius_arcsec: float = 1,
3135
n_neighbors: int = 1,
32-
**kwargs
33-
):
36+
): # pylint: disable=too-many-arguments,arguments-renamed,too-many-positional-arguments
3437
super().validate(left, right, n_neighbors=n_neighbors, radius_arcsec=radius_arcsec)
3538

36-
if left_mag_col not in left._ddf.columns:
39+
if left_mag_col not in left.columns:
3740
raise ValueError(f"Left catalog must have column '{left_mag_col}'")
38-
if right_mag_col not in right._ddf.columns:
41+
if right_mag_col not in right.columns:
3942
raise ValueError(f"Right catalog must have column '{right_mag_col}'")
4043

44+
def _calculate_magnitude_differences(
45+
self, all_matches_df: pd.DataFrame, left_mag_col: str, right_mag_col: str
46+
) -> pd.DataFrame:
47+
all_matches_df["left_mag"] = self.left.iloc[all_matches_df["left_idx"]][left_mag_col].to_numpy()
48+
all_matches_df["right_mag"] = self.right.iloc[all_matches_df["right_idx"]][right_mag_col].to_numpy()
49+
all_matches_df["magnitude_difference"] = np.abs(
50+
all_matches_df["right_mag"] - all_matches_df["left_mag"]
51+
)
52+
return all_matches_df
53+
54+
def _select_best_matches(self, all_matches_df: pd.DataFrame) -> pd.DataFrame:
55+
best_match_indices_in_all_matches_df = all_matches_df.groupby("left_idx")[
56+
"magnitude_difference"
57+
].idxmin()
58+
return all_matches_df.loc[best_match_indices_in_all_matches_df].reset_index(drop=True)
59+
60+
# pylint: disable=arguments-differ
4161
def perform_crossmatch(
4262
self,
4363
left_mag_col: str,
4464
right_mag_col: str,
4565
radius_arcsec: float,
46-
n_neighbors: int = 1,
47-
**kwargs
66+
n_neighbors: int = 1,
4867
) -> tuple[np.ndarray, np.ndarray, pd.DataFrame]:
49-
5068
max_d_chord = _get_chord_distance(radius_arcsec)
51-
69+
5270
left_xyz, right_xyz = self._get_point_coordinates()
5371

5472
chord_distances_all, left_idx_all, right_idx_all = _find_crossmatch_indices(
5573
left_xyz=left_xyz,
5674
right_xyz=right_xyz,
57-
n_neighbors=n_neighbors,
75+
n_neighbors=n_neighbors,
5876
max_distance=max_d_chord,
5977
)
6078

6179
if len(left_idx_all) == 0:
62-
return np.array([], dtype=np.int64), np.array([], dtype=np.int64), pd.DataFrame({
63-
"_dist_arcsec": pd.Series(dtype=pd.ArrowDtype(pa.float64())),
64-
"magnitude_difference": pd.Series(dtype=pd.ArrowDtype(pa.float64())),
65-
})
80+
return (
81+
np.array([], dtype=np.int64),
82+
np.array([], dtype=np.int64),
83+
pd.DataFrame(
84+
{
85+
"_dist_arcsec": pd.Series(dtype=pd.ArrowDtype(pa.float64())),
86+
"magnitude_difference": pd.Series(dtype=pd.ArrowDtype(pa.float64())),
87+
}
88+
),
89+
)
6690

6791
arc_distances_all = np.degrees(2.0 * np.arcsin(0.5 * chord_distances_all)) * 3600
6892

69-
all_matches_df = pd.DataFrame({
70-
"left_idx": left_idx_all,
71-
"right_idx": right_idx_all,
72-
"arc_dist_arcsec": arc_distances_all,
73-
})
74-
75-
all_matches_df["left_mag"] = self.left.iloc[all_matches_df["left_idx"]][left_mag_col].to_numpy()
76-
all_matches_df["right_mag"] = self.right.iloc[all_matches_df["right_idx"]][right_mag_col].to_numpy()
77-
78-
all_matches_df["magnitude_difference"] = np.abs(all_matches_df["right_mag"] - all_matches_df["left_mag"])
79-
80-
best_match_indices_in_all_matches_df = all_matches_df.groupby("left_idx")["magnitude_difference"].idxmin()
93+
all_matches_df = pd.DataFrame(
94+
{
95+
"left_idx": left_idx_all,
96+
"right_idx": right_idx_all,
97+
"arc_dist_arcsec": arc_distances_all,
98+
}
99+
)
81100

82-
final_matches_df = all_matches_df.loc[best_match_indices_in_all_matches_df].reset_index(drop=True)
101+
all_matches_df = self._calculate_magnitude_differences(all_matches_df, left_mag_col, right_mag_col)
102+
final_matches_df = self._select_best_matches(all_matches_df)
83103

84104
final_left_indices = final_matches_df["left_idx"].to_numpy()
85105
final_right_indices = final_matches_df["right_idx"].to_numpy()
86106
final_distances = final_matches_df["arc_dist_arcsec"].to_numpy()
87107
final_magnitude_differences = final_matches_df["magnitude_difference"].to_numpy()
88108

89-
extra_columns = pd.DataFrame({
90-
"_dist_arcsec": pd.Series(final_distances, dtype=pd.ArrowDtype(pa.float64())),
91-
"magnitude_difference": pd.Series(final_magnitude_differences, dtype=pd.ArrowDtype(pa.float64())),
92-
})
109+
extra_columns = pd.DataFrame(
110+
{
111+
"_dist_arcsec": pd.Series(final_distances, dtype=pd.ArrowDtype(pa.float64())),
112+
"magnitude_difference": pd.Series(
113+
final_magnitude_differences, dtype=pd.ArrowDtype(pa.float64())
114+
),
115+
}
116+
)
93117

94-
return final_left_indices, final_right_indices, extra_columns
118+
return final_left_indices, final_right_indices, extra_columns

0 commit comments

Comments
 (0)