Skip to content

Commit ae8eabd

Browse files
committed
Merge branch 'develop'
2 parents c93547b + bb0d9e0 commit ae8eabd

File tree

10 files changed

+150
-185
lines changed

10 files changed

+150
-185
lines changed

config.yaml

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ probe_selection:
6969
nz:
7070
# Paths to source (for WL) and lens (for GCph) redshift distributions.
7171
# These must have shape (z_points, zbins + 1) where z_points is the number of
72-
# redshift values over which the n(z) is measured, and format .txt or .dat
72+
# redshift values over which the n(z) is measured, and format .txt, .dat or .ascii
7373
# OR
7474
# as .fits files, following the LE3/Heracles/euclidlib format
7575
nz_sources_filename: ../common_data/Spaceborne_jobs/develop/input/nzTab-EP03-zedMin02-zedMax25-mag245.dat # Type: str
@@ -101,6 +101,12 @@ binning:
101101
ell_max: 3000 # Type: int | float. Maximum ell
102102
ell_bins: 32 # Type: int. Number of ell bins. Not used in the 'unbinned' case
103103

104+
# Path to user-provided ell binning scheme. Used only if 'binning_type' is set to 'from_input'.
105+
# The file should have the following columns:
106+
# ell, delta_ell, ell_lower_edges, ell_upper_edges
107+
# with `ell` the ell bin center, `delta_ell` its width,
108+
# `ell_lower_edges` its lower edge, and `ell_upper_edges` its upper edge.
109+
# The file can be in .txt, .dat or .ascii format.
104110
ell_bins_filename: ../common_data/Spaceborne_jobs/develop/input/ell_values_3x2pt.txt
105111

106112
# theta binning for the real-space covariance matrix
@@ -337,20 +343,27 @@ precision:
337343
# see https://namaster.readthedocs.io/en/latest/api/pymaster.field.html#pymaster.field.NmtField
338344
n_iter_nmt: null # Type: None | int.
339345

346+
347+
# ======================================= JAX ========================================
348+
# Whether to enable 64-bit precision for JAX computations. Setting this to False can result
349+
# in noticeable differences with respect to the plain numpy implementation, since the
350+
# latter defaults to 64-bit precision.
351+
jax_enable_x64: True # Type: bool.
352+
353+
340354
misc:
341355
# note if you work from an interactive environment (e.g. jupyter notebook), you might
342356
# have to restart the kernel after changing the num_threads, and jax_platform
343357
num_threads: 50 # Type: int. How many threads to use for the parallel computations
344358

345-
# jax-specific settings
346-
jax_platform: auto # Type: str. 'auto', 'cpu' or 'gpu'
347-
jax_enable_x64: True # Type: bool. Whether to enable 64-bit precision for JAX computations
359+
# Device used for JAX-accelerated computation. Can be one of 'auto', 'cpu' or 'gpu'
360+
jax_platform: auto # Type: str.
348361

349362
# some sanity checks on the covariance matrix. These could take a while to run.
350-
test_numpy_inversion: True # Type: bool. Test for errors in np.linalg.inv
351-
test_condition_number: True # Type: bool. Check if condition number is above 1e10
352-
test_cholesky_decomposition: True # Type: bool. Test if cholesky decomposition fails
353-
test_symmetry: True # Type: bool. Test if covariance matrix is symmetric (cov = cov.T)
363+
test_numpy_inversion: False # Type: bool. Test for errors in np.linalg.inv
364+
test_condition_number: False # Type: bool. Check if condition number is above 1e10
365+
test_cholesky_decomposition: False # Type: bool. Test if cholesky decomposition fails
366+
test_symmetry: False # Type: bool. Test if covariance matrix is symmetric (cov = cov.T)
354367

355368
# Whether to produce a triangle plot of the Cls, especially useful when using input
356369
# Cls for a quick visual comparison. Can be slow, set to False for a faster runtime.

main.py

Lines changed: 30 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
# ruff: noqa: E402 (ignore module import not on top of the file warnings)
22
import argparse
33
import contextlib
4-
import itertools
54
import os
65
import sys
6+
import warnings
77

88
import yaml
99

@@ -39,9 +39,26 @@ def load_config(_config_path):
3939

