29
29
from matplotlib .figure import Figure
30
30
from matplotlib .pyplot import subplots , setp
31
31
from matplotlib .ticker import LinearLocator , FuncFormatter
32
- from numpy import isfinite , median , where , all , zeros_like , diff , asarray , interp , arange , floor , ndarray , \
33
- ceil , newaxis , inf , array , ones , poly1d , polyfit , nanpercentile , atleast_2d , nan , linspace , any , sqrt , nanmedian
32
+ from numpy import any , isfinite , median , where , all , zeros_like , diff , asarray , interp , arange , floor , ndarray , \
33
+ ceil , newaxis , inf , array , ones , poly1d , polyfit , nanpercentile , atleast_2d , nan , linspace , any , sqrt , nanmedian , \
34
+ nanmean
34
35
from pytransit .orbits import fold
35
36
from scipy .ndimage import median_filter
36
37
from scipy .signal import medfilt
@@ -45,9 +46,10 @@ class TSData:
45
46
fluxes, and errors. It provides methods for manipulating and analyzing the data.
46
47
"""
47
48
def __init__ (self , time : Sequence , wavelength : Sequence , fluxes : Sequence , errors : Sequence , name : str ,
48
- noise_group : str = 'a' , wl_edges : Sequence | None = None , tm_edges : Sequence | None = None ,
49
+ noise_group : int = 0 , wl_edges : Sequence | None = None , tm_edges : Sequence | None = None ,
49
50
transit_mask : ndarray | None = None , ephemeris : Ephemeris | None = None , n_baseline : int = 1 ,
50
- mask : ndarray = None , ephemeris_group : int = 0 , offset_group : int = 0 ) -> None :
51
+ mask : ndarray = None , epoch_group : int = 0 , offset_group : int = 0 ,
52
+ mask_nonfinite_errors : bool = True ) -> None :
51
53
"""
52
54
Parameters
53
55
----------
@@ -81,8 +83,11 @@ def __init__(self, time: Sequence, wavelength: Sequence, fluxes: Sequence, error
81
83
if n_baseline < 1 :
82
84
raise ValueError ("n_baseline must be greater than zero." )
83
85
84
- if ephemeris_group < 0 :
85
- raise ValueError ("ephemeris_group must be a non-negative integer." )
86
+ if noise_group < 0 :
87
+ raise ValueError ("noise_group must be a positive integer." )
88
+
89
+ if epoch_group < 0 :
90
+ raise ValueError ("epoch_group must be a non-negative integer." )
86
91
87
92
if offset_group < 0 :
88
93
raise ValueError ("offset_group must be a non-negative integer." )
@@ -94,17 +99,20 @@ def __init__(self, time: Sequence, wavelength: Sequence, fluxes: Sequence, error
94
99
raise ValueError ("The wavelength array must contain only finite values." )
95
100
96
101
self .name : str = name
102
+ self .mask_nonfinite_errors : bool = mask_nonfinite_errors
97
103
self .time : ndarray = time .copy ()
98
104
self .wavelength : ndarray = wavelength
99
- self .mask : ndarray = mask if mask is not None else isfinite (fluxes ) & isfinite (errors )
105
+ self .mask : ndarray = mask if mask is not None else isfinite (fluxes )
106
+ if self .mask_nonfinite_errors :
107
+ self .mask &= isfinite (errors )
100
108
self .fluxes : ndarray = where (self .mask , fluxes , nan )
101
109
self .errors : ndarray = where (self .mask , errors , nan )
102
110
self .transit_mask : ndarray = transit_mask if transit_mask is not None else ones (time .size , dtype = bool )
103
111
self .ngid : int = 0
104
- self .ephemeris : Ephemeris | None = ephemeris
112
+ self ._ephemeris : Ephemeris | None = ephemeris
105
113
self .n_baseline : int = n_baseline
106
114
self ._noise_group : str = noise_group
107
- self .ephemeris_group : int = ephemeris_group
115
+ self .epoch_group : int = epoch_group
108
116
self .offset_group : int = offset_group
109
117
self ._dataset : Optional ['TSDataSet' ] = None
110
118
self ._update ()
@@ -143,7 +151,7 @@ def export_fits(self) -> pf.HDUList:
143
151
mask = pf .ImageHDU (self .mask .astype (int ), name = f'mask_{ self .name } ' )
144
152
data .header ['ngroup' ] = self .noise_group
145
153
data .header ['nbasel' ] = self .n_baseline
146
- data .header ['epgroup' ] = self .ephemeris_group
154
+ data .header ['epgroup' ] = self .epoch_group
147
155
data .header ['offgroup' ] = self .offset_group
148
156
#TODO: export ephemeris
149
157
return pf .HDUList ([time , wave , data , ootm , mask ])
@@ -191,7 +199,7 @@ def import_fits(name: str, hdul: pf.HDUList) -> 'TSData':
191
199
192
200
#TODO: import ephemeris
193
201
return TSData (time , wave , data [0 ], data [1 ], name = name , noise_group = noise_group , transit_mask = ootm ,
194
- n_baseline = n_baseline , mask = mask , ephemeris_group = ephemeris_group , offset_group = offset_group )
202
+ n_baseline = n_baseline , mask = mask , epoch_group = ephemeris_group , offset_group = offset_group )
195
203
196
204
def __repr__ (self ) -> str :
197
205
return f"TSData Name:'{ self .name } ' [{ self .wavelength [0 ]:.2f} - { self .wavelength [- 1 ]:.2f} ] nwl={ self .nwl } npt={ self .npt } "
@@ -207,6 +215,16 @@ def noise_group(self, ng: str) -> None:
207
215
if self ._dataset is not None :
208
216
self ._dataset ._update_nids ()
209
217
218
+ @property
219
+ def ephemeris (self ) -> Ephemeris :
220
+ """Ephemeris."""
221
+ return self ._ephemeris
222
+
223
+ @ephemeris .setter
224
+ def ephemeris (self , ep : Ephemeris ) -> None :
225
+ self ._ephemeris = ep
226
+ self .mask_transit (ephemeris = ep )
227
+
210
228
def mask_transit (self , t0 : float | None = None , p : float | None = None , t14 : float | None = None ,
211
229
ephemeris : Ephemeris | None = None , elims : tuple [int , int ] | None = None ) -> 'TSData' :
212
230
"""Create a transit mask based on a given ephemeris or exposure index limits.
@@ -226,9 +244,9 @@ def mask_transit(self, t0: float | None = None, p: float | None = None, t14: flo
226
244
"""
227
245
if (t0 and p and t14 ) or ephemeris is not None :
228
246
if ephemeris is not None :
229
- self .ephemeris = ephemeris
247
+ self ._ephemeris = ephemeris
230
248
else :
231
- self .ephemeris = Ephemeris (t0 , p , t14 )
249
+ self ._ephemeris = Ephemeris (t0 , p , t14 )
232
250
phase = fold (self .time , self .ephemeris .period , self .ephemeris .zero_epoch )
233
251
self .transit_mask = abs (phase ) > 0.502 * self .ephemeris .duration
234
252
elif elims is not None :
@@ -257,6 +275,15 @@ def _update(self) -> None:
257
275
self .nwl = self .wavelength .size
258
276
self .npt = self .time .size
259
277
self .wllims = self .wavelength .min (), self .wavelength .max ()
278
+ if self ._ephemeris is not None :
279
+ self .mask_transit (ephemeris = self ._ephemeris )
280
+
281
+ def _update_data_mask (self ) -> None :
282
+ self .mask = isfinite (self .fluxes )
283
+ if self .mask_nonfinite_errors :
284
+ self .mask &= isfinite (self .errors )
285
+ self .fluxes = where (self .mask , self .fluxes , nan )
286
+ self .errors = where (self .mask , self .errors , nan )
260
287
261
288
def normalize_to_poly (self , deg : int = 1 ) -> 'TSData' :
262
289
"""Normalize the baseline flux for each spectroscopic light curve.
@@ -289,6 +316,7 @@ def normalize_to_poly(self, deg: int = 1) -> 'TSData':
289
316
deg = deg ))(self .time )
290
317
self .fluxes [ipb , :] /= bl
291
318
self .errors [ipb , :] /= bl
319
+ self ._update_data_mask ()
292
320
return self
293
321
294
322
def normalize_to_median (self , s : slice ) -> 'TSData' :
@@ -317,20 +345,22 @@ def partition_time(self, tlims: tuple[tuple[float,float]]) -> 'TSDataSet':
317
345
d = TSData (name = f'{ self .name } _1' , time = self .time [m ], wavelength = self .wavelength ,
318
346
fluxes = self .fluxes [:, m ], errors = self .errors [:, m ], mask = self .mask [:, m ],
319
347
noise_group = self .noise_group ,
320
- ephemeris_group = self .ephemeris_group ,
348
+ epoch_group = self .epoch_group ,
321
349
offset_group = self .offset_group ,
322
350
transit_mask = self .transit_mask [m ],
323
351
ephemeris = self .ephemeris ,
324
- n_baseline = self .n_baseline )
352
+ n_baseline = self .n_baseline ,
353
+ mask_nonfinite_errors = self .mask_nonfinite_errors )
325
354
for i , m in enumerate (masks [1 :]):
326
355
d = d + TSData (name = f'{ self .name } _{ i + 2 } ' , time = self .time [m ], wavelength = self .wavelength ,
327
356
fluxes = self .fluxes [:, m ], errors = self .errors [:, m ], mask = self .mask [:, m ],
328
357
noise_group = self .noise_group ,
329
- ephemeris_group = self .ephemeris_group ,
358
+ epoch_group = self .epoch_group ,
330
359
offset_group = self .offset_group ,
331
360
transit_mask = self .transit_mask [m ],
332
361
ephemeris = self .ephemeris ,
333
- n_baseline = self .n_baseline )
362
+ n_baseline = self .n_baseline ,
363
+ mask_nonfinite_errors = self .mask_nonfinite_errors )
334
364
return d
335
365
336
366
def crop_wavelength (self , lmin : float , lmax : float , inplace : bool = True ) -> 'TSData' :
@@ -362,12 +392,13 @@ def crop_wavelength(self, lmin: float, lmax: float, inplace: bool = True) -> 'TS
362
392
errors = self .errors [m ],
363
393
mask = self .mask [m ],
364
394
noise_group = self .noise_group ,
365
- ephemeris_group = self .ephemeris_group ,
395
+ epoch_group = self .epoch_group ,
366
396
offset_group = self .offset_group ,
367
397
wl_edges = (self ._wl_l_edges [m ], self ._wl_r_edges [m ]),
368
398
tm_edges = (self ._tm_l_edges , self ._tm_r_edges ),
369
399
transit_mask = self .transit_mask , ephemeris = self .ephemeris ,
370
- n_baseline = self .n_baseline )
400
+ n_baseline = self .n_baseline ,
401
+ mask_nonfinite_errors = self .mask_nonfinite_errors )
371
402
372
403
def crop_time (self , tmin : float , tmax : float , inplace : bool = True ) -> 'TSData' :
373
404
"""Crop the data to include only the time range between lmin and lmax.
@@ -399,19 +430,20 @@ def crop_time(self, tmin: float, tmax: float, inplace: bool = True) -> 'TSData':
399
430
errors = self .errors [:, m ],
400
431
mask = self .mask [:, m ],
401
432
noise_group = self .noise_group ,
402
- ephemeris_group = self .ephemeris_group ,
433
+ epoch_group = self .epoch_group ,
403
434
offset_group = self .offset_group ,
404
435
wl_edges = (self ._wl_l_edges , self ._wl_r_edges ),
405
436
tm_edges = (self ._tm_l_edges [m ], self ._tm_r_edges [m ]),
406
437
transit_mask = self .transit_mask [m ], ephemeris = self .ephemeris ,
407
- n_baseline = self .n_baseline )
438
+ n_baseline = self .n_baseline ,
439
+ mask_nonfinite_errors = self .mask_nonfinite_errors )
408
440
409
- def remove_outliers (self , sigma : float = 5.0 ) -> 'TSData' :
410
- """Remove outliers along the wavelength axis.
441
+ # TODO: separate mask into bad data mask and outlier mask.
442
+ def mask_outliers (self , sigma : float = 5.0 ) -> 'TSData' :
443
+ """Mask outliers along the wavelength axis.
411
444
412
- Replace outliers along the wavelength axis with the value of a 5-point running median filter. Outliers are
413
- defined as data points that deviate from the median by more than sigma times the median absolute deviation
414
- along the wavelength axis.
445
+ Outliers are defined as data points that deviate from the running 5-point median by more
446
+ than sigma times the median absolute deviation along the wavelength axis.
415
447
416
448
Parameters
417
449
----------
@@ -422,13 +454,18 @@ def remove_outliers(self, sigma: float = 5.0) -> 'TSData':
422
454
----
423
455
The data will be modified in place.
424
456
"""
425
- fm = median (self .fluxes , axis = 0 )
426
- fe = mad_std (self .fluxes , axis = 0 )
457
+ fm = nanmedian (self .fluxes , axis = 0 )
458
+ fe = mad_std (self .fluxes , axis = 0 , ignore_nan = True )
427
459
self .mask &= abs (self .fluxes - fm ) / fe < sigma
428
460
self .fluxes = where (self .mask , self .fluxes , nan )
429
461
self .errors = where (self .mask , self .errors , nan )
430
462
return self
431
463
464
+ @deprecated ("0.10" , alternative = "TSData.mask_outliers" )
465
+ def remove_outliers (self , sigma : float = 5.0 ) -> 'TSData' :
466
+ """Remove outliers along the wavelength axis."""
467
+ self .mask_outliers (sigma = sigma )
468
+
432
469
def plot (self , ax = None , vmin : float = None , vmax : float = None , cmap = None , figsize = None , data = None ,
433
470
plims : tuple [float , float ] | None = None ) -> Figure :
434
471
"""Plot the spectroscopic light curves as a 2D image.
@@ -528,7 +565,7 @@ def plot_white(self, ax: Axes | None = None, figsize: tuple[float, float] | None
528
565
fig = ax .figure
529
566
tref = floor (self .time .min ())
530
567
531
- ax .plot (self .time , self .fluxes . mean ( 0 ))
568
+ ax .plot (self .time , nanmean ( self .fluxes , 0 ))
532
569
if self .ephemeris is not None :
533
570
[ax .axvline (tl , ls = '--' , c = 'k' ) for tl in self .ephemeris .transit_limits (self .time .mean ())]
534
571
@@ -620,7 +657,7 @@ def bin_wavelength(self, binning: Optional[Union[Binning, CompoundBinning]] = No
620
657
name = self .name ,
621
658
tm_edges = (self ._tm_l_edges , self ._tm_r_edges ),
622
659
noise_group = self .noise_group ,
623
- ephemeris_group = self .ephemeris_group ,
660
+ epoch_group = self .epoch_group ,
624
661
offset_group = self .offset_group ,
625
662
transit_mask = self .transit_mask ,
626
663
ephemeris = self .ephemeris ,
@@ -662,7 +699,7 @@ def bin_time(self, binning: Optional[Union[Binning, CompoundBinning]] = None,
662
699
noise_group = self .noise_group ,
663
700
ephemeris = self .ephemeris ,
664
701
n_baseline = self .n_baseline ,
665
- ephemeris_group = self .ephemeris_group ,
702
+ epoch_group = self .epoch_group ,
666
703
offset_group = self .offset_group )
667
704
if self .ephemeris is not None :
668
705
d .mask_transit (ephemeris = self .ephemeris )
@@ -739,6 +776,11 @@ def offset_groups(self) -> list[int]:
739
776
"""List of offset groups."""
740
777
return [d .offset_group for d in self .data ]
741
778
779
+ @property
780
+ def epoch_groups (self ) -> list [int ]:
781
+ """List of epoch groups."""
782
+ return [d .epoch_group for d in self .data ]
783
+
742
784
@property
743
785
def n_baselines (self ) -> list [int ]:
744
786
"""Number of baseline coefficients for each data set."""
0 commit comments