Skip to content

Commit 6cde531

Browse files
erusseilEtienne Russeil
andauthored
SLSN bug fixes (#671)
* Fix i-band issue + bad galactic extinction computation * Fixed positive K-correction * Updated model with fixed K-correction * add dust_extinction * Add exception to ruff * Removed comments --------- Co-authored-by: Etienne Russeil <etru7215@localhost.localdomain>
1 parent fe4f064 commit 6cde531

File tree

7 files changed

+153
-52
lines changed

7 files changed

+153
-52
lines changed

.github/workflows/run_test_ztf.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ jobs:
4444
- name: Run test suites for ZTF
4545
run: |
4646
pip install onnxruntime==1.16.3
47+
pip install dust_extinction
4748
rm -f /tmp/forest_*.onnx
4849
./run_tests.sh -s ztf
4950
curl -s https://codecov.io/bash | bash

.ruff.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,8 @@ ignore = [
4444
"D400", "D401", "D100", "D102", "D103", "D104", "D415", "D419", "D301",
4545
"E731",
4646
"N812", "N806", "N803",
47-
"PLR0913"
47+
"PLR0913",
48+
"D420"
4849
]
4950

5051
# Allow fix for all enabled rules (when `--fix`) is provided.
-12 KB
Binary file not shown.

fink_science/ztf/superluminous/kernel.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
curdir = os.path.dirname(os.path.abspath(__file__))
2020

2121
classifier_path = curdir + "/data/models/superluminous_classifier.joblib"
22-
# Declare i-band with dummy value
2322
band_wave_aa = {1: 4770.0, 2: 6231.0}
2423
temperature = "sigmoid"
2524
bolometric = "bazin"

fink_science/ztf/superluminous/processor.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -157,17 +157,13 @@ def superluminous_score(
157157
probas_total = np.zeros(len(objectId), dtype=float) - 1
158158
transient_mask = pdf["is_transient"]
159159
old_enough_mask = pdf["jd"] - pdf["jdstarthist"] >= kern.min_duration
160-
161-
# FIXME: kernel & model do not handle i-band
162-
contains_i_band = pdf["cfid"].apply(lambda bands: 3 in bands)
163-
mask_valid = transient_mask & old_enough_mask & ~contains_i_band
160+
mask_valid = transient_mask & old_enough_mask
164161

165162
if sum(mask_valid) == 0:
166163
return pd.Series([-1.0] * len(objectId))
167164

168165
# select only transient alerts
169166
pdf_valid = pdf[mask_valid].copy().reset_index()
170-
171167
valid_ids = list(pdf_valid["objectId"])
172168

173169
# Use Fink API to get the full light curves history
@@ -218,6 +214,9 @@ def superluminous_score(
218214
lcs = slsn.compute_flux(lcs)
219215
lcs = slsn.remove_nan(lcs)
220216

217+
# keep only g and r band, that were used for training
218+
lcs = slsn.remove_bad_bands(lcs)
219+
221220
# Perform feature extraction
222221
features = slsn.extract_features(lcs)
223222

@@ -244,7 +243,11 @@ def superluminous_score(
244243
upper_M = np.array(
245244
SLSN_features.apply(
246245
lambda x: slsn.abs_peak(
247-
x["peak_mag"], x["photoz"], x["photozerr"], x["ebv"]
246+
[x["peak_mag_g"], x["peak_mag_r"]],
247+
[kern.band_wave_aa[1], kern.band_wave_aa[2]],
248+
x["photoz"],
249+
x["photozerr"],
250+
x["ebv"],
248251
)[2],
249252
axis=1,
250253
)

fink_science/ztf/superluminous/slsn_classifier.py

Lines changed: 140 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from fink_science.tester import spark_unit_tests
2424
from fink_utils.photometry.conversion import mag2fluxcal_snana
2525
import astropy.units as u
26+
from dust_extinction.parameter_averages import F99
2627
from astropy.cosmology import LambdaCDM
2728
from astropy.coordinates import SkyCoord
2829
from dustmaps.sfd import SFDQuery
@@ -78,7 +79,38 @@ def compute_flux(pdf):
7879
return pdf
7980

8081

81-
def abs_peak(app_peak, z, zerr, ebv):
82+
def compute_milky_way_extinction(ebv, lambda_angstrom, Rv=3.1):
83+
"""Compute the milky way extinction
84+
85+
Parameters
86+
----------
87+
ebv: float
88+
E(B-V) extinction.
89+
lambda_angstrom: float
90+
Effective wavelength of the telescope filter expressed in Angstrom
91+
Rv: float
92+
Parameter describing the shape of the extinction curve.
93+
Rv = 3.1 is a standard value in many cases.
94+
95+
Examples
96+
--------
97+
>>> round(compute_milky_way_extinction(0.5, 6000), 2)
98+
1.34
99+
"""
100+
# Filter effective wavelength
101+
lambda_eff = lambda_angstrom * u.AA
102+
103+
# Extinction law
104+
ext = F99(Rv=Rv)
105+
106+
# Total extinction
107+
R_lambda = ext(lambda_eff) * Rv
108+
A_lambda = R_lambda * ebv
109+
110+
return A_lambda
111+
112+
113+
def abs_peak(app_peak, lambda_angstrom, z, zerr, ebv):
82114
"""Compute the peak absolute magnitude based on redshift, assuming a cosmology
83115
84116
Notes
@@ -87,8 +119,11 @@ def abs_peak(app_peak, z, zerr, ebv):
87119
88120
Parameters
89121
----------
90-
app_peak: float
91-
Apparent peak magnitude.
122+
app_peak: list
123+
Apparent peak magnitudes in each passband
124+
lambda_angstrom: list
125+
Effective wavelength associated to the peak apparent magnitudes,
126+
expressed in Angstrom.
92127
z: float
93128
Redshift
94129
zerr: float
@@ -98,35 +133,54 @@ def abs_peak(app_peak, z, zerr, ebv):
98133
99134
Examples
100135
--------
101-
>>> abs_peak(19, 0.2, 0.05, 0.1)
102-
array([-20.49163613, -21.1351084 , -21.63604614])
103-
>>> abs_peak(19, 0.2, 0.05, -1)
104-
array([-20.18163613, -20.8251084 , -21.32604614])
105-
>>> abs_peak(19, 0.2, np.nan, 0.1)
136+
>>> abs_peak(19, 4000, 0.2, 0.05, 0.1)
137+
array([-20.92638971, -21.66227902, -22.25186059])
138+
>>> abs_peak(19, 4000, 0.2, 0.05, -1)
139+
array([-20.48512533, -21.22101463, -21.81059621])
140+
>>> abs_peak(19, 4000, 0.2, np.nan, 0.1)
106141
array([ nan, nan, nan])
107-
>>> abs_peak(19, np.nan, 0.05, 0.1)
142+
>>> abs_peak(19, 4000, np.nan, 0.05, 0.1)
108143
array([ nan, nan, nan])
144+
>>> abs_peak([18, 18], [4400, 6600], 0.12, 0.01, 0.5)
145+
array([-22.74727368, -22.96008329, -23.15747603])
109146
"""
147+
# In case the user gives a single value instead of a list
148+
app_peak_is_num = (type(app_peak) is float) | (type(app_peak) is int)
149+
lambda_angstrom_is_num = (type(lambda_angstrom) is float) | (
150+
type(lambda_angstrom) is int
151+
)
152+
153+
if app_peak_is_num & lambda_angstrom_is_num:
154+
app_peak = [app_peak]
155+
lambda_angstrom = [lambda_angstrom]
156+
157+
# In case a negative E(B-V) value is provided
110158
if ebv < 0:
111159
ebv = 0
112160

113161
if (z == z) and (zerr == zerr):
114162
cosmo = LambdaCDM(H0=67.8, Om0=0.308, Ode0=0.692)
115-
Rv = 3.1
116-
117-
Ms = []
118-
for k in [-1, 0, 1]:
119-
effective_z = max(z + k * zerr, 1e-3)
120-
D_L = cosmo.luminosity_distance(effective_z).to("pc").value
121-
M = (
122-
app_peak
123-
- 5 * np.log10(D_L / 10)
124-
+ 2.5 * np.log10(1 + effective_z)
125-
- Rv * ebv
126-
)
127-
Ms.append(M)
128163

129-
return np.array(Ms)
164+
Ms_lambda = []
165+
166+
for band in range(len(app_peak)):
167+
Ms = []
168+
for k in [-1, 0, 1]:
169+
effective_z = max(z + k * zerr, 1e-3)
170+
D_L = cosmo.luminosity_distance(effective_z).to("pc").value
171+
M = (
172+
app_peak[band]
173+
- 5 * np.log10(D_L / 10)
174+
- 2.5 * np.log10(1 + effective_z)
175+
- compute_milky_way_extinction(ebv, lambda_angstrom[band])
176+
)
177+
Ms.append(M)
178+
Ms_lambda.append(Ms)
179+
180+
# Find the band with the highest absolute magnitude
181+
brightest = np.argmin(np.array(Ms_lambda)[:, 1])
182+
183+
return np.array(Ms_lambda[brightest])
130184

131185
return np.array([np.nan, np.nan, np.nan])
132186

@@ -271,7 +325,7 @@ def add_all_ebv(pdf):
271325
Parameters
272326
----------
273327
pdf: pd.DataFrame
274-
Must at leat include objectId, ra, dec columns
328+
Must at least include objectId, ra, dec columns
275329
276330
Returns
277331
-------
@@ -331,6 +385,45 @@ def remove_nan(pdf):
331385
return pdf
332386

333387

388+
def remove_bad_bands(pdf):
389+
"""Keep only the g and r bands
390+
391+
Parameters
392+
----------
393+
pdf: pd.DataFrame
394+
Must at least include cfid, based
395+
on which it will remove unwanted bands from the columns:
396+
"cjd","cmagpsf","csigmapsf","cfid","csigflux","cflux"
397+
398+
Returns
399+
-------
400+
pd.DataFrame
401+
Original DataFrame with nan/None removed.
402+
403+
Examples
404+
--------
405+
>>> pdf = pd.DataFrame(data={"cflux":[[10, 20, 30, 40]],"cfid":[[1, 2, 3, 3]]})
406+
>>> result = remove_bad_bands(pdf)
407+
>>> expected = pd.DataFrame(data={"cflux":[[10, 20]],"cfid":[[1, 2]]})
408+
>>> pd.testing.assert_frame_equal(result, expected)
409+
"""
410+
for k in ["cjd", "cmagpsf", "csigmapsf", "csigflux", "cflux", "cfid"]:
411+
if k in pdf.columns:
412+
pdf.loc[:, k] = pdf.apply(
413+
lambda row: np.array([
414+
a
415+
for a, b in zip(
416+
row[k],
417+
(np.isin(row["cfid"], list(kern.band_wave_aa.keys()))), # noqa: E711
418+
)
419+
if b
420+
]),
421+
axis=1,
422+
)
423+
424+
return pdf
425+
426+
334427
def fit_rainbow(lc, rainbow_model):
335428
"""Perform a rainbow fit (Russeil et al. 2024) on a light curve.
336429
@@ -364,9 +457,6 @@ def fit_rainbow(lc, rainbow_model):
364457
np.array(lc["cfid"]),
365458
)
366459

367-
# t_scaler = Scaler.from_time(lc["cjd"])
368-
# m_scaler = MultiBandScaler.from_flux(lc["cflux"], lc["cfid"], with_baseline=False)
369-
370460
try:
371461
result, errors = rainbow_model._eval_and_get_errors(
372462
t=lc["cjd"],
@@ -443,7 +533,7 @@ def statistical_features(lc):
443533
-------
444534
list
445535
List of statistical features
446-
[amplitude, kurtosis, max_slope, skew, peak_magn, std_flux, q15_time, q85_time]
536+
[amplitude, kurtosis, max_slope, skew, peak_mag_g, peak_mag_r, std_flux, q15_time, q85_time]
447537
"""
448538
amplitude = lcpckg.Amplitude()
449539
kurtosis = lcpckg.Kurtosis()
@@ -464,11 +554,14 @@ def statistical_features(lc):
464554

465555
normed_flux = lc["cflux"] / np.max(lc["cflux"])
466556
shifted_time = lc["cjd"] - np.min(lc["cjd"])
467-
peak_mag = np.min(lc["cmagpsf"])
557+
558+
peak_mag_g = np.min(lc["cmagpsf"][lc["cfid"] == 1], initial=99)
559+
peak_mag_r = np.min(lc["cmagpsf"][lc["cfid"] == 2], initial=99)
560+
468561
std = np.std(normed_flux)
469562
q15 = np.quantile(shifted_time, 0.15)
470563
q85 = np.quantile(shifted_time, 0.85)
471-
return list(result) + [peak_mag, std, q15, q85]
564+
return list(result) + [peak_mag_g, peak_mag_r, std, q15, q85]
472565

473566

474567
def quiet_model():
@@ -562,17 +655,18 @@ def extract_features(data):
562655
>>> salt_features = quiet_fit_salt(lc, salt_model)
563656
564657
# Check their values
565-
>>> np.testing.assert_allclose(stat_features,[ 8.307904e+02,
658+
>>> np.testing.assert_allclose(stat_features,[ 8.307904e+02,
566659
... 4.843807e-02, 7.573933e+03, -7.161292e-01,
567-
... 1.875300e+01, 1.383518e-01, 9.992026e+00, 2.499306e+01], rtol=1e-3)
660+
... 1.875300e+01, 1.882850e+01, 1.383518e-01,
661+
... 9.992026e+00, 2.499306e+01], rtol=1e-3)
568662
>>> np.testing.assert_allclose(salt_features,[ 1.374512e-01,
569663
... -1.201602e+01, 3.522748e-03, 9.219506e+00,
570664
... 3.321469e-02, 4.337947e+01], rtol=5e-2)
571665
>>> np.testing.assert_allclose(rainbow_features,
572-
... [ -2.161259e+00, 4.886508e+03, 2.196836e+01, 2.740976e+01,
573-
... 9.102432e+03, 9.948595e+03, 1.403806e+00, -5.663001e-01,
574-
... 1.050990e+01, 6.421245e+00, 1.106539e+00, 7.157673e+00,
575-
... 1.364669e+01, 1.184238e+00, 1.194966e-01], rtol=5e-2)
666+
... [ -2.161261e+00, 4.886507e+03, 2.196836e+01, 2.740982e+01,
667+
... 9.102431e+03, 9.948591e+03, 1.403805e+00, -5.663014e-01,
668+
... 1.050993e+01, 6.421246e+00, 1.106546e+00, 7.157723e+00,
669+
... 1.364673e+01, 1.184242e+00, 1.194966e-01], rtol=5e-2)
576670
577671
# Check full feature extraction function
578672
>>> pdf_check = pdf.copy()
@@ -581,15 +675,15 @@ def extract_features(data):
581675
# Only the fake alert should pass the cuts
582676
>>> np.testing.assert_equal(
583677
... np.array(np.sum(full_features.iloc[-30:].isnull(), axis=1)),
584-
... np.array([ 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29,
585-
... 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 0]))
678+
... np.array([ 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30,
679+
... 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 0]))
586680
587681
>>> list(full_features.columns) == ["distnr", "ra", "dec", "ebv", "duration",
588-
... "flux_amplitude", "kurtosis", "max_slope", "skew", "peak_mag", "std_flux", "q15",
589-
... "q85", "reference_time", "amplitude", "rise_time", "fall_time", "Tmin",
590-
... "Tmax", "t_color", "snr_reference_time", "snr_amplitude", "snr_rise_time",
591-
... "snr_fall_time", "snr_Tmin", "snr_Tmax", "snr_t_color", "chi2_rainbow",
592-
... "z", "t0", "x0", "x1", "c", "chi2_salt"]
682+
... "flux_amplitude", "kurtosis", "max_slope", "skew", "peak_mag_g", "peak_mag_r",
683+
... "std_flux", "q15", "q85", "reference_time", "amplitude", "rise_time", "fall_time",
684+
... "Tmin", "Tmax", "t_color", "snr_reference_time", "snr_amplitude", "snr_rise_time",
685+
... "snr_fall_time", "snr_Tmin", "snr_Tmax", "snr_t_color", "chi2_rainbow", "z", "t0",
686+
... "x0", "x1", "c", "chi2_salt"]
593687
True
594688
"""
595689
data = add_all_ebv(data)
@@ -617,7 +711,8 @@ def extract_features(data):
617711
"kurtosis",
618712
"max_slope",
619713
"skew",
620-
"peak_mag",
714+
"peak_mag_g",
715+
"peak_mag_r",
621716
"std_flux",
622717
"q15",
623718
"q85",
@@ -636,6 +731,7 @@ def extract_features(data):
636731
kern.min_points_perband
637732
<= np.array([sum(lc["cfid"] == band) for band in list(kern.band_wave_aa)])
638733
)
734+
639735
enough_total_points = len(lc["cjd"]) > kern.min_points_total
640736
duration = np.ptp(lc["cjd"])
641737
enough_duration = duration > kern.min_duration

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ light-curve==0.10.4
4444
sncosmo==2.12.1
4545
xgboost==2.1.4
4646
dustmaps==1.0.14
47+
dust_extinction==1.5
4748

4849
#for ELAsTiCC
4950
light-curve[full]

0 commit comments

Comments
 (0)