Skip to content

Commit 98a2091

Browse files
alxndrkalininclaude
andcommitted
fix(metrics): address code review findings scored >50
- Replace xp/get_array_module pattern with to_same_device in all three spectral PCC functions (CLAUDE.md: prefer cubic.cuda abstractions) - Use cubic.scipy proxy for savgol_filter and median_filter instead of direct scipy imports (CLAUDE.md: device-agnostic wrappers) - Add xy_std/z_std keys to fsc_resolution mask backend return path to match the documented return contract - Warn when n_repeats>1 is silently ignored (checkerboard or two-image) - Clarify n_repeats docstring: requires single-image + binomial mode - Add tests for spectral_pcc(smooth=True) and nbins_low parameter Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent a651ef1 commit 98a2091

File tree

3 files changed

+117
-33
lines changed

3 files changed

+117
-33
lines changed

cubic/metrics/bandlimited.py

Lines changed: 60 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030

3131
import numpy as np
3232

33-
from cubic.cuda import asnumpy, to_same_device, get_array_module, check_same_device
33+
from cubic.cuda import asnumpy, to_same_device, check_same_device
3434
from cubic.image_utils import tukey_window, hamming_window
3535

3636
from .spectral.dcr import dcr_resolution
@@ -432,7 +432,7 @@ def smooth_spectral_weights(
432432
np.ndarray
433433
Weights in [0, 1].
434434
"""
435-
from scipy.signal import savgol_filter
435+
from cubic.scipy import signal as csignal
436436

437437
log_p = np.log(np.maximum(power, 1e-30))
438438
n = len(log_p)
@@ -446,7 +446,7 @@ def smooth_spectral_weights(
446446
if wlen % 2 == 0:
447447
wlen -= 1
448448
wlen = max(wlen, min_wlen)
449-
log_p_smooth = savgol_filter(log_p, wlen, sg_polyorder)
449+
log_p_smooth = csignal.savgol_filter(log_p, wlen, sg_polyorder)
450450
p_smooth = np.exp(log_p_smooth)
451451

452452
n2 = noise_floor**2
@@ -525,7 +525,7 @@ def _estimate_noise_baseline(
525525
noise_baseline : np.ndarray
526526
Frequency-dependent noise floor N(k), same length as *power*.
527527
"""
528-
from scipy.signal import savgol_filter
528+
from cubic.scipy import signal as csignal
529529

530530
n = len(power)
531531

@@ -541,7 +541,7 @@ def _estimate_noise_baseline(
541541
wlen = max(wlen, min_wlen)
542542

543543
log_p = np.log(np.maximum(power, 1e-30))
544-
log_p_smooth = savgol_filter(log_p, wlen, sg_polyorder)
544+
log_p_smooth = csignal.savgol_filter(log_p, wlen, sg_polyorder)
545545

546546
# Running low-quantile of smoothed log-power
547547
log_n = _running_quantile_1d(log_p_smooth, window=quantile_window, q=quantile)
@@ -713,14 +713,13 @@ def _spectral_pcc_baseline(
713713
edges = to_same_device(edges_cpu, prediction)
714714
bid = radial_bin_id(prediction.shape, edges, spacing=spacing_seq)
715715

716-
xp = get_array_module(prediction)
717716
n_bins_needed = int(asnumpy(bid[bid >= 0].max())) + 1 if np.any(bid >= 0) else 0
718717
if len(w_bins) < n_bins_needed:
719718
raise ValueError(
720719
f"frozen_weights has {len(w_bins)} bins but binning requires "
721720
f"{n_bins_needed}; check that bin_delta and image shape match."
722721
)
723-
w_bins_dev = xp.asarray(w_bins) if xp is not np else w_bins
722+
w_bins_dev = to_same_device(np.asarray(w_bins, dtype=np.float32), prediction)
724723

725724
W = np.zeros_like(bid, dtype=np.float32)
726725
valid = bid >= 0
@@ -751,13 +750,14 @@ def frc_weights(
751750
alpha: float = 2.0,
752751
nbins_low: int = 3,
753752
smooth_window: int = 5,
753+
split_type: str = "binomial",
754+
n_repeats: int = 3,
755+
rng: np.random.Generator | int | None = 42,
754756
) -> np.ndarray:
755757
"""Per-bin weights derived from single-image FRC reproducibility.
756758
757-
Splits the GT image via checkerboard, computes ring-wise FRC,
758-
and converts the FRC curve to monotone-decreasing weights that
759-
indicate per-frequency reliability. Operates entirely in
760-
index-frequency units.
759+
Uses binomial splitting (same-shape, full frequency coverage) by
760+
default, falling back to checkerboard if requested.
761761
762762
Parameters
763763
----------
@@ -766,13 +766,21 @@ def frc_weights(
766766
bin_delta : int
767767
Radial-bin width in index units.
768768
threshold : float
769-
FRC threshold (default 0.143 = 1/7).
769+
FRC threshold (default 0.143 = 1-bit).
770770
alpha : float
771771
Weight exponent (default 2.0 — sharpens the transition).
772772
nbins_low : int
773773
Number of lowest bins to zero (DC / background exclusion).
774774
smooth_window : int
775775
Median-filter window (clamped to odd, >= 3).
776+
split_type : str
777+
``"binomial"`` (default, same-shape, Rieger et al. 2024) or
778+
``"checkerboard"`` (subsampled halves, Koho et al. 2019).
779+
n_repeats : int
780+
Number of independent binomial splits to average (default 3).
781+
Ignored for checkerboard.
782+
rng : Generator, int, or None
783+
Random seed for binomial split reproducibility.
776784
777785
Returns
778786
-------
@@ -781,7 +789,7 @@ def frc_weights(
781789
binning matching ``radial_edges(image.shape, bin_delta,
782790
spacing=None)``).
783791
"""
784-
from scipy.ndimage import median_filter
792+
from cubic.scipy import ndimage as cndimage
785793

786794
# --- lazy import to avoid circular deps ---
787795
from .spectral.frc import calculate_frc as _calculate_frc
@@ -791,16 +799,26 @@ def frc_weights(
791799
if image.shape[0] != image.shape[1]:
792800
raise ValueError(f"frc_weights requires square images, got {image.shape}.")
793801

794-
# 1. FRC curve via existing public API (no spacing → index units)
795-
result = _calculate_frc(
796-
image,
802+
# 1. FRC curve via public API (no spacing → index units)
803+
frc_kwargs = dict(
797804
image2=None,
798805
backend="hist",
799806
bin_delta=bin_delta,
800807
zero_padding=False,
801808
disable_hamming=False,
802-
average=True,
809+
split_type=split_type,
803810
)
811+
if split_type == "binomial":
812+
frc_kwargs.update(
813+
counts_mode="poisson_thinning",
814+
n_repeats=n_repeats,
815+
rng=rng,
816+
average=False, # no checkerboard reverse averaging
817+
)
818+
else:
819+
frc_kwargs["average"] = True # checkerboard forward+reverse
820+
821+
result = _calculate_frc(image, **frc_kwargs)
804822
frc_curve = np.clip(
805823
np.asarray(result.correlation["correlation"], dtype=np.float64),
806824
-1.0,
@@ -820,16 +838,18 @@ def frc_weights(
820838
freq_nyq_full = float(np.floor(image.shape[0] / 2.0))
821839
freq_full_norm = radii_full_idx / freq_nyq_full
822840

823-
# FRC was computed on checkerboard halves (shape//2), so its [0,1]
824-
# normalised axis covers only the half-image Nyquist. Derive the
825-
# frequency scaling factor from actual image dimensions.
826-
freq_nyq_half = float(np.floor(image.shape[0] // 2 / 2.0))
827-
freq_scale = freq_nyq_full / freq_nyq_half # typically 2.0
828-
freq_full_in_half = np.clip(freq_full_norm * freq_scale, 0.0, 1.0)
841+
# 3. Map FRC frequency axis onto full-resolution bins
842+
if split_type == "checkerboard":
843+
# Checkerboard halves are shape//2 → FRC [0,1] covers half Nyquist
844+
freq_nyq_half = float(np.floor(image.shape[0] // 2 / 2.0))
845+
freq_scale = freq_nyq_full / freq_nyq_half
846+
interp_x = np.clip(freq_full_norm * freq_scale, 0.0, 1.0)
847+
else:
848+
# Binomial split: same shape → FRC and full bins share the same axis
849+
interp_x = freq_full_norm
829850

830-
# 3. Interpolate FRC onto full-resolution bins
831851
frc_full = np.interp(
832-
freq_full_in_half,
852+
interp_x,
833853
freq_norm,
834854
frc_curve,
835855
left=float(frc_curve[0]),
@@ -849,7 +869,7 @@ def frc_weights(
849869
# 5. Smooth + monotone non-increasing envelope
850870
sw = smooth_window | 1 # clamp to odd
851871
sw = max(3, min(sw, len(w) | 1))
852-
w = median_filter(w, size=sw).astype(np.float64)
872+
w = cndimage.median_filter(w, size=sw).astype(np.float64)
853873
w = np.maximum.accumulate(w[::-1])[::-1].copy()
854874

855875
# 6. Low-k exclusion
@@ -871,6 +891,9 @@ def spectral_pcc_frcw(
871891
alpha: float = 2.0,
872892
nbins_low: int = 3,
873893
smooth_window: int = 5,
894+
split_type: str = "binomial",
895+
n_repeats: int = 3,
896+
rng: np.random.Generator | int | None = 42,
874897
frozen_weights: np.ndarray | None = None,
875898
) -> float:
876899
"""Spectral PCC weighted by single-image FRC reproducibility.
@@ -894,6 +917,12 @@ def spectral_pcc_frcw(
894917
Number of lowest bins to exclude.
895918
smooth_window : int
896919
Median-filter window for weight smoothing.
920+
split_type : str
921+
``"binomial"`` or ``"checkerboard"`` — passed to :func:`frc_weights`.
922+
n_repeats : int
923+
Number of binomial splits to average.
924+
rng : Generator, int, or None
925+
Random seed for binomial split.
897926
frozen_weights : np.ndarray, optional
898927
Pre-computed 1-D radial-bin weights. If given, skip FRC weight
899928
estimation. Must match index-unit binning for the target shape.
@@ -920,6 +949,9 @@ def spectral_pcc_frcw(
920949
alpha=alpha,
921950
nbins_low=nbins_low,
922951
smooth_window=smooth_window,
952+
split_type=split_type,
953+
n_repeats=n_repeats,
954+
rng=rng,
923955
)
924956

925957
# 2. Zero-weight-mass guard
@@ -950,14 +982,13 @@ def spectral_pcc_frcw(
950982
edges = to_same_device(edges_cpu, prediction)
951983
bid = radial_bin_id(prediction.shape, edges, spacing=None)
952984

953-
xp = get_array_module(prediction)
954985
n_bins_needed = int(asnumpy(bid[bid >= 0].max())) + 1 if np.any(bid >= 0) else 0
955986
if len(w_bins) < n_bins_needed:
956987
raise ValueError(
957988
f"frozen_weights has {len(w_bins)} bins but binning requires "
958989
f"{n_bins_needed}; check that bin_delta and image shape match."
959990
)
960-
w_bins_dev = xp.asarray(w_bins) if xp is not np else w_bins
991+
w_bins_dev = to_same_device(np.asarray(w_bins, dtype=np.float32), prediction)
961992

962993
W = np.zeros_like(bid, dtype=np.float32)
963994
valid = bid >= 0
@@ -1332,14 +1363,13 @@ def spectral_pcc(
13321363
edges = to_same_device(edges_cpu, prediction)
13331364
bid = radial_bin_id(prediction.shape, edges, spacing=spacing_seq)
13341365

1335-
xp = get_array_module(prediction)
13361366
n_bins_needed = int(asnumpy(bid[bid >= 0].max())) + 1 if np.any(bid >= 0) else 0
13371367
if len(w_bins) < n_bins_needed:
13381368
raise ValueError(
13391369
f"Weight vector has {len(w_bins)} bins but binning requires "
13401370
f"{n_bins_needed}; check that bin_delta and image shape match."
13411371
)
1342-
w_bins_dev = xp.asarray(w_bins) if xp is not np else w_bins
1372+
w_bins_dev = to_same_device(np.asarray(w_bins, dtype=np.float32), prediction)
13431373

13441374
# Build weight volume: map bin weights through bin_id
13451375
W = np.zeros_like(bid, dtype=np.float32)

cubic/metrics/spectral/frc.py

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -450,7 +450,8 @@ def calculate_frc(
450450
offset: Camera offset (ADU) for binomial counts mode.
451451
readout_noise_rms: Read-noise std in electrons for binomial counts mode.
452452
n_repeats: Number of independent binomial splits to average. Only used
453-
when split_type="binomial". Produces correlation-std and
453+
when split_type="binomial" and single-image mode (image2 is
454+
None). Ignored otherwise. Produces correlation-std and
454455
resolution-std in the result. Default: 1.
455456
rng: Random number generator, seed, or None for binomial split.
456457
"""
@@ -459,6 +460,14 @@ def calculate_frc(
459460

460461
use_binomial = split_type == "binomial" and single_image
461462

463+
if n_repeats > 1 and not use_binomial:
464+
warnings.warn(
465+
f"n_repeats={n_repeats} ignored: only used with "
466+
f"split_type='binomial' and single-image mode.",
467+
UserWarning,
468+
stacklevel=2,
469+
)
470+
462471
if use_binomial and n_repeats == 1:
463472
logger.info(_BINOMIAL_SINGLE_REPEAT_MSG)
464473

@@ -1274,7 +1283,8 @@ def fsc_resolution(
12741283
offset: Camera offset (ADU) for binomial counts mode.
12751284
readout_noise_rms: Read-noise std in electrons for binomial counts mode.
12761285
n_repeats: Number of independent binomial splits to average. Only used
1277-
when split_type="binomial". Default: 1.
1286+
when split_type="binomial" and single-image mode (image2 is
1287+
None). Ignored otherwise. Default: 1.
12781288
rng: Random number generator, seed, or None for binomial split.
12791289
12801290
Returns
@@ -1296,6 +1306,14 @@ def fsc_resolution(
12961306
single_image = image2 is None
12971307
use_binomial = split_type == "binomial" and single_image
12981308

1309+
if n_repeats > 1 and not use_binomial:
1310+
warnings.warn(
1311+
f"n_repeats={n_repeats} ignored: only used with "
1312+
f"split_type='binomial' and single-image mode.",
1313+
UserWarning,
1314+
stacklevel=2,
1315+
)
1316+
12991317
# --- Isotropic resampling (optional) ---
13001318
original_spacing_z = None
13011319
z_factor = 1.0
@@ -1345,7 +1363,11 @@ def fsc_resolution(
13451363
xy_res = 0.5 * (
13461364
angle_to_resolution.get(90, np.nan) + angle_to_resolution.get(270, np.nan)
13471365
)
1348-
return {"xy": xy_res, "z": z_res}
1366+
result = {"xy": xy_res, "z": z_res}
1367+
if use_binomial:
1368+
result["xy_std"] = 0.0
1369+
result["z_std"] = 0.0
1370+
return result
13491371

13501372
# --- Hist backend ---
13511373
spacing_list = _normalize_spacing(spacing, image1.ndim)

tests/metrics/test_bandlimited.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -417,6 +417,38 @@ def test_spectral_pcc_shape_mismatch_raises() -> None:
417417
spectral_pcc(a, b, spacing=0.065)
418418

419419

420+
def test_spectral_pcc_smooth() -> None:
421+
"""Spectral PCC with SG-smoothed weights returns valid result."""
422+
pred, tgt, _ = _make_synthetic_pair(noise_sigma=0.5, seed=7)
423+
r = spectral_pcc(pred, tgt, spacing=0.065, smooth=True)
424+
assert -1.0 <= r <= 1.0
425+
426+
427+
def test_spectral_pcc_smooth_identical() -> None:
428+
"""Smooth-weighted spectral PCC of identical images is ~1."""
429+
rng = np.random.default_rng(0)
430+
img = rng.standard_normal((64, 64)).astype(np.float32)
431+
r = spectral_pcc(img, img, spacing=0.065, smooth=True)
432+
assert r == pytest.approx(1.0, abs=1e-3)
433+
434+
435+
def test_spectral_pcc_nbins_low() -> None:
436+
"""nbins_low excludes DC bins without crashing."""
437+
pred, tgt, _ = _make_synthetic_pair(noise_sigma=0.5, seed=8)
438+
r0 = spectral_pcc(pred, tgt, spacing=0.065, nbins_low=0)
439+
r3 = spectral_pcc(pred, tgt, spacing=0.065, nbins_low=3)
440+
# Both should be valid floats; excluding DC may change the value
441+
assert -1.0 <= r0 <= 1.0
442+
assert -1.0 <= r3 <= 1.0
443+
444+
445+
def test_spectral_pcc_smooth_with_nbins_low() -> None:
446+
"""Smooth + nbins_low combined work correctly."""
447+
pred, tgt, _ = _make_synthetic_pair(noise_sigma=0.5, seed=9)
448+
r = spectral_pcc(pred, tgt, spacing=0.065, smooth=True, nbins_low=3)
449+
assert -1.0 <= r <= 1.0
450+
451+
420452
# ===================================================================
421453
# GPU / CPU parity tests
422454
# ===================================================================

0 commit comments

Comments
 (0)