Skip to content

Commit 1eb9194

Browse files
authored
Merge pull request #87 from SWIFTSIM/box-size-correction
Box size correction
2 parents 9a48eb5 + 647c155 commit 1eb9194

File tree

4 files changed

+227
-2
lines changed

4 files changed

+227
-2
lines changed
Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
#!/usr/bin/env python
2+
3+
"""
4+
Compute a box size correction file that can be used as the 'box_size_correction'
5+
argument for an autoplotter plot.
6+
7+
Usage:
8+
velociraptor-compute-box-size-correction \
9+
smallbox largebox plotname plottype output
10+
11+
with:
12+
- smallbox/largebox: data*.yml output file from a pipeline run
13+
- plotname: Name of a particular plot in the data*.yml files
14+
- plottype: Type of plot (currently supported: mass_function)
15+
- output: Name of an output .yml file. If the .yml extension is missing, it is
16+
added.
17+
"""
18+
19+
import argparse
20+
import os
21+
import yaml
22+
import numpy as np
23+
import scipy.interpolate as interpol
24+
25+
argparser = argparse.ArgumentParser("Compute the box size correction for a plot.")
26+
argparser.add_argument(
27+
"smallbox", help="Pipeline output for the small box that needs to be corrected."
28+
)
29+
argparser.add_argument(
30+
"largebox", help="Pipeline output for the large box that we want to correct to."
31+
)
32+
argparser.add_argument("plotname", help="Name of the plot that we want to correct.")
33+
argparser.add_argument("plottype", help="Type of the plot we want to correct.")
34+
argparser.add_argument(
35+
"output", help="Name of the output file that will store the correction."
36+
)
37+
args = argparser.parse_args()
38+
39+
if not args.plottype in ["mass_function"]:
40+
raise AttributeError(
41+
f"Cannot compute box size correction for plot type {args.plottype}!"
42+
)
43+
44+
log_x = False
45+
log_y = False
46+
if args.plottype in ["mass_function"]:
47+
log_x = True
48+
log_y = True
49+
50+
small_box = args.smallbox
51+
large_box = args.largebox
52+
for file in [args.smallbox, args.largebox]:
53+
if not os.path.exists(file):
54+
raise AttributeError(f"File {file} could not be found!")
55+
56+
output_file = args.output
57+
if not output_file.endswith(".yml"):
58+
output_file += ".yml"
59+
try:
60+
open(output_file, "w").close()
61+
except:
62+
raise AttributeError(f"Can not write to {output_file}!")
63+
64+
with open(args.smallbox, "r") as handle:
65+
small_box = yaml.safe_load(handle)
66+
with open(args.largebox, "r") as handle:
67+
large_box = yaml.safe_load(handle)
68+
69+
try:
70+
small_box_data = small_box[args.plotname]["lines"]
71+
except:
72+
raise AttributeError(f"Could not find {args.plotname} in {args.smallbox}!")
73+
try:
74+
large_box_data = large_box[args.plotname]["lines"]
75+
except:
76+
raise AttributeError(f"Could not find {args.plotname} in {args.largebox}!")
77+
78+
try:
79+
small_box_plot_data = small_box_data[args.plottype]
80+
except:
81+
raise AttributeError(
82+
f"{args.plottype} not found in plot {args.plotname} in {args.smallbox}!"
83+
)
84+
try:
85+
large_box_plot_data = large_box_data[args.plottype]
86+
except:
87+
raise AttributeError(
88+
f"{args.plottype} not found in plot {args.plotname} in {args.largebox}!"
89+
)
90+
91+
small_box_x = small_box_plot_data["centers"]
92+
small_box_y = small_box_plot_data["values"]
93+
large_box_x = large_box_plot_data["centers"]
94+
large_box_y = large_box_plot_data["values"]
95+
96+
if log_x:
97+
small_box_x = np.log10(small_box_x)
98+
large_box_x = np.log10(large_box_x)
99+
100+
if log_y:
101+
small_box_y = np.log10(small_box_y)
102+
large_box_y = np.log10(large_box_y)
103+
104+
small_spline = interpol.InterpolatedUnivariateSpline(small_box_x, small_box_y)
105+
large_spline = interpol.InterpolatedUnivariateSpline(large_box_x, large_box_y)
106+
107+
xmin = max(small_box_x.min(), large_box_x.min())
108+
xmax = min(small_box_x.max(), large_box_x.max())
109+
x_range = np.linspace(xmin, xmax, 100)
110+
small_y_range = small_spline(x_range)
111+
large_y_range = large_spline(x_range)
112+
113+
if log_y:
114+
small_y_range = 10.0**small_y_range
115+
large_y_range = 10.0**large_y_range
116+
117+
correction = large_y_range / small_y_range
118+
119+
correction_data = {}
120+
correction_data["plot_name"] = args.plotname
121+
correction_data["plot_type"] = args.plottype
122+
correction_data["is_log_x"] = True
123+
correction_data["x_units"] = small_box_plot_data["centers_units"]
124+
correction_data["x_limits"] = np.array([xmin, xmax]).tolist()
125+
correction_data["x"] = x_range.tolist()
126+
correction_data["y"] = correction.tolist()
127+
with open(output_file, "w") as handle:
128+
yaml.safe_dump(correction_data, handle)
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
"""
2+
Functionality to apply a mass dependent correction to quantities that have been
3+
binned in mass bins (e.g. a mass function).
4+
"""
5+
6+
import numpy as np
7+
import yaml
8+
import os
9+
import scipy.interpolate as interpol
10+
import unyt
11+
from typing import Tuple
12+
13+
14+
class VelociraptorBoxSizeCorrection:
15+
def __init__(self, filename: str, correction_directory: str):
16+
correction_file = f"{correction_directory}/{filename}"
17+
if not os.path.exists(correction_file):
18+
raise FileNotFoundError(f"Could not find {correction_file}!")
19+
with open(correction_file, "r") as handle:
20+
correction_data = yaml.safe_load(handle)
21+
self.is_log_x = correction_data["is_log_x"]
22+
self.x_min, self.x_max = correction_data["x_limits"]
23+
x = np.array(correction_data["x"])
24+
y = np.array(correction_data["y"])
25+
self.correction_spline = interpol.InterpolatedUnivariateSpline(x, y)
26+
27+
def apply_mass_function_correction(
28+
self,
29+
mass_function_output: Tuple[unyt.unyt_array, unyt.unyt_array, unyt.unyt_array],
30+
) -> Tuple[unyt.unyt_array, unyt.unyt_array, unyt.unyt_array]:
31+
32+
bin_centers, mass_function, error = mass_function_output
33+
34+
x_vals = bin_centers
35+
correction = np.ones(x_vals.shape)
36+
if self.is_log_x:
37+
x_vals = np.log10(x_vals)
38+
# only apply the correction to bins that are within the range for which
39+
# the correction is valid
40+
x_mask = (self.x_min <= x_vals) & (x_vals <= self.x_max)
41+
correction[x_mask] = self.correction_spline(x_vals[x_mask])
42+
43+
corrected_mass_function = mass_function * correction
44+
corrected_mass_function.name = mass_function.name
45+
46+
return bin_centers, corrected_mass_function, error

