Skip to content

Commit 2a50e33

Browse files
author
martinkilbinger
committed
added optional weighted R matrix
1 parent 9a438eb commit 2a50e33

File tree

1 file changed

+87
-64
lines changed

1 file changed

+87
-64
lines changed

notebooks/calibrate_comprehensive_cat.py

Lines changed: 87 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,18 @@
1-
# ---
2-
# jupyter:
3-
# jupytext:
4-
# text_representation:
5-
# extension: .py
6-
# format_name: light
7-
# format_version: '1.5'
8-
# jupytext_version: 1.15.1
9-
# kernelspec:
10-
# display_name: Python 3
11-
# language: python
12-
# name: python3
13-
# ---
14-
15-
# # Calibrate comprehensive catalogue
16-
17-
# %reload_ext autoreload
18-
# %autoreload 2
1+
# %%
2+
# Calibrate comprehensive catalogue
193

4+
# %%
5+
from IPython import get_ipython
6+
7+
# %%
8+
# enable autoreload for interactive sessions
9+
ipython = get_ipython()
10+
if ipython is not None:
11+
ipython.run_line_magic("load_ext", "autoreload")
12+
ipython.run_line_magic("autoreload", "2")
13+
ipython.run_line_magic("load_ext", "log_cell_time")
14+
15+
# %%
2016
import sys
2117
import os
2218
import numpy as np
@@ -29,40 +25,47 @@
2925
from sp_validation import calibration
3026
import sp_validation.cat as cat
3127

28+
# %%
3229
# Initialize calibration class instance
3330
obj = sp_joint.CalibrateCat()
3431

32+
# %%
3533
# Read configuration file and set parameters
3634
config = obj.read_config_set_params("config_mask.yaml")
3735

38-
# !pwd
39-
40-
# +
36+
# %%
4137
# Get data. Set load_into_memory to False for very large files
42-
43-
4438
dat, dat_ext = obj.read_cat(load_into_memory=False)
45-
# -
4639

47-
n_test = -1
48-
#n_test = 100000
40+
# %%
41+
#n_test = -1
42+
n_test = 100000
4943
if n_test > 0:
5044
print(f"MKDEBUG testing only first {n_test} objects")
5145
dat = dat[:n_test]
5246
dat_ext = dat_ext[:n_test]
5347

48+
49+
50+
# %%
5451
# ## Masking
5552

53+
# %%
5654
# ### Pre-processing ShapePipe flags
57-
58-
masks, labels = sp_joint.get_masks_from_config(config, dat, dat_ext)
55+
masks, labels = sp_joint.get_masks_from_config(
56+
config,
57+
dat,
58+
dat_ext,
59+
verbose=True
60+
)
5961

6062
mask_combined = sp_joint.Mask.from_list(
6163
masks,
6264
label="combined",
6365
verbose=obj._params["verbose"],
6466
)
6567

68+
# %%
6669
# Output some mask statistics
6770
sp_joint.print_mask_stats(dat.shape[0], masks, mask_combined)
6871

@@ -74,13 +77,13 @@
7477

7578
sp_joint.sky_plots(dat, masks, labels, zoom_ra, zoom_dec)
7679

80+
# %%
7781
# ### Calibration
7882

79-
# +
8083
# Call metacal
81-
8284
cm = config["metacal"]
8385

86+
# %%
8487
gal_metacal = metacal(
8588
dat,
8689
mask_combined._mask,
@@ -90,19 +93,23 @@
9093
rel_size_max=cm["gal_rel_size_max"],
9194
size_corr_ell=cm["gal_size_corr_ell"],
9295
sigma_eps=cm["sigma_eps_prior"],
96+
global_R_weight=cm["global_R_weight"],
9397
col_2d=False,
9498
verbose=True,
9599
)
96-
# -
97100

98-
g_corr_mc, g_uncorr, w, mask_metacal, c, c_err = calibration.get_calibrated_m_c(gal_metacal)
101+
# %%
102+
g_corr_mc, g_uncorr, w, mask_metacal, c, c_err = (
103+
calibration.get_calibrated_m_c(gal_metacal)
104+
)
99105

100106
num_ok = len(g_corr_mc[0])
107+
num_obj = len(dat)
101108
sp_joint.Mask.print_strings(
102109
"metacal", "gal selection", str(num_ok), f"{num_ok / num_obj:10.2%}"
103110
)
104111

105-
# +
112+
# %%
106113
# Compute DES weights
107114

108115
cat_gal = {}
@@ -117,7 +124,7 @@
117124
num_bins=20,
118125
)
119126

120-
# +
127+
# %%
121128
# Correct for PSF leakage
122129

123130
alpha_1, alpha_2 = sp_joint.compute_PSF_leakage(
@@ -128,27 +135,26 @@
128135
mask_metacal,
129136
num_bins=20,
130137
)
131-
# -
132138

