55import numpy as np
66import pandas as pd
77import pyarrow as pa
8-
98from lsdb .core .crossmatch .kdtree_match import KdTreeCrossmatch
109from lsdb .core .crossmatch .kdtree_utils import _find_crossmatch_indices , _get_chord_distance
1110
1413
1514
1615class MyCrossmatchAlgorithm (KdTreeCrossmatch ):
16+ """Cross-matching algorithm that extends KdTreeCrossmatch to include
17+ magnitude difference calculations and filtering.
18+ """
1719
18- extra_columns = pd .DataFrame ({
19- "_dist_arcsec" : pd .Series (dtype = pd .ArrowDtype (pa .float64 ())),
20- "magnitude_difference" : pd .Series (dtype = pd .ArrowDtype (pa .float64 ())),
21- })
20+ extra_columns = pd .DataFrame (
21+ {
22+ "_dist_arcsec" : pd .Series (dtype = pd .ArrowDtype (pa .float64 ())),
23+ "magnitude_difference" : pd .Series (dtype = pd .ArrowDtype (pa .float64 ())),
24+ }
25+ )
2226
2327 @classmethod
2428 def validate (
@@ -29,66 +33,86 @@ def validate(
2933 right_mag_col : str ,
3034 radius_arcsec : float = 1 ,
3135 n_neighbors : int = 1 ,
32- ** kwargs
33- ):
36+ ): # pylint: disable=too-many-arguments,arguments-renamed,too-many-positional-arguments
3437 super ().validate (left , right , n_neighbors = n_neighbors , radius_arcsec = radius_arcsec )
3538
36- if left_mag_col not in left ._ddf . columns :
39+ if left_mag_col not in left .columns :
3740 raise ValueError (f"Left catalog must have column '{ left_mag_col } '" )
38- if right_mag_col not in right ._ddf . columns :
41+ if right_mag_col not in right .columns :
3942 raise ValueError (f"Right catalog must have column '{ right_mag_col } '" )
4043
44+ def _calculate_magnitude_differences (
45+ self , all_matches_df : pd .DataFrame , left_mag_col : str , right_mag_col : str
46+ ) -> pd .DataFrame :
47+ all_matches_df ["left_mag" ] = self .left .iloc [all_matches_df ["left_idx" ]][left_mag_col ].to_numpy ()
48+ all_matches_df ["right_mag" ] = self .right .iloc [all_matches_df ["right_idx" ]][right_mag_col ].to_numpy ()
49+ all_matches_df ["magnitude_difference" ] = np .abs (
50+ all_matches_df ["right_mag" ] - all_matches_df ["left_mag" ]
51+ )
52+ return all_matches_df
53+
54+ def _select_best_matches (self , all_matches_df : pd .DataFrame ) -> pd .DataFrame :
55+ best_match_indices_in_all_matches_df = all_matches_df .groupby ("left_idx" )[
56+ "magnitude_difference"
57+ ].idxmin ()
58+ return all_matches_df .loc [best_match_indices_in_all_matches_df ].reset_index (drop = True )
59+
60+ # pylint: disable=arguments-differ
4161 def perform_crossmatch (
4262 self ,
4363 left_mag_col : str ,
4464 right_mag_col : str ,
4565 radius_arcsec : float ,
46- n_neighbors : int = 1 ,
47- ** kwargs
66+ n_neighbors : int = 1 ,
4867 ) -> tuple [np .ndarray , np .ndarray , pd .DataFrame ]:
49-
5068 max_d_chord = _get_chord_distance (radius_arcsec )
51-
69+
5270 left_xyz , right_xyz = self ._get_point_coordinates ()
5371
5472 chord_distances_all , left_idx_all , right_idx_all = _find_crossmatch_indices (
5573 left_xyz = left_xyz ,
5674 right_xyz = right_xyz ,
57- n_neighbors = n_neighbors ,
75+ n_neighbors = n_neighbors ,
5876 max_distance = max_d_chord ,
5977 )
6078
6179 if len (left_idx_all ) == 0 :
62- return np .array ([], dtype = np .int64 ), np .array ([], dtype = np .int64 ), pd .DataFrame ({
63- "_dist_arcsec" : pd .Series (dtype = pd .ArrowDtype (pa .float64 ())),
64- "magnitude_difference" : pd .Series (dtype = pd .ArrowDtype (pa .float64 ())),
65- })
80+ return (
81+ np .array ([], dtype = np .int64 ),
82+ np .array ([], dtype = np .int64 ),
83+ pd .DataFrame (
84+ {
85+ "_dist_arcsec" : pd .Series (dtype = pd .ArrowDtype (pa .float64 ())),
86+ "magnitude_difference" : pd .Series (dtype = pd .ArrowDtype (pa .float64 ())),
87+ }
88+ ),
89+ )
6690
6791 arc_distances_all = np .degrees (2.0 * np .arcsin (0.5 * chord_distances_all )) * 3600
6892
69- all_matches_df = pd .DataFrame ({
70- "left_idx" : left_idx_all ,
71- "right_idx" : right_idx_all ,
72- "arc_dist_arcsec" : arc_distances_all ,
73- })
74-
75- all_matches_df ["left_mag" ] = self .left .iloc [all_matches_df ["left_idx" ]][left_mag_col ].to_numpy ()
76- all_matches_df ["right_mag" ] = self .right .iloc [all_matches_df ["right_idx" ]][right_mag_col ].to_numpy ()
77-
78- all_matches_df ["magnitude_difference" ] = np .abs (all_matches_df ["right_mag" ] - all_matches_df ["left_mag" ])
79-
80- best_match_indices_in_all_matches_df = all_matches_df .groupby ("left_idx" )["magnitude_difference" ].idxmin ()
93+ all_matches_df = pd .DataFrame (
94+ {
95+ "left_idx" : left_idx_all ,
96+ "right_idx" : right_idx_all ,
97+ "arc_dist_arcsec" : arc_distances_all ,
98+ }
99+ )
81100
82- final_matches_df = all_matches_df .loc [best_match_indices_in_all_matches_df ].reset_index (drop = True )
101+ all_matches_df = self ._calculate_magnitude_differences (all_matches_df , left_mag_col , right_mag_col )
102+ final_matches_df = self ._select_best_matches (all_matches_df )
83103
84104 final_left_indices = final_matches_df ["left_idx" ].to_numpy ()
85105 final_right_indices = final_matches_df ["right_idx" ].to_numpy ()
86106 final_distances = final_matches_df ["arc_dist_arcsec" ].to_numpy ()
87107 final_magnitude_differences = final_matches_df ["magnitude_difference" ].to_numpy ()
88108
89- extra_columns = pd .DataFrame ({
90- "_dist_arcsec" : pd .Series (final_distances , dtype = pd .ArrowDtype (pa .float64 ())),
91- "magnitude_difference" : pd .Series (final_magnitude_differences , dtype = pd .ArrowDtype (pa .float64 ())),
92- })
109+ extra_columns = pd .DataFrame (
110+ {
111+ "_dist_arcsec" : pd .Series (final_distances , dtype = pd .ArrowDtype (pa .float64 ())),
112+ "magnitude_difference" : pd .Series (
113+ final_magnitude_differences , dtype = pd .ArrowDtype (pa .float64 ())
114+ ),
115+ }
116+ )
93117
94- return final_left_indices , final_right_indices , extra_columns
118+ return final_left_indices , final_right_indices , extra_columns
0 commit comments