Skip to content

Commit 619751a

Browse files
authored
Merge pull request #694 from lincc-frameworks/chaining
Add return values to post processing functions
2 parents eb485aa + f4420e7 commit 619751a

File tree

2 files changed

+29
-4
lines changed

2 files changed

+29
-4
lines changed

src/lightcurvelynx/utils/post_process_results.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,11 @@ def augment_single_lightcurve(results, *, min_snr=0.0, t0=None):
189189
Minimum SNR required to mark an entry as a detection. Default is 0.0.
190190
t0 : float or None, optional
191191
Reference time for the lightcurve.
192+
193+
Returns
194+
-------
195+
results : pandas.DataFrame
196+
The modified DataFrame (to enable chaining).
192197
"""
193198
if "flux" not in results.columns or "fluxerr" not in results.columns:
194199
raise ValueError("flux and fluxerr must be present in the light curve DataFrame.")
@@ -206,6 +211,8 @@ def augment_single_lightcurve(results, *, min_snr=0.0, t0=None):
206211
if t0 is not None and "mjd" in results.columns:
207212
results["time_rel"] = results["mjd"] - t0
208213

214+
return results
215+
209216

210217
def results_augment_lightcurves(results, *, min_snr=0.0):
211218
"""Add columns to the results DataFrame with additional information
@@ -228,6 +235,11 @@ def results_augment_lightcurves(results, *, min_snr=0.0):
228235
The DataFrame containing lightcurve data. Modified in place.
229236
min_snr : float, optional
230237
Minimum SNR required to mark an entry as a detection. Default is 0.0.
238+
239+
Returns
240+
-------
241+
results : pandas.DataFrame or nested_pandas.NestedFrame
242+
The modified DataFrame (to enable chaining).
231243
"""
232244
if not isinstance(results, NestedFrame) or "lightcurve" not in results.columns:
233245
raise ValueError("results must be a NestedFrame with a 'lightcurve' column.")
@@ -260,6 +272,8 @@ def results_augment_lightcurves(results, *, min_snr=0.0):
260272
t0_idx = np.array(results["lightcurve"]["mjd"].index)
261273
results["lightcurve.time_rel"] = results["lightcurve.mjd"] - t0[t0_idx]
262274

275+
return results
276+
263277

264278
def results_use_full_filter_names(results, passbands):
265279
"""Modifies the 'filter' column in the results DataFrame to include
@@ -272,6 +286,11 @@ def results_use_full_filter_names(results, passbands):
272286
passbands : list of PassbandGroup
273287
The list of PassbandGroups used in the simulation, in the same order
274288
as in the simulation.
289+
290+
Returns
291+
-------
292+
results : pandas.DataFrame or nested_pandas.NestedFrame
293+
The modified DataFrame (to enable chaining).
275294
"""
276295
if not isinstance(results, NestedFrame) or "lightcurve" not in results.columns:
277296
raise ValueError("results must be a NestedFrame with a 'lightcurve' column.")
@@ -293,3 +312,5 @@ def results_use_full_filter_names(results, passbands):
293312
full_name = passbands[s_idx][fil].full_name
294313
filter_names[mask] = full_name
295314
results["lightcurve.filter"] = filter_names
315+
316+
return results

tests/lightcurvelynx/utils/test_post_process_results.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -304,8 +304,10 @@ def test_results_augment_lightcurves():
304304
assert "magerr" not in results["lightcurve"].nest.columns
305305
assert results["object_id"].tolist() == [0, 1, 2]
306306

307-
# Augmenting the lightcurves should add the new columns.
308-
results_augment_lightcurves(results, min_snr=5)
307+
# Augmenting the lightcurves should add the new columns. The results object is modified in place,
308+
# but check that it also returns a reference for chaining.
309+
res2 = results_augment_lightcurves(results, min_snr=5)
310+
assert res2 is results
309311
assert len(results) == 3
310312

311313
# Check the SNR and detection markings.
@@ -524,8 +526,10 @@ def test_results_use_full_filter_names():
524526
assert np.array_equal(results["lightcurve.filter"][0].tolist(), ["g", "r"])
525527
assert np.array_equal(results["lightcurve.filter"][1].tolist(), ["g", "g", "r", "r"])
526528

527-
# Transform to full filter names.
528-
results_use_full_filter_names(results, [passbands1, passbands2])
529+
# Transform to full filter names. The results object is modified in place,
530+
# but check that it also returns a reference for chaining.
531+
res2 = results_use_full_filter_names(results, [passbands1, passbands2])
532+
assert res2 is results
529533
assert len(results) == 2
530534
assert np.array_equal(
531535
results["lightcurve.filter"][0].tolist(),

0 commit comments

Comments
 (0)