velociraptor/autoplotter/lines.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
)
1919
from velociraptor.tools.histogram import create_histogram_given_bins
2020
from velociraptor.tools.adaptive import create_adaptive_bins
21+
from velociraptor.autoplotter.box_size_correction import VelociraptorBoxSizeCorrection
2122

2223
valid_line_types = [
2324
"median",
@@ -61,6 +62,8 @@ class VelociraptorLine(object):
6162
bins: unyt_array = None
6263
# Scatter can be: "none", "errorbar", or "shaded"
6364
scatter: str
65+
# Box size correction?
66+
box_size_correction: Union[None, VelociraptorBoxSizeCorrection]
6467
# Output: centers, values, scatter, additional_x, additional_y - initialised here
6568
# to prevent crashes in other code.
6669
output: Tuple[unyt_array] = (
@@ -71,7 +74,12 @@ class VelociraptorLine(object):
7174
unyt_array([]),
7275
)
7376

74-
def __init__(self, line_type: str, line_data: Dict[str, Union[Dict, str]]):
77+
def __init__(
78+
self,
79+
line_type: str,
80+
line_data: Dict[str, Union[Dict, str]],
81+
box_size_correction: Union[None, VelociraptorBoxSizeCorrection] = None,
82+
):
7583
"""
7684
Initialise a line with data from the yaml file.
7785
"""
@@ -82,6 +90,8 @@ def __init__(self, line_type: str, line_data: Dict[str, Union[Dict, str]]):
8290
self.data = line_data
8391
self._parse_data()
8492

93+
self.box_size_correction = box_size_correction
94+
8595
return
8696

8797
def _parse_line_type(self):
@@ -265,6 +275,10 @@ def create_line(
265275
mass_function_output = create_mass_function_given_bins(
266276
masked_x, self.bins, box_volume=box_volume
267277
)
278+
if self.box_size_correction is not None:
279+
mass_function_output = self.box_size_correction.apply_mass_function_correction(
280+
mass_function_output
281+
)
268282
self.output = (
269283
*mass_function_output,
270284
unyt_array([], units=mass_function_output[0].units),
@@ -309,6 +323,10 @@ def create_line(
309323
box_volume=box_volume,
310324
return_bin_edges=True,
311325
)
326+
if self.box_size_correction is not None:
327+
mass_function_output = self.box_size_correction.apply_mass_function_correction(
328+
mass_function_output
329+
)
312330
self.output = (
313331
*mass_function_output,
314332
unyt_array([], units=mass_function_output[0].units),

velociraptor/autoplotter/objects.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
from velociraptor import VelociraptorCatalogue
66
from velociraptor.autoplotter.lines import VelociraptorLine, valid_line_types
7+
from velociraptor.autoplotter.box_size_correction import VelociraptorBoxSizeCorrection
78
from velociraptor.exceptions import AutoPlotterError
89
from velociraptor.observations import load_observations
910

@@ -87,6 +88,9 @@ class VelociraptorPlot(object):
8788
exclude_structure_type: Union[None, int]
8889
structure_mask: Union[None, array]
8990
selection_mask: Union[None, array]
91+
# Apply a box size correction to the plot?
92+
correction_directory: str
93+
box_size_correction: Union[None, VelociraptorBoxSizeCorrection]
9094
# Where should the legend and z, a information be placed?
9195
legend_loc: str
9296
redshift_loc: str
@@ -103,13 +107,15 @@ def __init__(
103107
filename: str,
104108
data: Dict[str, Union[Dict, str]],
105109
observational_data_directory: str,
110+
correction_directory: str,
106111
):
107112
"""
108113
Initialise the plot object variables.
109114
"""
110115
self.filename = filename
111116
self.data = data
112117
self.observational_data_directory = observational_data_directory
118+
self.correction_directory = correction_directory
113119

114120
self._parse_data()
115121

@@ -523,6 +529,13 @@ def _parse_massfunction(self) -> None:
523529

524530
self._parse_common_histogramtype()
525531

532+
try:
533+
box_size_correction = str(self.data["box_size_correction"])
534+
self.box_size_correction = VelociraptorBoxSizeCorrection(
535+
box_size_correction, self.correction_directory
536+
)
537+
except KeyError:
538+
self.box_size_correction = None
526539
# A bit of a hacky workaround - improve this in the future
527540
# by combining this functionality properly into the
528541
# VelociraptorLine methods.
@@ -535,6 +548,7 @@ def _parse_massfunction(self) -> None:
535548
start=dict(value=self.x_lim[0].value, units=self.x_lim[0].units),
536549
end=dict(value=self.x_lim[1].value, units=self.x_lim[1].units),
537550
),
551+
box_size_correction=self.box_size_correction,
538552
)
539553

540554
return
@@ -550,6 +564,13 @@ def _parse_adaptivemassfunction(self) -> None:
550564

551565
self._parse_common_histogramtype()
552566

567+
try:
568+
box_size_correction = str(self.data["box_size_correction"])
569+
self.box_size_correction = VelociraptorBoxSizeCorrection(
570+
box_size_correction, self.correction_directory
571+
)
572+
except KeyError:
573+
self.box_size_correction = None
553574
# A bit of a hacky workaround - improve this in the future
554575
# by combining this functionality properly into the
555576
# VelociraptorLine methods.
@@ -563,6 +584,7 @@ def _parse_adaptivemassfunction(self) -> None:
563584
end=dict(value=self.x_lim[1].value, units=self.x_lim[1].units),
564585
adaptive=True,
565586
),
587+
box_size_correction=self.box_size_correction,
566588
)
567589