133139
# Compute leakage-corrected ellipticities
134140
e1_leak_corrected = g_corr_mc[0] - alpha_1 * cat_gal["e1_PSF"]
135141
e2_leak_corrected = g_corr_mc[1] - alpha_2 * cat_gal["e2_PSF"]
136142

143+
# %%
137144
# Get some memory back
138145
for mask in masks:
139146
del mask
140147

141-
# +
148+
# %%
149+
# %%
142150
# Additional quantities
143-
R_shear = np.mean(gal_metacal.R_shear, 2)
144-
145151
ra = cat.get_col(dat, "RA", mask_combined._mask, mask_metacal)
146152
dec = cat.get_col(dat, "Dec", mask_combined._mask, mask_metacal)
147153
mag = cat.get_col(dat, "mag", mask_combined._mask, mask_metacal)
148154

149155

150-
# +
151-
156+
# %%
157+
# Additional data columns to write to output cat
152158
add_cols = [
153159
"w_iv",
154160
"FLUX_RADIUS",
@@ -161,36 +167,57 @@
161167
"FLUXERR_AUTO",
162168
"FLUX_APER",
163169
"FLUXERR_APER",
170+
"NGMIX_T_NOSHEAR",
171+
"NGMIX_Tpsf_NOSHEAR",
172+
"fwhm_PSF",
164173
]
165174
add_cols_data = {}
166175
for key in add_cols:
167-
add_cols_data[key] = dat[key][mask_combined._mask][mask_metacal]
168-
169-
# +
176+
add_cols_data[key] = cat.get_col(
177+
dat,
178+
key,
179+
mask_combined._mask,
180+
mask_metacal
181+
)
182+
#add_cols_data[key] = dat[key][mask_combined._mask][mask_metacal]
183+
184+
# %%
185+
# Additional post-processing columns to write to output cat
186+
add_cols_post = [
187+
"R_g11",
188+
"R_g12",
189+
"R_g21",
190+
"R_g22",
191+
"e1_PSF",
192+
"e2_PSF",
193+
]
194+
for key in add_cols_post:
195+
add_cols_data[key] = cat_gal[key]
170196

197+
# %%
198+
# Other additional columns
171199
add_cols_data["e1_leak_corrected"] = e1_leak_corrected
172200
add_cols_data["e2_leak_corrected"] = e2_leak_corrected
173201

174-
add_cols_data["e1_PSF"] = cat_gal["e1_PSF"]
175-
add_cols_data["e2_PSF"] = cat_gal["e2_PSF"]
176-
add_cols_data["fwhm_PSF"] = cat.get_col(
177-
dat, "fwhm_PSF", mask_combined._mask, mask_metacal
178-
)
179-
180-
# +
202+
# %%
181203
# Add information to FITS header
182204

183205
# Generate new header
184206
header = fits.Header()
185207

186-
# Add general config information to FITS header
187-
obj.add_params_to_FITS_header(header)
208+
# Add general and metacal config information to FITS header
209+
obj.add_params_to_FITS_header(header, cm=cm)
188210

211+
# %%
189212
# Add mask information to FITS header
190213
for my_mask in masks:
191214
my_mask.add_summary_to_FITS_header(header)
192215

193-
# +
216+
217+
# %%
218+
header
219+
220+
# %%
194221
output_shape_cat_path = obj._params["input_path"].replace(
195222
"comprehensive", "cut"
196223
)
@@ -215,27 +242,23 @@
215242
add_cols=add_cols_data,
216243
add_header=header,
217244
)
218-
# -
219245

246+
# %%
220247
with open("masks.txt", "w") as f_out:
221248
for my_mask in masks:
222249
my_mask.print_summary(f_out)
223250

251+
# %%
224252
from scipy import stats
225253

226-
#
227-
228254
all_masks = masks[:-3]
229255

230-
# +
231256
if not obj._params["cmatrices"]:
232257
print("Skipping cmatric calculations")
233258
sys.exit(0)
234259

235260
r_val, r_cl = sp_joint.correlation_matrix(all_masks)
236261

237-
# +
238-
239262
n = len(all_masks)
240263
keys = [my_mask._label for my_mask in all_masks]
241264

@@ -246,9 +269,7 @@
246269
plt.colorbar()
247270
plt.savefig("correlation_matrix.png")
248271

249-
# -
250-
251-
272+
# %%
252273
n_key = len(all_masks)
253274
cms = np.zeros((n_key, n_key, 2, 2))
254275
for idx in range(n_key):
@@ -261,7 +282,7 @@
261282
res = sp_joint.confusion_matrix(masks[idx]._mask, masks[jdx]._mask)
262283
cms[idx][jdx] = res["cmn"]
263284

264-
# +
285+
# %%
265286
import seaborn as sns
266287

267288
fig = plt.figure(figsize=(30, 30))
@@ -294,6 +315,8 @@
294315

295316
plt.show(block=False)
296317
plt.savefig("confusion_matrix.png")
297-
# -
318+
298319

299320
obj.close_hd5()
321+
322+
# %%

0 commit comments

Comments
 (0)