Skip to content

Commit d952f19

Browse files
committed
introduce cov_terms_and_codes
1 parent c7f3543 commit d952f19

File tree

1 file changed

+71
-25
lines changed

1 file changed

+71
-25
lines changed

main.py

Lines changed: 71 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -481,6 +481,12 @@ def plot_cls():
481481
if cfg['covariance']['cNG'] and cfg['covariance']['cNG_code'] == 'PyCCL':
482482
compute_ccl_cng = True
483483

484+
cov_terms_and_codes = {
485+
'G': cfg['covariance']['G_code'] if cfg['covariance']['G'] else False,
486+
'SSC': cfg['covariance']['SSC_code'] if cfg['covariance']['SSC'] else False,
487+
'cNG': cfg['covariance']['cNG_code'] if cfg['covariance']['cNG'] else False,
488+
}
489+
484490
_condition = 'GLGL' in req_probe_combs_hs_2d or 'gtgt' in req_probe_combs_rs_2d
485491
if compute_ccl_cng and _condition:
486492
warnings.warn(
@@ -1057,7 +1063,7 @@ def plot_cls():
10571063
cl_ll_3d = bnt.cl_bnt_transform(ccl_obj.cl_ll_3d, bnt_matrix, 'L', 'L')
10581064
cl_3x2pt_5d = bnt.cl_bnt_transform_3x2pt(ccl_obj.cl_3x2pt_5d, bnt_matrix)
10591065
warnings.warn('you should probably BNT-transform the responses too!', stacklevel=2)
1060-
if compute_oc_g or compute_oc_ssc or compute_oc_cng:
1066+
if 'OneCovariance' in cov_terms_and_codes.values():
10611067
raise NotImplementedError('You should cut also the OC Cls')
10621068

10631069

@@ -1118,7 +1124,7 @@ def plot_cls():
11181124
# cl_3x2pt_5d = cl_utils.cl_ell_cut_3x2pt(
11191125
# cl_3x2pt_5d, ell_cuts_dict, ell_dict['ell_3x2pt']
11201126
# )
1121-
# if compute_oc_g or compute_oc_ssc or compute_oc_cng:
1127+
# if 'OneCovariance' in cov_terms_and_codes.values():
11221128
# raise NotImplementedError('You should cut also the OC Cls')
11231129

11241130
# re-set cls in the ccl_obj after BNT transform and/or ell cuts
@@ -1248,7 +1254,7 @@ def plot_cls():
12481254
# ! =================================== OneCovariance ================================
12491255
# initialize object
12501256
cov_oc_obj = None
1251-
if compute_oc_g or compute_oc_ssc or compute_oc_cng:
1257+
if 'OneCovariance' in cov_terms_and_codes.values():
12521258
if cfg['ell_cuts']['cl_ell_cuts']:
12531259
raise NotImplementedError(
12541260
'TODO double check inputs in this case. This case is untested'
@@ -1493,7 +1499,8 @@ def plot_cls():
14931499
print(f'Time taken to compute OC: {(time.perf_counter() - start_time) / 60:.2f} m')
14941500

14951501
cov_ssc_obj = None
1496-
if compute_sb_ssc:
1502+
if cov_terms_and_codes['SSC'] == 'Spaceborne':
1503+
# TODO most of this should go in the cov_ssc class
14971504
# ! ================================= Probe responses ==============================
14981505
resp_obj = responses.SpaceborneResponses(
14991506
cfg=cfg, k_grid=k_grid, z_grid=z_grid_trisp, ccl_obj=ccl_obj
@@ -1741,7 +1748,7 @@ def plot_cls():
17411748
)
17421749

17431750

1744-
if obs_space == 'real':
1751+
if obs_space == 'real' and 'Spaceborne' in cov_terms_and_codes.values():
17451752
print('\nComputing real-space covariance...')
17461753
start_rs = time.perf_counter()
17471754

@@ -1792,7 +1799,7 @@ def plot_cls():
17921799

17931800
# TODO this code block is almost identical to the real-space one above, probably
17941801
# TODO can be unified
1795-
if obs_space == 'cosebis':
1802+
if obs_space == 'cosebis' and 'Spaceborne' in cov_terms_and_codes.values():
17961803
print('\nComputing COSEBIs covariance...')
17971804
start_rs = time.perf_counter()
17981805

@@ -1843,21 +1850,60 @@ def plot_cls():
18431850

18441851
print(f'...done in {time.perf_counter() - start_rs:.2f} s')
18451852

1853+
# def copy_cov_dict_leaf_level(original, new):
1854+
# for term in original:
1855+
# for probe_2tpl in original[term]:
1856+
# for dim
1857+
# new[term][probe_2tpl] = deepcopy(original[term][probe_2tpl])
1858+
18461859

18471860
if obs_space == 'harmonic':
1848-
_cov_obj = cov_hs_obj
1861+
_cov_dict = cov_hs_obj.cov_dict
18491862
_probes = unique_probe_combs_hs
18501863
elif obs_space == 'real':
1851-
_cov_obj = cov_rs_obj
1864+
_cov_dict = cov_rs_obj.cov_dict
18521865
_probes = unique_probe_combs_rs
18531866
elif obs_space == 'cosebis':
1854-
_cov_obj = cov_cs_obj
1867+
_cov_dict = cov_cs_obj.cov_dict
18551868
_probes = unique_probe_combs_cs
18561869
else:
18571870
raise ValueError(
18581871
f'Unknown cfg["probe_selection"]["space"]: {cfg["probe_selection"]["space"]}'
18591872
)
18601873

1874+
# in the harmonic case, this is handled by the cov_harmonic_space class
1875+
if obs_space != 'harmonic':
1876+
if cfg['covariance']['G_code'] == 'OneCovariance':
1877+
# Copy arrays at leaf level for 'g' term
1878+
for probe_pair in _cov_dict['g']:
1879+
for dim in _cov_dict['g'][probe_pair]:
1880+
_cov_dict['g'][probe_pair][dim] = cov_oc_obj.cov_dict['g'][probe_pair][
1881+
dim
1882+
]
1883+
1884+
# Copy split Gaussian terms if they exist
1885+
if cfg['covariance']['split_gaussian_cov']:
1886+
for term in ['sva', 'sn', 'mix']:
1887+
for probe_pair in _cov_dict[term]:
1888+
for dim in _cov_dict[term][probe_pair]:
1889+
_cov_dict[term][probe_pair][dim] = cov_oc_obj.cov_dict[term][
1890+
probe_pair
1891+
][dim]
1892+
1893+
if cfg['covariance']['SSC_code'] == 'OneCovariance':
1894+
for probe_pair in _cov_dict['ssc']:
1895+
for dim in _cov_dict['ssc'][probe_pair]:
1896+
_cov_dict['ssc'][probe_pair][dim] = cov_oc_obj.cov_dict['ssc'][
1897+
probe_pair
1898+
][dim]
1899+
1900+
if cfg['covariance']['cNG_code'] == 'OneCovariance':
1901+
for probe_pair in _cov_dict['cng']:
1902+
for dim in _cov_dict['cng'][probe_pair]:
1903+
_cov_dict['cng'][probe_pair][dim] = cov_oc_obj.cov_dict['cng'][
1904+
probe_pair
1905+
][dim]
1906+
18611907

18621908
# # ! important note: for OC RS, list fmt seems to be missing some blocks (problem common to HS, solve it)
18631909
# # ! moreover, some of the sub-blocks are transposed.
@@ -1904,19 +1950,19 @@ def plot_cls():
19041950
# ! save 2D covs (for each term) in npz archive
19051951
cov_dict_tosave_2d = {}
19061952
if cfg['covariance']['G']:
1907-
cov_dict_tosave_2d['Gauss'] = _cov_obj.cov_dict['g']['3x2pt']['2d']
1953+
cov_dict_tosave_2d['Gauss'] = _cov_dict['g']['3x2pt']['2d']
19081954
if cfg['covariance']['SSC']:
1909-
cov_dict_tosave_2d['SSC'] = _cov_obj.cov_dict['ssc']['3x2pt']['2d']
1955+
cov_dict_tosave_2d['SSC'] = _cov_dict['ssc']['3x2pt']['2d']
19101956
if cfg['covariance']['cNG']:
1911-
cov_dict_tosave_2d['cNG'] = _cov_obj.cov_dict['cng']['3x2pt']['2d']
1957+
cov_dict_tosave_2d['cNG'] = _cov_dict['cng']['3x2pt']['2d']
19121958
if cfg['covariance']['split_gaussian_cov']:
1913-
cov_dict_tosave_2d['SVA'] = _cov_obj.cov_dict['sva']['3x2pt']['2d']
1914-
cov_dict_tosave_2d['SN'] = _cov_obj.cov_dict['sn']['3x2pt']['2d']
1915-
cov_dict_tosave_2d['MIX'] = _cov_obj.cov_dict['mix']['3x2pt']['2d']
1959+
cov_dict_tosave_2d['SVA'] = _cov_dict['sva']['3x2pt']['2d']
1960+
cov_dict_tosave_2d['SN'] = _cov_dict['sn']['3x2pt']['2d']
1961+
cov_dict_tosave_2d['MIX'] = _cov_dict['mix']['3x2pt']['2d']
19161962
# the total covariance is equal to the Gaussian one if neither SSC nor cNG are computed,
19171963
# so only save it if at least one of the two is computed
19181964
if cfg['covariance']['cNG'] or cfg['covariance']['SSC']:
1919-
cov_dict_tosave_2d['TOT'] = _cov_obj.cov_dict['tot']['3x2pt']['2d']
1965+
cov_dict_tosave_2d['TOT'] = _cov_dict['tot']['3x2pt']['2d']
19201966

19211967
cov_filename = cfg['covariance']['cov_filename']
19221968
np.savez_compressed(f'{output_path}/{cov_filename}_2D.npz', **cov_dict_tosave_2d)
@@ -1926,7 +1972,7 @@ def plot_cls():
19261972
# ! i.e. there is no cov_3x2pt_{term}_6d
19271973
if cfg['covariance']['save_full_cov']:
19281974
cov_dict_tosave_6d = {}
1929-
_cd = _cov_obj.cov_dict # just to make the code more readable
1975+
_cd = _cov_dict # just to make the code more readable
19301976
for _probe in _probes:
19311977
probe_ab, probe_cd = sl.split_probe_name(_probe, obs_space)
19321978
probe_2tpl = (probe_ab, probe_cd) # just to make the code more readable
@@ -1946,7 +1992,7 @@ def plot_cls():
19461992
np.savez_compressed(f'{output_path}/{cov_filename}_6D.npz', **cov_dict_tosave_6d)
19471993

19481994
if cfg['covariance']['save_cov_fits'] and obs_space == 'harmonic':
1949-
io_obj.save_cov_euclidlib(cov_hs_obj=_cov_obj)
1995+
io_obj.save_cov_euclidlib(cov_hs_obj=_cov_dict)
19501996
if cfg['covariance']['save_cov_fits'] and obs_space != 'harmonic':
19511997
raise ValueError(
19521998
'Official Euclid .fits format is only supported for harmonic space '
@@ -2224,13 +2270,13 @@ def plot_cls():
22242270

22252271
# ! old
22262272
covs_arrays_dict = {}
2227-
for term in _cov_obj.cov_dict:
2228-
for _probe_abcd in _cov_obj.cov_dict[term]:
2229-
for dim in _cov_obj.cov_dict[term][_probe_abcd]:
2230-
if _cov_obj.cov_dict[term][_probe_abcd][dim] is None:
2273+
for term in _cov_dict:
2274+
for _probe_abcd in _cov_dict[term]:
2275+
for dim in _cov_dict[term][_probe_abcd]:
2276+
if _cov_dict[term][_probe_abcd][dim] is None:
22312277
value = 0
22322278
else:
2233-
value = _cov_obj.cov_dict[term][_probe_abcd][dim]
2279+
value = _cov_dict[term][_probe_abcd][dim]
22342280

22352281
if _probe_abcd == ('LL', 'LL'):
22362282
probe_abcd = 'WL'
@@ -2248,12 +2294,12 @@ def plot_cls():
22482294
covs_arrays_dict[name] = value
22492295

22502296
# TODO this should be removed after merge with develop
2251-
if 'ssc' not in _cov_obj.cov_dict:
2297+
if 'ssc' not in _cov_dict:
22522298
for probe in ['WL', 'GC', '3x2pt']:
22532299
covs_arrays_dict[f'cov_{probe}_ssc_6d'] = 0
22542300
covs_arrays_dict[f'cov_{probe}_ssc_4d'] = 0
22552301
covs_arrays_dict[f'cov_{probe}_ssc_2d'] = 0
2256-
if 'cng' not in _cov_obj.cov_dict:
2302+
if 'cng' not in _cov_dict:
22572303
for probe in ['WL', 'GC', 'XC', '3x2pt']:
22582304
covs_arrays_dict[f'cov_{probe}_cng_6d'] = 0
22592305
covs_arrays_dict[f'cov_{probe}_cng_4d'] = 0

0 commit comments

Comments
 (0)