Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 25 additions & 25 deletions examples/jackknife-covariance.ipynb

Large diffs are not rendered by default.

2 changes: 0 additions & 2 deletions heracles/dices/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
"jackknife_fsky",
"jackknife_bias",
"correct_bias",
"jackknife_maps",
"get_mask_correlation_ratio",
"correct_footprint_reduction",
"jackknife_covariance",
Expand All @@ -44,7 +43,6 @@

from .jackknife import (
jackknife_cls,
jackknife_maps,
jackknife_fsky,
jackknife_bias,
correct_bias,
Expand Down
87 changes: 42 additions & 45 deletions heracles/dices/jackknife.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def jackknife_cls(
inputs:
data_maps (dict): Dictionary of data maps
vis_maps (dict): Dictionary of visibility maps
jkmaps (dict): Dictionary of mask maps
jk_maps (dict): Dictionary of mask maps
fields (dict): Dictionary of fields
mask_correction (str): Type of mask correction to apply ("Fast" or "Full")
nd (int): Number of Jackknife regions
Expand All @@ -54,16 +54,28 @@ def jackknife_cls(
if nd < 0 or nd > 2:
raise ValueError("number of deletions must be 0, 1, or 2")
cls = {}
mls0 = get_cls(vis_maps, jk_maps, fields)
jkmap = jk_maps[list(jk_maps.keys())[0]]
njk = len(np.unique(jkmap)[np.unique(jkmap) != 0])

data_alms_regions = {}
vis_alms_regions = {}
for k in range(1, njk + 1):
print(f" - Computing ALMs for region {k}", end="\r", flush=True)
data_alms_regions[k] = transform(
fields, _get_region_maps(data_maps, jk_maps, k)
)
vis_alms_regions[k] = transform(fields, _get_region_maps(vis_maps, jk_maps, k))

mls0 = angular_power_spectra(_sum_alms_except(vis_alms_regions, ()))

for regions in combinations(range(1, njk + 1), nd):
_cls = get_cls(data_maps, jk_maps, fields, *regions)
_cls_mm = get_cls(vis_maps, jk_maps, fields, *regions)
# Bias correction
print(f" - Computing Cls for regions {regions}", end="\r", flush=True)
alms_jk = _sum_alms_except(data_alms_regions, regions)
_cls = angular_power_spectra(alms_jk)
_cls = correct_bias(_cls, jk_maps, fields, *regions)
# Mask correction
if mask_correction == "Full":
vis_alms_jk = _sum_alms_except(vis_alms_regions, regions)
_cls_mm = angular_power_spectra(vis_alms_jk)
alphas = get_mask_correlation_ratio(_cls_mm, mls0, unmixed=unmixed)
_cls = _naturalspice(_cls, alphas, fields)
elif mask_correction == "Fast":
Expand All @@ -76,57 +88,42 @@ def jackknife_cls(
return cls


def get_cls(maps, jkmaps, fields, jk=0, jk2=0):
def _get_region_maps(maps, jkmaps, jk):
"""
Internal method to compute the Cls of removing 2 Jackknife.
inputs:
maps (dict): Dictionary of data maps
jkmaps (dict): Dictionary of mask maps
fields (dict): Dictionary of fields
jk (int): Jackknife region to remove
jk2 (int): Jackknife region to remove
returns:
cls (dict): Dictionary of data Cls
"""
print(f" - Computing Cls for regions ({jk},{jk2})", end="\r", flush=True)
# remove the region from the maps
_maps = jackknife_maps(maps, jkmaps, jk=jk, jk2=jk2)
# compute alms
alms = transform(fields, _maps)
# compute cls
cls = angular_power_spectra(alms)
return cls


def jackknife_maps(maps, jkmaps, jk=0, jk2=0):
"""
Internal method to remove a region from the maps.
inputs:
maps (dict): Dictionary of data maps
jkmaps (dict): Dictionary of mask maps
jk (int): Jackknife region to remove
jk2 (int): Jackknife region to remove
returns:
maps (dict): Dictionary of data maps
Returns maps with only the pixels belonging to jackknife region *jk* active.
All other pixels are set to zero.
"""
_maps = deepcopy(maps)
for key_data, key_mask in zip(maps.keys(), jkmaps.keys()):
_map = _maps[key_data]
_jkmap = jkmaps[key_mask]

if _jkmap is None:
continue

_mask = np.copy(_jkmap)
_mask = (_mask > 0).astype(int)
# Remove jk 2 regions
cond = np.where((_jkmap == float(jk)) | (_jkmap == float(jk2)))[0]
_mask[cond] = 0.0
# Apply mask
_mask = (_jkmap == float(jk)).astype(int)
_map *= _mask
return _maps


def _sum_alms_except(alms_regions, exclude=()):
"""
Returns the sum of all region alms except those whose key is in *exclude*.

Metadata (including bias) is taken from the first included region's alms,
consistent with the mapper copying the full-footprint bias to every region.
Passing an empty *exclude* gives the full-sky alms; passing the deleted
region keys gives the delete-k or delete-k1k2 alms directly.
"""
included = [alms for k, alms in alms_regions.items() if k not in exclude]
first = included[0]
result = {}
for key in first:
arr = first[key].copy()
for alms_k in included[1:]:
arr += alms_k[key]
result[key] = arr
return result


def bias(cls):
"""
Internal method to compute the bias.
Expand Down
6 changes: 3 additions & 3 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,10 +154,10 @@ def cls0(fields, data_maps):


@pytest.fixture(scope="session")
def mls0(fields, vis_maps, jk_maps):
from heracles.dices.jackknife import get_cls
def mls0(fields, vis_maps):
from heracles import transform, angular_power_spectra

return get_cls(vis_maps, jk_maps, fields)
return angular_power_spectra(transform(fields, vis_maps))


@pytest.fixture(scope="session")
Expand Down
79 changes: 42 additions & 37 deletions tests/test_dices.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,43 +15,48 @@ def test_jkmap(jk_maps, njk):
assert np.all(np.unique(jk_maps[key]) == np.arange(1, njk + 1))


def test_jackknife_maps(data_maps, jk_maps, njk):
# multiply maps by jk footprint
vmap = np.copy(jk_maps[("VIS", 1)])
vmap[vmap > 0] = vmap[vmap > 0] / vmap[vmap > 0]
for key in list(data_maps.keys()):
data_maps[key] *= vmap
# test null case
_data_maps = dices.jackknife.jackknife_maps(data_maps, jk_maps)
for key in list(_data_maps.keys()):
np.testing.assert_allclose(_data_maps[key], data_maps[key])
# test delete1 case
__data_maps = np.array(
[
dices.jackknife.jackknife_maps(data_maps, jk_maps, jk=i, jk2=i)[("POS", 1)]
for i in range(1, njk + 1)
]
)
__data_map = np.sum(__data_maps, axis=0) / (njk - 1)
np.testing.assert_allclose(__data_map, data_maps[("POS", 1)])
___data_map = np.prod(__data_maps, axis=0)
np.testing.assert_allclose(___data_map, np.zeros_like(data_maps[("POS", 1)]))

# Copy data map and add systematic map which should not be jackknifed
data_maps_nojk = data_maps.copy()
data_maps_nojk[("SYS", 1)] = np.arange(1, 11, dtype=float)

# Copy Jackknife maps and add None map, output jackknifed maps
jk_maps_nojk = jk_maps.copy()
jk_maps_nojk[("SYS", 1)] = None
out_maps = dices.jackknife.jackknife_maps(data_maps_nojk, jk_maps_nojk, jk=1)

# Assert that the SYS map is unchanged
np.testing.assert_allclose(out_maps[("SYS", 1)], data_maps_nojk[("SYS", 1)])

# Check that a sample key WAS jackknifed
sample_key = ("POS", 1)
assert not np.allclose(out_maps[sample_key], data_maps_nojk[sample_key])
def _remove_regions(maps, jk_maps, regions):
"""Reference: explicitly zero out the given regions in each map."""
from copy import deepcopy

_maps = deepcopy(maps)
for key_data, key_mask in zip(maps.keys(), jk_maps.keys()):
_jkmap = jk_maps[key_mask]
if _jkmap is None:
continue
mask = (_jkmap > 0).astype(int)
for r in regions:
mask[_jkmap == float(r)] = 0
_maps[key_data] *= mask
return _maps


def test_region_alm_cls(fields, data_maps, jk_maps, njk):
"""ALM-subtraction and map-masking must give identical Cls."""
from itertools import combinations

from heracles import angular_power_spectra, transform
from heracles.dices.jackknife import _get_region_maps, _sum_alms_except

alms_regions = {
k: transform(fields, _get_region_maps(data_maps, jk_maps, k))
for k in range(1, njk + 1)
}

for nd in (0, 1, 2):
for regions in combinations(range(1, njk + 1), nd):
cls_new = angular_power_spectra(_sum_alms_except(alms_regions, regions))
cls_ref = angular_power_spectra(
transform(fields, _remove_regions(data_maps, jk_maps, regions))
)
for key in cls_ref:
np.testing.assert_allclose(
cls_new[key].array,
cls_ref[key].array,
rtol=1e-7,
atol=1e-10,
err_msg=f"nd={nd}, regions={regions}, key={key}",
)


def test_cls(nside, cls0, fields, data_maps, vis_maps, jk_maps):
Expand Down
Loading