4040

4141
cfg = load_config('config.yaml')
42-
# Set jax platform
42+
43+
# JAX settings
4344
if cfg['misc']['jax_platform'] != 'auto':
4445
os.environ['JAX_PLATFORMS'] = cfg['misc']['jax_platform']
46+
if cfg['precision']['jax_enable_x64']:
47+
os.environ['JAX_ENABLE_X64'] = '1'
48+
else:
49+
os.environ['JAX_ENABLE_X64'] = '0'
50+
warnings.warn(
51+
'JAX 64-bit precision is disabled. This may lead to '
52+
'noticeable differences with respect to the numpy '
53+
'implementation, which uses 64-bit precision by default.',
54+
stacklevel=2,
55+
)
56+
57+
# Import JAX after environment variables are set, then print device info
58+
import jax
59+
60+
print(f'JAX devices: {jax.devices()}')
61+
print(f'JAX backend: {jax.default_backend()}')
4562

4663
# if using the CPU, set the number of threads
4764
num_threads = cfg['misc']['num_threads']
@@ -65,9 +82,9 @@ def load_config(_config_path):
6582
# override in cfg as well
6683
cfg['misc']['num_threads'] = num_threads
6784

85+
import itertools
6886
import pprint
6987
import time
70-
import warnings
7188
from functools import partial
7289

7390
import matplotlib
@@ -2370,7 +2387,12 @@ def plot_cls():
23702387
else 'Gauss'
23712388
)
23722389
cov = covs_3x2pt_2d_tosave_dict[key]
2373-
print(f'Testing cov {key}...')
2390+
2391+
print(
2392+
f'Performing sanity checks on cov {key}.\n'
2393+
'This can take some time for large matrices. '
2394+
'Please note that your files have already been saved.\n'
2395+
)
23742396

23752397
if cfg['misc']['test_condition_number']:
23762398
cond_number = np.linalg.cond(cov)
@@ -2418,16 +2440,16 @@ def plot_cls():
24182440

24192441
# note that this is *not* compatible with %matplotlib inline in the interactive window!
24202442
if cfg['misc']['save_figs']:
2421-
output_dir = f'{output_path}/figs'
2422-
os.makedirs(output_dir, exist_ok=True)
2443+
output_path_figs = f'{output_path}/figs'
2444+
os.makedirs(output_path_figs, exist_ok=True)
24232445
for i, fig_num in enumerate(plt.get_fignums()):
24242446
fig = plt.figure(fig_num)
24252447
fig.savefig(
2426-
os.path.join(output_dir, f'fig_{i:03d}.pdf'),
2448+
os.path.join(output_path_figs, f'fig_{i:03d}.pdf'),
24272449
bbox_inches='tight',
24282450
pad_inches=0.1,
24292451
)
2430-
print(f'Figures saved in {output_dir}\n')
2452+
print(f'Figures saved in {output_path_figs}\n')
24312453

24322454

24332455
print(f'Finished in {(time.perf_counter() - script_start_time) / 60:.2f} minutes')

