2323from fink_science .tester import spark_unit_tests
2424from fink_utils .photometry .conversion import mag2fluxcal_snana
2525import astropy .units as u
26+ from dust_extinction .parameter_averages import F99
2627from astropy .cosmology import LambdaCDM
2728from astropy .coordinates import SkyCoord
2829from dustmaps .sfd import SFDQuery
@@ -78,7 +79,38 @@ def compute_flux(pdf):
7879 return pdf
7980
8081
81- def abs_peak (app_peak , z , zerr , ebv ):
82+ def compute_milky_way_extinction (ebv , lambda_angstrom , Rv = 3.1 ):
83+ """Compute the milky way extinction
84+
85+ Parameters
86+ ----------
87+ ebv: float
88+ E(B-V) extinction.
89+ lambda_angstrom: float
90+ Effective wavelength of the telescope filter expressed in Angstrom
91+ Rv: float
92+ Parameter describing the shape of the extinction curve.
93+ Rv = 3.1 is a standard value in many cases.
94+
95+ Examples
96+ --------
97+ >>> round(compute_milky_way_extinction(0.5, 6000), 2)
98+ 1.34
99+ """
100+ # Filter effective wavelength
101+ lambda_eff = lambda_angstrom * u .AA
102+
103+ # Extinction law
104+ ext = F99 (Rv = Rv )
105+
106+ # Total extinction
107+ R_lambda = ext (lambda_eff ) * Rv
108+ A_lambda = R_lambda * ebv
109+
110+ return A_lambda
111+
112+
113+ def abs_peak (app_peak , lambda_angstrom , z , zerr , ebv ):
82114 """Compute the peak absolute magnitude based on redshift, assuming a cosmology
83115
84116 Notes
@@ -87,8 +119,11 @@ def abs_peak(app_peak, z, zerr, ebv):
87119
88120 Parameters
89121 ----------
90- app_peak: float
91- Apparent peak magnitude.
122+ app_peak: list
123+ Apparent peak magnitudes in each passband
124+ lambda_angstrom: list
125+ Effective wavelength associated to the peak apparent magnitudes,
126+ expressed in Angstrom.
92127 z: float
93128 Redshift
94129 zerr: float
@@ -98,35 +133,54 @@ def abs_peak(app_peak, z, zerr, ebv):
98133
99134 Examples
100135 --------
101- >>> abs_peak(19, 0.2, 0.05, 0.1)
102- array([-20.49163613 , -21.1351084 , -21.63604614 ])
103- >>> abs_peak(19, 0.2, 0.05, -1)
104- array([-20.18163613 , -20.8251084 , -21.32604614 ])
105- >>> abs_peak(19, 0.2, np.nan, 0.1)
136+ >>> abs_peak(19, 4000, 0.2, 0.05, 0.1)
137+ array([-20.92638971 , -21.66227902 , -22.25186059 ])
138+ >>> abs_peak(19, 4000, 0.2, 0.05, -1)
139+ array([-20.48512533 , -21.22101463 , -21.81059621 ])
140+ >>> abs_peak(19, 4000, 0.2, np.nan, 0.1)
106141 array([ nan, nan, nan])
107- >>> abs_peak(19, np.nan, 0.05, 0.1)
142+ >>> abs_peak(19, 4000, np.nan, 0.05, 0.1)
108143 array([ nan, nan, nan])
144+ >>> abs_peak([18, 18], [4400, 6600], 0.12, 0.01, 0.5)
145+ array([-22.74727368, -22.96008329, -23.15747603])
109146 """
147+ # In case the user gives a single value instead of a list
148+ app_peak_is_num = (type (app_peak ) is float ) | (type (app_peak ) is int )
149+ lambda_angstrom_is_num = (type (lambda_angstrom ) is float ) | (
150+ type (lambda_angstrom ) is int
151+ )
152+
153+ if app_peak_is_num & lambda_angstrom_is_num :
154+ app_peak = [app_peak ]
155+ lambda_angstrom = [lambda_angstrom ]
156+
157+ # In case a negative E(B-V) value is provided
110158 if ebv < 0 :
111159 ebv = 0
112160
113161 if (z == z ) and (zerr == zerr ):
114162 cosmo = LambdaCDM (H0 = 67.8 , Om0 = 0.308 , Ode0 = 0.692 )
115- Rv = 3.1
116-
117- Ms = []
118- for k in [- 1 , 0 , 1 ]:
119- effective_z = max (z + k * zerr , 1e-3 )
120- D_L = cosmo .luminosity_distance (effective_z ).to ("pc" ).value
121- M = (
122- app_peak
123- - 5 * np .log10 (D_L / 10 )
124- + 2.5 * np .log10 (1 + effective_z )
125- - Rv * ebv
126- )
127- Ms .append (M )
128163
129- return np .array (Ms )
164+ Ms_lambda = []
165+
166+ for band in range (len (app_peak )):
167+ Ms = []
168+ for k in [- 1 , 0 , 1 ]:
169+ effective_z = max (z + k * zerr , 1e-3 )
170+ D_L = cosmo .luminosity_distance (effective_z ).to ("pc" ).value
171+ M = (
172+ app_peak [band ]
173+ - 5 * np .log10 (D_L / 10 )
174+ - 2.5 * np .log10 (1 + effective_z )
175+ - compute_milky_way_extinction (ebv , lambda_angstrom [band ])
176+ )
177+ Ms .append (M )
178+ Ms_lambda .append (Ms )
179+
180+ # Find the band with the highest absolute magnitude
181+ brightest = np .argmin (np .array (Ms_lambda )[:, 1 ])
182+
183+ return np .array (Ms_lambda [brightest ])
130184
131185 return np .array ([np .nan , np .nan , np .nan ])
132186
@@ -271,7 +325,7 @@ def add_all_ebv(pdf):
271325 Parameters
272326 ----------
273327 pdf: pd.DataFrame
274- Must at leat include objectId, ra, dec columns
328+ Must at least include objectId, ra, dec columns
275329
276330 Returns
277331 -------
@@ -331,6 +385,45 @@ def remove_nan(pdf):
331385 return pdf
332386
333387
388+ def remove_bad_bands (pdf ):
389+ """Keep only the g and r bands
390+
391+ Parameters
392+ ----------
393+ pdf: pd.DataFrame
394+ Must at least include cfid, based
395+ on which it will remove unwanted bands from the columns:
396+ "cjd","cmagpsf","csigmapsf","cfid","csigflux","cflux"
397+
398+ Returns
399+ -------
400+ pd.DataFrame
401+ Original DataFrame with nan/None removed.
402+
403+ Examples
404+ --------
405+ >>> pdf = pd.DataFrame(data={"cflux":[[10, 20, 30, 40]],"cfid":[[1, 2, 3, 3]]})
406+ >>> result = remove_bad_bands(pdf)
407+ >>> expected = pd.DataFrame(data={"cflux":[[10, 20]],"cfid":[[1, 2]]})
408+ >>> pd.testing.assert_frame_equal(result, expected)
409+ """
410+ for k in ["cjd" , "cmagpsf" , "csigmapsf" , "csigflux" , "cflux" , "cfid" ]:
411+ if k in pdf .columns :
412+ pdf .loc [:, k ] = pdf .apply (
413+ lambda row : np .array ([
414+ a
415+ for a , b in zip (
416+ row [k ],
417+ (np .isin (row ["cfid" ], list (kern .band_wave_aa .keys ()))), # noqa: E711
418+ )
419+ if b
420+ ]),
421+ axis = 1 ,
422+ )
423+
424+ return pdf
425+
426+
334427def fit_rainbow (lc , rainbow_model ):
335428 """Perform a rainbow fit (Russeil et al. 2024) on a light curve.
336429
@@ -364,9 +457,6 @@ def fit_rainbow(lc, rainbow_model):
364457 np .array (lc ["cfid" ]),
365458 )
366459
367- # t_scaler = Scaler.from_time(lc["cjd"])
368- # m_scaler = MultiBandScaler.from_flux(lc["cflux"], lc["cfid"], with_baseline=False)
369-
370460 try :
371461 result , errors = rainbow_model ._eval_and_get_errors (
372462 t = lc ["cjd" ],
@@ -443,7 +533,7 @@ def statistical_features(lc):
443533 -------
444534 list
445535 List of statistical features
446- [amplitude, kurtosis, max_slope, skew, peak_magn , std_flux, q15_time, q85_time]
536+ [amplitude, kurtosis, max_slope, skew, peak_mag_g, peak_mag_r , std_flux, q15_time, q85_time]
447537 """
448538 amplitude = lcpckg .Amplitude ()
449539 kurtosis = lcpckg .Kurtosis ()
@@ -464,11 +554,14 @@ def statistical_features(lc):
464554
465555 normed_flux = lc ["cflux" ] / np .max (lc ["cflux" ])
466556 shifted_time = lc ["cjd" ] - np .min (lc ["cjd" ])
467- peak_mag = np .min (lc ["cmagpsf" ])
557+
558+ peak_mag_g = np .min (lc ["cmagpsf" ][lc ["cfid" ] == 1 ], initial = 99 )
559+ peak_mag_r = np .min (lc ["cmagpsf" ][lc ["cfid" ] == 2 ], initial = 99 )
560+
468561 std = np .std (normed_flux )
469562 q15 = np .quantile (shifted_time , 0.15 )
470563 q85 = np .quantile (shifted_time , 0.85 )
471- return list (result ) + [peak_mag , std , q15 , q85 ]
564+ return list (result ) + [peak_mag_g , peak_mag_r , std , q15 , q85 ]
472565
473566
474567def quiet_model ():
@@ -562,17 +655,18 @@ def extract_features(data):
562655 >>> salt_features = quiet_fit_salt(lc, salt_model)
563656
564657 # Check their values
565- >>> np.testing.assert_allclose(stat_features,[ 8.307904e+02,
658+ >>> np.testing.assert_allclose(stat_features,[ 8.307904e+02,
566659 ... 4.843807e-02, 7.573933e+03, -7.161292e-01,
567- ... 1.875300e+01, 1.383518e-01, 9.992026e+00, 2.499306e+01], rtol=1e-3)
660+ ... 1.875300e+01, 1.882850e+01, 1.383518e-01,
661+ ... 9.992026e+00, 2.499306e+01], rtol=1e-3)
568662 >>> np.testing.assert_allclose(salt_features,[ 1.374512e-01,
569663 ... -1.201602e+01, 3.522748e-03, 9.219506e+00,
570664 ... 3.321469e-02, 4.337947e+01], rtol=5e-2)
571665 >>> np.testing.assert_allclose(rainbow_features,
572- ... [ -2.161259e +00, 4.886508e +03, 2.196836e+01, 2.740976e +01,
573- ... 9.102432e +03, 9.948595e +03, 1.403806e +00, -5.663001e -01,
574- ... 1.050990e +01, 6.421245e +00, 1.106539e +00, 7.157673e +00,
575- ... 1.364669e +01, 1.184238e +00, 1.194966e-01], rtol=5e-2)
666+ ... [ -2.161261e +00, 4.886507e +03, 2.196836e+01, 2.740982e +01,
667+ ... 9.102431e +03, 9.948591e +03, 1.403805e +00, -5.663014e -01,
668+ ... 1.050993e +01, 6.421246e +00, 1.106546e +00, 7.157723e +00,
669+ ... 1.364673e +01, 1.184242e +00, 1.194966e-01], rtol=5e-2)
576670
577671 # Check full feature extraction function
578672 >>> pdf_check = pdf.copy()
@@ -581,15 +675,15 @@ def extract_features(data):
581675 # Only the fake alert should pass the cuts
582676 >>> np.testing.assert_equal(
583677 ... np.array(np.sum(full_features.iloc[-30:].isnull(), axis=1)),
584- ... np.array([ 29, 29, 29 , 29, 29, 29 , 29, 29, 29 , 29, 29, 29 , 29 ,
585- ... 29, 29, 29 , 29, 29, 29 , 29, 29, 29, 29, 29, 29 , 29, 29, 29, 29 , 0]))
678+ ... np.array([ 30, 30, 30 , 30, 30, 30 , 30, 30, 30 , 30, 30, 30 , 30 ,
679+ ... 30, 30, 30 , 30, 30, 30 , 30, 30, 30, 30, 30, 30 , 30, 30, 30, 30 , 0]))
586680
587681 >>> list(full_features.columns) == ["distnr", "ra", "dec", "ebv", "duration",
588- ... "flux_amplitude", "kurtosis", "max_slope", "skew", "peak_mag ", "std_flux", "q15 ",
589- ... "q85", "reference_time", "amplitude", "rise_time", "fall_time", "Tmin ",
590- ... "Tmax", "t_color", "snr_reference_time", "snr_amplitude", "snr_rise_time",
591- ... "snr_fall_time", "snr_Tmin", "snr_Tmax", "snr_t_color", "chi2_rainbow",
592- ... "z", "t0", " x0", "x1", "c", "chi2_salt"]
682+ ... "flux_amplitude", "kurtosis", "max_slope", "skew", "peak_mag_g ", "peak_mag_r ",
683+ ... "std_flux", "q15", " q85", "reference_time", "amplitude", "rise_time", "fall_time",
684+ ... "Tmin", " Tmax", "t_color", "snr_reference_time", "snr_amplitude", "snr_rise_time",
685+ ... "snr_fall_time", "snr_Tmin", "snr_Tmax", "snr_t_color", "chi2_rainbow", "z", "t0",
686+ ... "x0", "x1", "c", "chi2_salt"]
593687 True
594688 """
595689 data = add_all_ebv (data )
@@ -617,7 +711,8 @@ def extract_features(data):
617711 "kurtosis" ,
618712 "max_slope" ,
619713 "skew" ,
620- "peak_mag" ,
714+ "peak_mag_g" ,
715+ "peak_mag_r" ,
621716 "std_flux" ,
622717 "q15" ,
623718 "q85" ,
@@ -636,6 +731,7 @@ def extract_features(data):
636731 kern .min_points_perband
637732 <= np .array ([sum (lc ["cfid" ] == band ) for band in list (kern .band_wave_aa )])
638733 )
734+
639735 enough_total_points = len (lc ["cjd" ]) > kern .min_points_total
640736 duration = np .ptp (lc ["cjd" ])
641737 enough_duration = duration > kern .min_duration
0 commit comments