Skip to content

Commit 74476e3

Browse files
committed
overwrite covs at leaf levels in the OC case, sorted -> set for lists comparison
1 parent af5cd59 commit 74476e3

File tree

1 file changed

+18
-12
lines changed

1 file changed

+18
-12
lines changed

spaceborne/cov_harmonic_space.py

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -394,29 +394,35 @@ def combine_and_reshape_covs(
394394
assert _term in cov_oc_obj.cov_dict, '_term not in cov_oc_obj.cov_dict'
395395

396396
# check probes
397-
probe_list_sb = sorted(self.cov_dict[_term].keys())
398-
probe_list_oc = sorted(cov_oc_obj.cov_dict[_term].keys())
397+
probe_list_sb = set(self.cov_dict[_term].keys())
398+
probe_list_oc = set(cov_oc_obj.cov_dict[_term].keys())
399399
assert probe_list_sb == probe_list_oc, (
400400
f'probe_list_sb: {probe_list_sb}, probe_list_oc: {probe_list_oc}'
401401
)
402402

403403
# check dims
404404
for _probe_2tpl in probe_list_sb:
405-
dim_list_sb = sorted(self.cov_dict[_term][_probe_2tpl].keys())
406-
dim_list_oc = sorted(cov_oc_obj.cov_dict[_term][_probe_2tpl].keys())
405+
dim_list_sb = set(self.cov_dict[_term][_probe_2tpl].keys())
406+
dim_list_oc = set(cov_oc_obj.cov_dict[_term][_probe_2tpl].keys())
407407
assert dim_list_sb == dim_list_oc, (
408408
f'dim_list_sb: {dim_list_sb}, dim_list_oc: {dim_list_oc}'
409409
)
410-
assert dim_list_sb == ['6d'], (
411-
'the dict should only contain 6d arrays for the moment'
412-
)
413410

411+
# TODO delete this
414412
# having checked the covs, overwrite the relevand dict items
415-
self.cov_dict['g'] = deepcopy(cov_oc_obj.cov_dict['g'])
416-
if split_gaussian_cov:
417-
self.cov_dict['sva'] = deepcopy(cov_oc_obj.cov_dict['sva'])
418-
self.cov_dict['sn'] = deepcopy(cov_oc_obj.cov_dict['sn'])
419-
self.cov_dict['mix'] = deepcopy(cov_oc_obj.cov_dict['mix'])
413+
# self.cov_dict['g'] = deepcopy(cov_oc_obj.cov_dict['g'])
414+
# if split_gaussian_cov:
415+
# self.cov_dict['sva'] = deepcopy(cov_oc_obj.cov_dict['sva'])
416+
# self.cov_dict['sn'] = deepcopy(cov_oc_obj.cov_dict['sn'])
417+
# self.cov_dict['mix'] = deepcopy(cov_oc_obj.cov_dict['mix'])
418+
419+
for term in self.cov_dict:
420+
for probe_2tpl in self.cov_dict[term]:
421+
for dim in self.cov_dict[term][probe_2tpl]:
422+
if self.cov_dict[term][probe_2tpl][dim] is not None:
423+
self.cov_dict[term][probe_2tpl][dim] = deepcopy(
424+
cov_oc_obj.cov_dict[term][probe_2tpl][dim]
425+
)
420426

421427
# ! reshape and set SSC and cNG - the "if include SSC/cNG"
422428
# ! are inside the function

0 commit comments

Comments
 (0)