spaceborne/ccl_interface.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -766,7 +766,7 @@ def compute_ng_cov_3x2pt(
766766
symmetrize_zpairs = (probe_a, probe_b) == (probe_c, probe_d)
767767

768768
tqdm.write(
769-
f'CCL {which_ng_cov}: computing probe combination {probe_ab, probe_cd}'
769+
f'CCL {which_ng_cov} cov: computing probe combination {probe_ab, probe_cd}'
770770
)
771771

772772
self.cov_dict[ng_term][probe_2tpl]['4d'] = self.compute_ng_cov_probe_block(
@@ -794,7 +794,7 @@ def compute_ng_cov_3x2pt(
794794
nbx=len(ells),
795795
zbins=None,
796796
ind_dict=ind_dict,
797-
msg=f'CCL {which_ng_cov}: ',
797+
msg=f'CCL {which_ng_cov} cov: ',
798798
)
799799

800800
def check_cov_blocks_symmetry(self):

spaceborne/config_checker.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -548,6 +548,9 @@ def check_types(self) -> None:
548548
assert isinstance(precision_cfg.get('cov_rs_int_method'), str), (
549549
'precision: cov_rs_int_method must be a string'
550550
)
551+
assert isinstance(precision_cfg.get('jax_enable_x64'), bool), (
552+
'precision: jax_enable_x64 must be a boolean'
553+
)
551554

552555
# misc
553556
assert isinstance(self.cfg.get('misc'), dict), (
@@ -578,9 +581,6 @@ def check_types(self) -> None:
578581
assert isinstance(misc_cfg.get('jax_platform'), str), (
579582
'misc: jax_platform must be a string'
580583
)
581-
assert isinstance(misc_cfg.get('jax_enable_x64'), bool), (
582-
'misc: jax_enable_x64 must be a boolean'
583-
)
584584
assert isinstance(misc_cfg.get('cl_triangle_plot'), bool), (
585585
'misc: cl_triangle_plot must be a boolean'
586586
)

spaceborne/cov_harmonic_space.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ def consistency_checks(self):
119119

120120
def set_gauss_cov(self, ccl_obj: CCLInterface):
121121
start = time.perf_counter()
122-
122+
123123
print('\nComputing Gaussian harmonic-space covariance matrix...')
124124

125125
# signal
@@ -154,15 +154,15 @@ def set_gauss_cov(self, ccl_obj: CCLInterface):
154154
)
155155

156156
if self.cfg['precision']['cov_hs_g_ell_bin_average']:
157-
# unbinned cls and noise; need the edges to compute the number of modes
158-
# (after casting them to int. n_modes is equivalent to delta_ell modulo the
157+
# unbinned cls and noise; need the edges to compute the number of modes
158+
# (after casting them to int. n_modes is equivalent to delta_ell modulo the
159159
# fact that for delta_ell we consider non-integer ell values)
160160
_cl_5d = self.cl_3x2pt_unb_5d
161161
_noise_5d = noise_3x2pt_unb_5d
162162
_ell_values = self.ell_obj.ells_3x2pt_unb
163163
_ell_edges = self.ell_obj.ell_edges_3x2pt
164164
else:
165-
# evaluate the covariance at the center of the ell bin and normalise by
165+
# evaluate the covariance at the center of the ell bin and normalise by
166166
# delta_ell
167167
_cl_5d = cl_3x2pt_5d
168168
_noise_5d = noise_3x2pt_5d

spaceborne/cov_partial_sky.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,7 @@ def cl_22_list(zi, zj, spin0):
178178
probe_ab, probe_cd = sl.split_probe_name(probe_abcd, space='harmonic')
179179
probe_a, probe_b, probe_c, probe_d = list(probe_abcd)
180180

181-
tqdm.write(f'computing probe combination {probe_ab, probe_cd}')
181+
tqdm.write(f'NaMaster G cov: computing probe combination {probe_ab, probe_cd}')
182182

183183
s1 = spin_dict[probe_a]
184184
s2 = spin_dict[probe_b]
@@ -244,7 +244,7 @@ def cl_22_list(zi, zj, spin0):
244244
nbx=nbl,
245245
zbins=zbins,
246246
ind_dict=ind_dict,
247-
msg='',
247+
msg='NaMaster G cov: ',
248248
)
249249

250250
return cov_dict

spaceborne/cov_ssc.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import time
22

3-
import jax
43
import jax.numpy as jnp
54
import matplotlib.pyplot as plt
65
import numpy as np
@@ -167,10 +166,6 @@ def __init__(self, cfg, pvt_cfg, ccl_obj, z_grid):
167166
self.z_grid = z_grid
168167
self.ccl_obj = ccl_obj
169168

170-
# Enable 64-bit precision if required
171-
jax.config.update('jax_enable_x64', cfg['misc']['jax_enable_x64'])
172-
print('JAX devices:', jax.devices())
173-
174169
# set some useful attributes
175170
if self.use_ke_approx:
176171
self.ssc_func = ssc_integral_4D_simps_jax_ke_approx
@@ -339,7 +334,7 @@ def compute_ssc(
339334
nbx=nbl,
340335
zbins=None,
341336
ind_dict=self.ind_dict,
342-
msg='SSC: ',
337+
msg='SSC cov: ',
343338
)
344339

345340
print(f'...done in {(time.perf_counter() - start):.2f} s')

0 commit comments

Comments
 (0)