568590
return
@@ -1042,6 +1064,8 @@ class AutoPlotter(object):
10421064
plots: List[VelociraptorPlot]
10431065
# Directory containing the observational data.
10441066
observational_data_directory: str
1067+
# Directory containing box size correction data
1068+
correction_directory: str
10451069
# Whether or not the plots were created successfully.
10461070
created_successfully: List[bool]
10471071
# global mask
@@ -1051,6 +1075,7 @@ def __init__(
10511075
self,
10521076
filename: Union[str, List[str]],
10531077
observational_data_directory: Union[None, str] = None,
1078+
correction_directory: Union[None, str] = None,
10541079
) -> None:
10551080
"""
10561081
Initialises the AutoPlotter object with the yaml filename(s).
@@ -1075,6 +1100,9 @@ def __init__(
10751100
if observational_data_directory is not None
10761101
else ""
10771102
)
1103+
self.correction_directory = Path(
1104+
correction_directory if correction_directory is not None else ""
1105+
)
10781106

10791107
self.load_yaml()
10801108
self.parse_yaml()
@@ -1105,7 +1133,12 @@ def parse_yaml(self):
11051133
"""
11061134

11071135
self.plots = [
1108-
VelociraptorPlot(filename, plot, self.observational_data_directory)
1136+
VelociraptorPlot(
1137+
filename,
1138+
plot,
1139+
self.observational_data_directory,
1140+
self.correction_directory,
1141+
)
11091142
for filename, plot in self.yaml.items()
11101143
]
11111144

0 commit comments

Comments
 (0)