Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
bfbf404
fields for unmixing
JaimeRZP Nov 27, 2025
2006f71
first commit
JaimeRZP Nov 28, 2025
d49bdf8
working
JaimeRZP Nov 28, 2025
8911f6a
ruff
JaimeRZP Nov 28, 2025
8b97633
more tests
JaimeRZP Nov 28, 2025
fcaf5a1
tuning
JaimeRZP Nov 28, 2025
77f2ee3
tuning
JaimeRZP Nov 28, 2025
505efc0
tuning
JaimeRZP Nov 28, 2025
2b1cced
print statemnt
JaimeRZP Nov 28, 2025
75f0a5e
upd before tackling mms
JaimeRZP Nov 28, 2025
55c8325
tuning for mms workig
JaimeRZP Nov 28, 2025
c4d9276
riff
JaimeRZP Nov 28, 2025
ef3f475
ready to integrate options into mms
JaimeRZP Nov 28, 2025
12b315e
ruff
JaimeRZP Dec 1, 2025
f868c3c
Merge remote-tracking branch 'origin' into unmixing_refactor
JaimeRZP Dec 1, 2025
d501dea
bug
JaimeRZP Dec 1, 2025
1b1a54e
bug
JaimeRZP Dec 1, 2025
9cff968
bug
JaimeRZP Dec 1, 2025
c13b4e7
bug
JaimeRZP Dec 1, 2025
f800553
bug
JaimeRZP Dec 1, 2025
73934f0
bug
JaimeRZP Dec 1, 2025
dbe9915
bug
JaimeRZP Dec 1, 2025
44402e2
bug
JaimeRZP Dec 1, 2025
f8d950b
bug
JaimeRZP Dec 1, 2025
0cf5cb5
bug
JaimeRZP Dec 1, 2025
df54ff5
bug
JaimeRZP Dec 1, 2025
5131d92
bug
JaimeRZP Dec 1, 2025
b1ff76d
bug
JaimeRZP Dec 1, 2025
ac60057
bug
JaimeRZP Dec 1, 2025
808324e
try bug solution
JaimeRZP Dec 1, 2025
3b3bc10
bug
JaimeRZP Dec 1, 2025
1fc4f2b
MEMORY BUG
JaimeRZP Dec 1, 2025
f8ac4d9
bug
JaimeRZP Dec 1, 2025
2d188cb
bug
JaimeRZP Dec 1, 2025
661f28b
bug
JaimeRZP Dec 1, 2025
256f345
bug
JaimeRZP Dec 3, 2025
b0590a3
no tuning for now
JaimeRZP Dec 17, 2025
c7d7558
Merge branch 'main' into unmixing_refactor
JaimeRZP Dec 17, 2025
6515691
no tests for tuning
JaimeRZP Dec 18, 2025
d24631a
ruff
JaimeRZP Dec 18, 2025
5ba0ad0
small fix
JaimeRZP Dec 18, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 9 additions & 10 deletions examples/unmixing.ipynb

Large diffs are not rendered by default.

5 changes: 3 additions & 2 deletions heracles/dices/jackknife.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from ..mapping import transform
from ..twopoint import angular_power_spectra
from ..unmixing import _natural_unmixing, logistic
from ..transforms import cl2corr
from ..transforms import cl2corr, transform_cls

try:
from copy import replace
Expand Down Expand Up @@ -58,7 +58,8 @@ def jackknife_cls(data_maps, vis_maps, jk_maps, fields, nd=1):
_cls_mm = get_cls(vis_maps, jk_maps, fields, *regions)
# Mask correction
alphas = mask_correction(_cls_mm, mls0)
_cls = _natural_unmixing(_cls, alphas, fields)
_wcls = transform_cls(_cls)
_cls = _natural_unmixing(_wcls, alphas, fields)
# Bias correction
_cls = correct_bias(_cls, jk_maps, fields, *regions)
cls[regions] = _cls
Expand Down
160 changes: 160 additions & 0 deletions heracles/transforms.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,15 @@
from scipy.special import lpn as legendrep
from heracles.twopoint import truncated
import numpy as np


try:
from copy import replace
except ImportError:
# Python < 3.13
from dataclasses import replace


