Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/changes/968.bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Speedup and improve robustness of fill_bad_time_intervals
152 changes: 129 additions & 23 deletions stingray/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
get_random_state,
find_nearest,
rebin_data,
show_progress,
)
from .gti import (
create_gti_mask,
Expand All @@ -36,6 +37,7 @@
get_total_gti_length,
bin_intervals_from_gtis,
time_intervals_from_gtis,
eliminate_short_gtis,
)
from typing import TYPE_CHECKING, Type, TypeVar, Union

Expand Down Expand Up @@ -2276,9 +2278,12 @@ def fill_bad_time_intervals(
Random seed to use for the simulation. If None, a random seed is generated.

"""

# Initialize random number generator for reproducibility
rs = get_random_state(seed)

# Determine which attributes should be filled with random samples from the buffer,
# and which should be left alone (filled with NaN or default values).
# Time and mask are always treated specially and excluded from randomization.
if attrs_to_randomize is None:
attrs_to_randomize = self.array_attrs() + self.internal_array_attrs()
for attr in ["time", "_mask"]:
Expand All @@ -2291,17 +2296,21 @@ def fill_bad_time_intervals(
if a not in attrs_to_randomize
]

# Set default max_length: only fill gaps shorter than 1/100th of the longest GTI.
# This ensures we only fill very short gaps where white noise injection is acceptable.
if max_length is None:
max_length = np.max(self.gti[:, 1] - self.gti[:, 0]) / 100

btis = get_btis(self.gti, self.time[0], self.time[-1])
if len(btis) == 0:
#
# If there's only one GTI (or none), there are no gaps between GTIs to fill
if len(self.gti) <= 1:
logger.info("No bad time intervals to fill")
return copy.deepcopy(self)

# Get only the valid (non-masked) time stamps to work with
filtered_times = self.time[self.mask]

new_times = [filtered_times.copy()]
new_attrs = {}
# Determine if data is evenly sampled by comparing median time separation to dt.
# This affects how we generate new time stamps in the gaps.
mean_data_separation = np.median(np.diff(filtered_times))
if even_sampling is None:
# The time series is considered evenly sampled if the median separation between
Expand All @@ -2311,84 +2320,172 @@ def fill_bad_time_intervals(
even_sampling = True
logger.info(f"Data are {'not' if not even_sampling else ''} evenly sampled")

# Estimate how many samples typically fit in a gap and in a GTI.
# This helps determine a reasonable buffer size for sampling statistics.
gti_lengths = self.gti[:, 1] - self.gti[:, 0]
if even_sampling:
est_samples_in_gap = int(max_length / self.dt)
max_samples_per_gti = int(gti_lengths.max() / self.dt)
else:
est_samples_in_gap = int(max_length / mean_data_separation)
max_samples_per_gti = int(gti_lengths.max() / mean_data_separation)

if buffer_size is None:
buffer_size = max(100, est_samples_in_gap)
# Ideal buffer size: a lot longer than the estimated number of samples in the gap
# (10x) to get good statistics for random sampling
buffer_size = est_samples_in_gap * 10
# However, this might be too much for short GTIs, so we set an upper limit to the
# buffer size equal to the length of the longest GTI.
buffer_size = min(buffer_size, max_samples_per_gti)
# Still, no less than 1
buffer_size = max(buffer_size, 1)

# Eliminate series of GTIs that are shorter than max_length, to avoid particularly
# problematic situations (ex. in NICER where many short GTIs can occur)
gti = eliminate_short_gtis(self.gti, max_length / 2, only_repeated=True)
if len(gti) < len(self.gti):
new_ts = self.apply_gtis(gti, inplace=False)
filtered_times = new_ts.time
else:
new_ts = self

# Initialize output containers: start with the original valid data
new_times = [filtered_times.copy()]
new_attrs = {}

# Get Bad Time Intervals (BTIs) - the gaps between GTIs that we might fill
btis = get_btis(gti, self.time[0], self.time[-1])

# Track which new GTIs we're adding (one for each filled gap)
added_gtis = []

total_filled_time = 0
for bti in btis:
for bti in show_progress(btis):
length = bti[1] - bti[0]

# Skip gaps that are too long - we only fill short gaps
if length > max_length:
continue

logger.info(f"Filling bad time interval {bti} ({length:.4f} s)")
epsilon = 1e-5 * length
added_gtis.append([bti[0] - epsilon, bti[1] + epsilon])

# Find the indices in filtered_times closest to the gap boundaries.
# filt_low_idx: index of the last valid point before the gap
# filt_hig_idx: index of the first valid point after the gap
filt_low_t, filt_low_idx = find_nearest(filtered_times, bti[0])
filt_hig_t, filt_hig_idx = find_nearest(filtered_times, bti[1], side="right")

# ---------- Generate new time stamps for the gap ----------
if even_sampling:
local_new_times = np.arange(bti[0] + self.dt / 2, bti[1], self.dt)
# For evenly sampled data: create a regular time grid within the gap.
# Start at bti[0] + dt/2 to place bin centers correctly.
local_new_times = np.arange(bti[0] + new_ts.dt / 2, bti[1], new_ts.dt)
nevents = local_new_times.size
else:
# For unevenly sampled data (e.g., event lists): estimate count rate
# from the buffer regions on either side of the gap, then draw a
# Poisson-distributed number of events and place them uniformly.

# Get buffer_size points before the gap (but stay within valid range)
low_time_arr = filtered_times[max(filt_low_idx - buffer_size, 0) : filt_low_idx]
# Only keep points that are within buffer_size seconds of the gap
low_time_arr = low_time_arr[low_time_arr > bti[0] - buffer_size]

# Get buffer_size points after the gap
high_time_arr = filtered_times[filt_hig_idx : buffer_size + filt_hig_idx]
high_time_arr = high_time_arr[high_time_arr < bti[1] + buffer_size]

# Calculate count rate from the buffer before the gap
# count_rate = number_of_events / time_span
if len(low_time_arr) > 0 and (filt_low_t - low_time_arr[0]) > 0:
ctrate_low = np.count_nonzero(low_time_arr) / (filt_low_t - low_time_arr[0])
ctrate_low = np.size(low_time_arr) / (filt_low_t - low_time_arr[0])
else:
ctrate_low = np.nan

# Calculate count rate from the buffer after the gap
if len(high_time_arr) > 0 and (high_time_arr[-1] - filt_hig_t) > 0:
ctrate_high = np.count_nonzero(high_time_arr) / (high_time_arr[-1] - filt_hig_t)
ctrate_high = np.size(high_time_arr) / (high_time_arr[-1] - filt_hig_t)
else:
ctrate_high = np.nan

# If we couldn't estimate count rate from either side, skip this gap
if not np.isfinite(ctrate_low) and not np.isfinite(ctrate_high):
warnings.warn(
f"No valid data around to simulate the time series in interval "
f"{bti[0]:g}-{bti[1]:g}. Skipping. Please check that the buffer size is "
f"adequate."
)
continue

# Average the count rates (nanmean ignores NaN values)
ctrate = np.nanmean([ctrate_low, ctrate_high])
nevents = rs.poisson(ctrate * (bti[1] - bti[0]))
local_new_times = rs.uniform(bti[0], bti[1], nevents)
expected_counts = ctrate * (bti[1] - bti[0])
# Draw number of events from Poisson distribution with expected value = rate * duration
nevents = rs.poisson(expected_counts)
# Place events uniformly within the gap
local_new_times = np.sort(rs.uniform(bti[0], bti[1], nevents))

# Create a new GTI that covers this gap (with tiny epsilon margin to avoid
# floating-point edge issues when merging GTIs later)
epsilon = 1e-5 * length
added_gtis.append([bti[0] - epsilon, bti[1] + epsilon])

# Add the new times to our collection
new_times.append(local_new_times)

# ---------- Fill array attributes with random samples from buffer ----------
for attr in attrs_to_randomize:
low_arr = getattr(self, attr)[max(buffer_size - filt_low_idx, 0) : filt_low_idx]
high_arr = getattr(self, attr)[filt_hig_idx : buffer_size + filt_hig_idx]
# Get attribute values from buffer before the gap
# Note: the indexing here seems to have a bug - should probably be:
# getattr(self, attr)[max(filt_low_idx - buffer_size, 0) : filt_low_idx]
low_arr = getattr(new_ts, attr)[max(filt_low_idx - buffer_size, 0) : filt_low_idx]
# Get attribute values from buffer after the gap
high_arr = getattr(new_ts, attr)[filt_hig_idx : buffer_size + filt_hig_idx]

# Initialize the output list with the original valid data (only once)
if attr not in new_attrs:
new_attrs[attr] = [getattr(self, attr)[self.mask]]
new_attrs[attr] = [getattr(new_ts, attr)[new_ts.mask]]

# Randomly sample from the combined buffer to fill the gap.
# This preserves the empirical distribution of values near the gap.
new_attrs[attr].append(rs.choice(np.concatenate([low_arr, high_arr]), nevents))

# ---------- Handle attributes that shouldn't be randomized ----------
for attr in attrs_to_leave_alone:
if attr not in new_attrs:
new_attrs[attr] = [getattr(self, attr)[self.mask]]
new_attrs[attr] = [getattr(new_ts, attr)[new_ts.mask]]

# For mask: new points are always valid (True)
# For other attrs: fill with NaN to indicate simulated/unknown values
if attr == "_mask":
new_attrs[attr].append(np.ones(nevents, dtype=bool))
else:
new_attrs[attr].append(np.zeros(nevents) + np.nan)

total_filled_time += length

logger.info(f"A total of {total_filled_time} s of data were simulated")

new_gtis = join_gtis(self.gti, added_gtis)
# ==================== PHASE 7: ASSEMBLE OUTPUT ====================
# Merge the original GTIs with the new GTIs covering filled gaps
new_gtis = join_gtis(gti, added_gtis)

# Concatenate all time arrays (original + filled gaps) and sort by time
new_times = np.concatenate(new_times)
order = np.argsort(new_times)
new_obj = type(self)()

# Create new time series object of the same type as self
new_obj = type(new_ts)()
new_obj.time = new_times[order]

# Copy all metadata attributes (scalars like mjdref, dt, etc.)
for attr in self.meta_attrs():
setattr(new_obj, attr, getattr(self, attr))

# Set array attributes, applying the same sort order used for time
for attr, values in new_attrs.items():
setattr(new_obj, attr, np.concatenate(values)[order])

new_obj.gti = new_gtis
return new_obj

Expand Down Expand Up @@ -2506,7 +2603,7 @@ def plot(
alpha=0.5,
facecolor="r",
zorder=1,
edgecolor="none",
edgecolor="r",
)
return ax

Expand Down Expand Up @@ -2660,12 +2757,21 @@ def analyze_segments(self, func, segment_size, fraction_step=1, **kwargs):
True
"""

even_sampling = False
mean_data_separation = np.median(np.diff(self.time))
if (
self.dt is not None
and self.dt > 0
and np.isclose(mean_data_separation, self.dt, rtol=0.01)
):
even_sampling = True

if segment_size is None:
start_times = self.gti[:, 0]
stop_times = self.gti[:, 1]
start = np.searchsorted(self.time, start_times)
stop = np.searchsorted(self.time, stop_times)
elif self.dt > 0:
elif even_sampling:
start, stop = bin_intervals_from_gtis(
self.gti, segment_size, self.time, fraction_step=fraction_step, dt=self.dt
)
Expand Down
70 changes: 70 additions & 0 deletions stingray/gti.py
Original file line number Diff line number Diff line change
Expand Up @@ -1848,3 +1848,73 @@ def split_gtis_by_exposure(gtis, exposure_per_chunk, new_interval_if_gti_sep=Non

vals = split_gtis_at_indices(gtis, index_list)
return vals


def eliminate_short_gtis(gtis, min_length, only_repeated=False):
"""Eliminate GTIs shorter than a given length.

Parameters
----------
gtis : 2-d float array
List of GTIs of the form ``[[gti0_0, gti0_1], [gti1_0, gti1_1], ...]``
min_length : float
Minimum length of GTIs to keep, in seconds

Other Parameters
----------------
only_repeated : bool
If True, only eliminate GTIs that are shorter than the given length and are
separated by less than the given length from another problematic GTI. If False, eliminate
all GTIs shorter than the given length, regardless of their separation from other GTIs.

Returns
-------
new_gtis : 2-d float array
List of GTIs of the form ``[[gti0_0, gti0_1], [gti1_0, gti1_1], ...]``,
with GTIs shorter than the given length removed

Examples
--------
``min_length`` is shorter than all GTIs, so no GTI is removed:
>>> gtis = [[0, 30], [40, 41], [90, 120], [130, 160]]
>>> new_gtis = eliminate_short_gtis(gtis, 0.5)
>>> assert np.allclose(new_gtis, [[0, 30], [40, 41], [90, 120], [130, 160]])

``min_length`` is longer than one GTI, so that GTI is removed:
>>> new_gtis = eliminate_short_gtis(gtis, 2)
>>> assert np.allclose(new_gtis, [[0, 30], [90, 120], [130, 160]])

``min_length`` is longer than one GTI, but ``only_repeated=True`` only removes GTIs
that are part of a repeated pattern of short GTIs:
>>> new_gtis = eliminate_short_gtis(gtis, 2, only_repeated=True)
>>> assert np.allclose(new_gtis, [[0, 30], [40, 41], [90, 120], [130, 160]])

Two short GTIs in a row, separated by less than ``min_length``, are removed when
``only_repeated=True``:
>>> gtis = [[0, 30], [40, 41], [41.5, 42], [90, 120], [130, 160]]
>>> new_gtis = eliminate_short_gtis(gtis, 2, only_repeated=True)
>>> assert np.allclose(new_gtis, [[0, 30], [90, 120], [130, 160]])

Four short GTIs in a row and ``only_repeated=True``, but only three of them are separated by
less than ``min_length``. Those three are removed, while the one that is separated by more
than ``min_length`` from the others is kept:
>>> gtis = [[0, 30], [31, 32], [40, 40.5], [41, 41.3], [41.5, 42], [90, 120], [130, 160]]
>>> new_gtis = eliminate_short_gtis(gtis, 2, only_repeated=True)
>>> assert np.allclose(new_gtis, [[0, 30], [31, 32], [90, 120], [130, 160]])

"""
gtis = np.asanyarray(gtis)
lengths = gtis[:, 1] - gtis[:, 0]
if only_repeated:
separation = np.diff(gtis[:, 0])
short_mask = lengths < min_length
# Find two bads in a row, close to one another.

mask = np.ones_like(short_mask, dtype=bool)
for i in range(len(short_mask) - 1):
if (short_mask[i] and short_mask[i + 1]) and separation[i] < min_length:
mask[i] = False
mask[i + 1] = False
else:
mask = lengths >= min_length
return gtis[mask]
Loading
Loading