Skip to content

Commit 8e37885

Browse files
authored
Merge pull request #712 from lincc-frameworks/instrument_name
Add option to save full filter names
2 parents e571508 + b591deb commit 8e37885

File tree

2 files changed

+38
-0
lines changed

2 files changed

+38
-0
lines changed

src/lightcurvelynx/simulate.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,8 @@ class SimulationInfo:
6161
output_file_path : str or Path, optional
6262
The file path and name of where to save the results. If provided the results
6363
are saved to this file instead of being returned directly.
64+
save_full_filter_names : bool
65+
Whether to save the full filter names in the results (including survey prefix).
6466
kwargs : dict
6567
Additional keyword arguments to pass to the simulation function.
6668
"""
@@ -79,6 +81,7 @@ def __init__(
7981
sample_offset=0,
8082
rng=None,
8183
output_file_path=None,
84+
save_full_filter_names=False,
8285
**kwargs,
8386
):
8487
self.model = model
@@ -93,6 +96,7 @@ def __init__(
9396
self.rng = rng
9497
self.kwargs = kwargs
9598
self.output_file_path = None
99+
self.save_full_filter_names = save_full_filter_names
96100

97101
if self.num_samples <= 0:
98102
raise ValueError("Number of samples must be a positive integer.")
@@ -176,6 +180,7 @@ def split(self, num_batches=None, batch_size=None):
176180
sample_offset=self.sample_offset + start_idx,
177181
rng=batch_rng,
178182
output_file_path=batch_output_file_path,
183+
save_full_filter_names=self.save_full_filter_names,
179184
**self.kwargs,
180185
)
181186
batches.append(batch_info)
@@ -349,6 +354,8 @@ def _simulate_lightcurves_batch(simulation_info):
349354
obstable_save_cols = []
350355
for col in obstable_save_cols:
351356
nested_dict[col] = []
357+
if simulation_info.save_full_filter_names:
358+
nested_dict["full_filter_name"] = []
352359

353360
# Determine which of the of the simulated positions match the pointings from each ObsTable.
354361
logger.info("Performing range searches to find matching observations.")
@@ -422,6 +429,15 @@ def _simulate_lightcurves_batch(simulation_info):
422429
total_num_obs += nobs
423430
nested_index.extend([idx] * nobs)
424431

432+
# Add the survey name from the passband information if we chose to save it.
433+
if simulation_info.save_full_filter_names:
434+
obs_filters = np.asarray(obs_filters)
435+
full_filter_names = np.empty_like(obs_filters, dtype=object)
436+
for filter_name in np.unique(obs_filters):
437+
pb_obj = passbands[survey_idx][filter_name]
438+
full_filter_names[obs_filters == filter_name] = pb_obj.full_name
439+
nested_dict["full_filter_name"].extend(list(full_filter_names))
440+
425441
# The number of observations is the total across all surveys.
426442
results_dict["nobs"][idx] = total_num_obs
427443

@@ -486,6 +502,7 @@ def simulate_lightcurves(
486502
executor=None,
487503
num_jobs=None,
488504
batch_size=100_000,
505+
save_full_filter_names=False,
489506
):
490507
"""Generate a number of simulations of the given model and information
491508
from one or more surveys. The result data can either be returned directly
@@ -547,6 +564,8 @@ def simulate_lightcurves(
547564
batch_size : int, optional
548565
The number of samples to process in each batch when using multiprocessing.
549566
Default is 100_000.
567+
save_full_filter_names : bool
568+
Whether to save the full filter names in the results (including survey prefix).
550569
551570
Returns
552571
-------
@@ -570,6 +589,7 @@ def simulate_lightcurves(
570589
obstable_save_cols=obstable_save_cols,
571590
param_cols=param_cols,
572591
output_file_path=output_file_path,
592+
save_full_filter_names=save_full_filter_names,
573593
)
574594

575595
# If we do not have any parallelization information, perform in serial.

tests/lightcurvelynx/test_simulate.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -246,6 +246,9 @@ def test_simulate_lightcurves(test_data_dir):
246246
assert len(results.loc[idx]["lightcurve"]["obs_idx"]) == num_obs
247247
assert len(np.unique(results.loc[idx]["lightcurve"]["obs_idx"])) == num_obs
248248

249+
# We do not include the full filter name unless that option is selected.
250+
assert "full_filter_name" not in results.loc[idx]["lightcurve"]
251+
249252
# Check that we extract one of the parameters.
250253
assert results["source_brightness"][idx] == given_brightness[idx]
251254

@@ -414,6 +417,7 @@ def test_simulate_parallel_threads(test_data_dir):
414417
param_cols=["source.brightness"],
415418
executor=executor,
416419
batch_size=10,
420+
save_full_filter_names=True,
417421
)
418422
assert len(results) == 100
419423
assert np.all(results["nobs"].values >= 1)
@@ -427,6 +431,14 @@ def test_simulate_parallel_threads(test_data_dir):
427431
assert num_obs >= 1
428432
assert len(results["lightcurve"][idx]["flux"]) == num_obs
429433

434+
# Check that we got the full filter names correct.
435+
for filter_name, full_name in zip(
436+
results["lightcurve"][idx]["filter"],
437+
results["lightcurve"][idx]["full_filter_name"],
438+
strict=False,
439+
):
440+
assert full_name == f"LSST_{filter_name}"
441+
430442

431443
def test_simulate_parallel_processes(test_data_dir):
432444
"""Test an end to end run of simulating the light curves paralle with processes."""
@@ -817,6 +829,7 @@ def compute_sed(self, times, wavelengths, graph_state, **kwargs):
817829
[obstable1, obstable2],
818830
[passband_group1, passband_group2],
819831
obstable_save_cols=["zp", "custom_col"],
832+
save_full_filter_names=True,
820833
)
821834
assert len(results) == 1
822835
assert results["nobs"][0] == 6
@@ -825,6 +838,11 @@ def compute_sed(self, times, wavelengths, graph_state, **kwargs):
825838
lightcurve = results["lightcurve"][0]
826839
assert np.all(lightcurve["flux_perfect"][0:3] > 800.0)
827840
assert np.all(lightcurve["flux_perfect"][3:6] < 500.0)
841+
assert np.array_equal(lightcurve["filter"], ["r"] * 6)
842+
assert np.array_equal(
843+
lightcurve["full_filter_name"],
844+
["survey1_r", "survey1_r", "survey1_r", "survey2_r", "survey2_r", "survey2_r"],
845+
)
828846

829847

830848
def test_compute_noise_free_lightcurves_single(test_data_dir):

0 commit comments

Comments
 (0)