gauss_legendre = None
_gauss_legendre_cache = {}

Expand Down Expand Up @@ -154,6 +163,10 @@ def corr2cl(corrs, lmax=None, sampling_factor=1):
:return: array of power spectra, cl[L, ix], where L starts at zero and ix=0,1,2,3 in order TT, EE, BB, TE.
They include :math:`\ell(\ell+1)/2\pi` factors.
"""
if corrs.ndim == 1:
corrs = np.array(
[corrs, np.zeros_like(corrs), np.zeros_like(corrs), np.zeros_like(corrs)]
).T

if lmax is None:
lmax = corrs.shape[0] - 1
Expand All @@ -180,3 +193,150 @@ def corr2cl(corrs, lmax=None, sampling_factor=1):
cls[1, :] *= 2
cls[2:, :] = cls[2:, :]
return 2 * np.pi * cls


def transform_cls(cls, lmax_out=None):
"""
Natural unmixing of the data Cl.
Args:
cls: Data Cl
Returns:
corr: correlation function
"""
wds = {}
for key in cls.keys():
cl = cls[key]
s1, s2 = cl.spin
lmax = cl.shape[-1]
if lmax_out is None:
lmax_out = lmax
# Grab metadata
dtype = cl.array.dtype
# pad cls
cl = np.atleast_2d(cl.array)
if lmax_out > lmax:
pad_width = [(0, 0)] * cl.ndim # no padding for other dims
pad_width[-1] = (0, lmax_out - lmax) # pad only last dim
cl = np.pad(cl, pad_width, mode="constant", constant_values=0)
lmax = lmax_out
if (s1 != 0) and (s2 != 0):
_cl = np.array(
[
np.zeros_like(cl[0, 0]),
cl[0, 0], # EE like spin-2
cl[1, 1], # BB like spin-2
np.zeros_like(cl[0, 0]),
]
)
_icl = np.array(
[
np.zeros_like(cl[0, 0]),
-cl[0, 1], # EB like spin-0
cl[1, 0], # EB like spin-0
np.zeros_like(cl[0, 0]),
]
)
# transform to corr
_wd = cl2corr(_cl.T).T + 1j * cl2corr(_icl.T).T
_iwd = _wd.imag
_wd = _wd.real
# reorganize
wd = np.zeros_like(cl)
wd[0, 0] = _wd[1] # EE like spin-2
wd[1, 1] = _wd[2] # BB like spin-2
wd[0, 1] = _iwd[1] # EB like spin-0
wd[1, 0] = _iwd[2] # EB like spin-0
else:
# Treat everything as spin-0
wd = []
for _cl in cl:
_wd = cl2corr(_cl).T
wd.append(_wd[0])
# remove extra axis
wd = np.squeeze(wd)
# Add metadata back
wd = np.array(list(wd), dtype=dtype)
wds[key] = replace(
cls[key],
ell=np.arange(lmax),
lower=np.arange(lmax)[:-1],
upper=np.arange(lmax)[1:],
weight=np.ones(lmax),
array=wd,
)
# truncate to lmax
wds = truncated(wds, lmax_out)
return wds


def transform_corrs(wds, lmax_out=None):
"""
Natural unmixing of the data Cl.
Args:
corrs: data corrs
Returns:
corr: correlation function
"""
cls = {}
for key in wds.keys():
wd = wds[key]
s1, s2 = wd.spin
lmax = wd.shape[-1]
if lmax_out is None:
lmax_out = lmax
# Grab metadata
dtype = wd.array.dtype
# pad cls
wd = np.atleast_2d(wd.array)
if lmax_out > lmax:
pad_width = [(0, 0)] * wd.ndim # no padding for other dims
pad_width[-1] = (0, lmax_out - lmax) # pad only last dim
wd = np.pad(wd, pad_width, mode="constant", constant_values=0)
lmax = lmax_out
if (s1 != 0) and (s2 != 0):
_wd = np.array(
[
np.zeros_like(wd[0, 0]),
wd[0, 0], # EE like spin-2
wd[1, 1], # BB like spin-2
np.zeros_like(wd[0, 0]),
]
)
_iwd = np.array(
[
np.zeros_like(wd[0, 0]),
wd[0, 1], # EB like spin-0
wd[1, 0], # EB like spin-0
np.zeros_like(wd[0, 0]),
]
)
# transform to cls
_wd = corr2cl(_wd.T).T
_iwd = corr2cl(_iwd.T).T
# reorganize
cl = np.zeros_like(wd)
cl[0, 0] = _wd[1] # EE like spin-2
cl[1, 1] = _wd[2] # BB like spin-2
cl[0, 1] = _iwd[1] # EB like spin-0
cl[1, 0] = _iwd[2] # EB like spin-0
else:
# Treat everything as spin-0
cl = []
for _wd in wd:
_wd = corr2cl(_wd).T
cl.append(_wd[0])
# remove extra axis
cl = np.squeeze(cl)
# Add metadata back
cl = np.array(list(cl), dtype=dtype)
cls[key] = replace(
wds[key],
ell=np.arange(lmax),
lower=np.arange(lmax)[:-1],
upper=np.arange(lmax)[1:],
weight=np.ones(lmax),
array=cl,
)
# truncate to lmax
cls = truncated(cls, lmax_out)
return cls
6 changes: 6 additions & 0 deletions heracles/twopoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,6 +404,7 @@ def mixing_matrices(

def invert_mixing_matrix(
M,
options={},
rtol: float = 1e-5,
progress: Progress | None = None,
):
Expand Down Expand Up @@ -432,6 +433,11 @@ def invert_mixing_matrix(
s1, s2 = value.spin
*_, _n, _m = _M.shape

if key in options:
rtol = options[key]
else:
rtol = rtol

with progress.task(f"invert {key}"):
if (s1 != 0) and (s2 != 0):
_inv_m = np.linalg.pinv(
Expand Down
124 changes: 43 additions & 81 deletions heracles/unmixing.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,7 @@
# You should have received a copy of the GNU Lesser General Public
# License along with Heracles. If not, see <https://www.gnu.org/licenses/>.
import numpy as np
from .result import truncated
from .transforms import cl2corr, corr2cl
from .transforms import transform_cls, transform_corrs
from .utils import get_cl

try:
Expand All @@ -28,108 +27,71 @@
from dataclasses import replace


def natural_unmixing(d, m, fields, x0=-2, k=50, patch_hole=True, lmax=None):
def natural_unmixing(cls, mls, fields, options={}, rtol=0.3):
"""
Natural unmixing of the data Cl.
Args:
d: Data Cl
m: mask Cl
cls: Data Cl
mls: mask Cl
fields: list of fields
patch_hole: If True, apply the patch hole correction
Returns:
corr_d: Corrected Cl
corr_cls: Corrected Cl
"""
wm = {}
m_keys = list(m.keys())
for m_key in m_keys:
_m = m[m_key].array
_wm = cl2corr(_m).T[0]
if patch_hole:
_wm *= logistic(np.log10(abs(_wm)), x0=x0, k=k)
wm[m_key] = replace(m[m_key], array=_wm)
return _natural_unmixing(d, wm, fields, lmax=lmax)
mask_lmax = mls[list(mls.keys())[0]].shape[-1]
lmax = cls[list(cls.keys())[0]].shape[-1] - 1
wmls = transform_cls(mls)
wmls = correct_correlation(wmls, options=options, rtol=rtol)
wcls = transform_cls(cls, lmax_out=mask_lmax)
return _natural_unmixing(wcls, wmls, fields, lmax=lmax)


def _natural_unmixing(d, wm, fields, lmax=None):
def _natural_unmixing(wcls, wmls, fields, lmax=None):
"""
Natural unmixing of the data Cl.
Args:
d: Data Cl
wm: mask correlation function
wcls: data correlation function
wmls: mask correlation function
fields: list of fields
patch_hole: If True, apply the patch hole correction
Returns:
corr_d: Corrected Cl
corr_cls: Corrected Cl
"""
corr_d = {}
corr_wcls = {}
masks = {}
for key, field in fields.items():
if field.mask is not None:
masks[key] = field.mask

for key in d.keys():
for key in wcls.keys():
a, b, i, j = key
m_key = (masks[a], masks[b], i, j)
_wm = get_cl(m_key, wm)
_d = d[key]
s1, s2 = _d.spin
if lmax is None:
*_, lmax = _d.shape
lmax_mask = len(_wm.array)
# Grab metadata
dtype = _d.array.dtype
# pad cls
_d = np.atleast_2d(_d.array)
pad_width = [(0, 0)] * _d.ndim # no padding for other dims
pad_width[-1] = (0, lmax_mask - lmax) # pad only last dim
_d = np.pad(_d, pad_width, mode="constant", constant_values=0)
if (s1 != 0) and (s2 != 0):
__d = np.array(
[
np.zeros_like(_d[0, 0]),
_d[0, 0], # EE like spin-2
_d[1, 1], # BB like spin-2
np.zeros_like(_d[0, 0]),
]
)
__id = np.array(
[
np.zeros_like(_d[0, 0]),
-_d[0, 1], # EB like spin-0
_d[1, 0], # EB like spin-0
np.zeros_like(_d[0, 0]),
]
)
# Correct by alpha
wd = cl2corr(__d.T).T + 1j * cl2corr(__id.T).T
corr_wd = (wd / _wm).real
icorr_wd = (wd / _wm).imag
# Transform back to Cl
__corr_d = corr2cl(corr_wd.T).T
__icorr_d = corr2cl(icorr_wd.T).T
# reorder
_corr_d = np.zeros_like(_d)
_corr_d[0, 0] = __corr_d[1] # EE like spin-2
_corr_d[1, 1] = __corr_d[2] # BB like spin-2
_corr_d[0, 1] = -__icorr_d[1] # EB like spin-0
_corr_d[1, 0] = __icorr_d[2] # EB like spin-0
wml = get_cl(m_key, wmls).array
corr_wcls[key] = replace(wcls[key], array=wcls[key].array / wml)

corr_cls = transform_corrs(corr_wcls, lmax_out=lmax)
return corr_cls


def correct_correlation(wms, options={}, rtol=0.3):
"""
Correct correlation functions using a logistic function.
Args:
wms: mask correlation functions
rtol: relative tolerance for the cutoff
Returns:
corrected_wms: corrected mask correlation functions
"""
corrected_wms = {}
for key, wm in wms.items():
if key in options:
rtol = options[key]
else:
# Treat everything as spin-0
_corr_d = []
for cl in _d:
wd = cl2corr(cl).T
corr_wd = wd / _wm
# Transform back to Cl
__corr_d = corr2cl(corr_wd.T).T
_corr_d.append(__corr_d[0])
# remove extra axis
_corr_d = np.squeeze(_corr_d)
# Add metadata back
_corr_d = np.array(list(_corr_d), dtype=dtype)
corr_d[key] = replace(d[key], array=_corr_d)
# truncate to lmax
corr_d = truncated(corr_d, lmax)
return corr_d
rtol = rtol
wm = wm.array
cutoff = rtol * np.max(np.abs(wm))
_wm = wm * logistic(np.log10(abs(wm)), x0=np.log10(cutoff))
corrected_wms[key] = replace(wms[key], array=_wm)
return corrected_wms


def logistic(x, x0=-5, k=50):
Expand Down
5 changes: 4 additions & 1 deletion heracles/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,10 @@ def get_cl(key, cls):
arr = cls[key_sym].array
s1, s2 = cls[key_sym].spin
if s1 != 0 and s2 != 0:
arr = np.transpose(arr, axes=(1, 0, 2))
# check this is not a mixing matrix
n, _, _ = arr.shape
if n != 3:
arr = np.transpose(arr, axes=(1, 0, 2))
# always transpose spins
s1, s2 = s2, s1
else:
Expand Down
3 changes: 2 additions & 1 deletion tests/test_dices.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,8 @@ def test_get_delete2_fsky(jk_maps, njk):

def test_mask_correction(cls0, mls0, fields):
alphas = dices.mask_correction(mls0, mls0)
_cls = heracles.unmixing._natural_unmixing(cls0, alphas, fields)
wcls0 = heracles.transforms.transform_cls(cls0)
_cls = heracles.unmixing._natural_unmixing(wcls0, alphas, fields)
for key in list(cls0.keys()):
cl = cls0[key].array
_cl = _cls[key].array
Expand Down
Loading