Skip to content

Commit 59c2b0d

Browse files
authored
gh-380: Inv mat ops (#379)
1 parent ffdd20f commit 59c2b0d

File tree

2 files changed

+24
-9
lines changed

2 files changed

+24
-9
lines changed

heracles/twopoint.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222

2323
import logging
2424
import time
25+
from collections.abc import Mapping
2526
from datetime import timedelta
2627
from itertools import combinations_with_replacement, product
2728
from typing import TYPE_CHECKING, Any
@@ -39,7 +40,7 @@
3940
from dataclasses import replace
4041

4142
if TYPE_CHECKING:
42-
from collections.abc import Mapping, MutableMapping
43+
from collections.abc import MutableMapping
4344

4445
from numpy.typing import ArrayLike, NDArray
4546

@@ -404,15 +405,15 @@ def mixing_matrices(
404405

405406
def invert_mixing_matrix(
406407
M,
407-
rtol: float = 1e-5,
408+
rcond: float = 1e-5,
408409
progress: Progress | None = None,
409410
):
410411
"""
411412
Inversion model for the unmixing E/B modes.
412413
413414
Args:
414415
M: Mixing matrix (mapping of keys -> Result objects)
415-
rtol: relative tolerance for pseudo-inverse
416+
rcond: relative tolerance for pseudo-inverse
416417
progress: optional progress reporter
417418
418419
Returns:
@@ -432,14 +433,21 @@ def invert_mixing_matrix(
432433
s1, s2 = value.spin
433434
*_, _n, _m = _M.shape
434435

436+
if isinstance(rcond, Mapping):
437+
if key not in rcond:
438+
raise KeyError(f"Missing rcond value for wm key: {key}")
439+
_rcond = rcond[key]
440+
else:
441+
_rcond = rcond
442+
435443
with progress.task(f"invert {key}"):
436444
if (s1 != 0) and (s2 != 0):
437445
# Cl^EE+Cl^BB and Cl^EE-Cl^BB transformation
438446
# makes the mixing matrix block-diagonal
439447
M_p = _M[0] + _M[1]
440448
M_m = _M[0] - _M[1]
441-
inv_M_p = np.linalg.pinv(M_p, rcond=rtol)
442-
inv_M_m = np.linalg.pinv(M_m, rcond=rtol)
449+
inv_M_p = np.linalg.pinv(M_p, rcond=_rcond)
450+
inv_M_m = np.linalg.pinv(M_m, rcond=_rcond)
443451
_inv_m = np.vstack(
444452
(
445453
np.hstack(((inv_M_p + inv_M_m) / 2, (inv_M_p - inv_M_m) / 2)),
@@ -448,10 +456,10 @@ def invert_mixing_matrix(
448456
)
449457
_inv_M_EEEE = _inv_m[:_m, :_n]
450458
_inv_M_EEBB = _inv_m[_m:, :_n]
451-
_inv_M_EBEB = np.linalg.pinv(_M[2], rcond=rtol)
459+
_inv_M_EBEB = np.linalg.pinv(_M[2], rcond=_rcond)
452460
_inv_M = np.array([_inv_M_EEEE, _inv_M_EEBB, _inv_M_EBEB])
453461
else:
454-
_inv_M = np.linalg.pinv(_M, rcond=rtol)
462+
_inv_M = np.linalg.pinv(_M, rcond=_rcond)
455463

456464
inv_M[key] = replace(M[key], array=_inv_M)
457465
return inv_M

tests/test_twopoint.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -368,7 +368,7 @@ def test_inverting_mixing_matrices():
368368
("WHT", "WHT", 0, 0): Result(cl, spin=(0, 0), axis=(0,)),
369369
}
370370
mms = mixing_matrices(fields, cls2, l1max=10, l2max=20)
371-
inv_mms = invert_mixing_matrix(mms)
371+
inv_mms = invert_mixing_matrix(mms, rcond=1e-2)
372372

373373
# test for correct shape
374374
for key in mms.keys():
@@ -387,7 +387,14 @@ def test_inverting_mixing_matrices():
387387
_m = np.ones_like(mms[key].array)
388388
mms[key] = Result(_m, spin=mms[key].spin, axis=mms[key].axis, ell=mms[key].ell)
389389

390-
inv_mms = invert_mixing_matrix(mms)
390+
inv_mms = invert_mixing_matrix(
391+
mms,
392+
rcond={
393+
("POS", "POS", 0, 0): 1e-2,
394+
("POS", "SHE", 0, 0): 1e-3,
395+
("SHE", "SHE", 0, 0): 1e-4,
396+
},
397+
)
391398
assert inv_mms.keys() == mms.keys()
392399
for key in mms:
393400
inv_mm = inv_mms[key].array

0 commit comments

Comments
 (0)