3030
3131import numpy as np
3232
33- from cubic .cuda import asnumpy , to_same_device , get_array_module , check_same_device
33+ from cubic .cuda import asnumpy , to_same_device , check_same_device
3434from cubic .image_utils import tukey_window , hamming_window
3535
3636from .spectral .dcr import dcr_resolution
@@ -432,7 +432,7 @@ def smooth_spectral_weights(
432432 np.ndarray
433433 Weights in [0, 1].
434434 """
435- from scipy . signal import savgol_filter
435+ from cubic . scipy import signal as csignal
436436
437437 log_p = np .log (np .maximum (power , 1e-30 ))
438438 n = len (log_p )
@@ -446,7 +446,7 @@ def smooth_spectral_weights(
446446 if wlen % 2 == 0 :
447447 wlen -= 1
448448 wlen = max (wlen , min_wlen )
449- log_p_smooth = savgol_filter (log_p , wlen , sg_polyorder )
449+ log_p_smooth = csignal . savgol_filter (log_p , wlen , sg_polyorder )
450450 p_smooth = np .exp (log_p_smooth )
451451
452452 n2 = noise_floor ** 2
@@ -525,7 +525,7 @@ def _estimate_noise_baseline(
525525 noise_baseline : np.ndarray
526526 Frequency-dependent noise floor N(k), same length as *power*.
527527 """
528- from scipy . signal import savgol_filter
528+ from cubic . scipy import signal as csignal
529529
530530 n = len (power )
531531
@@ -541,7 +541,7 @@ def _estimate_noise_baseline(
541541 wlen = max (wlen , min_wlen )
542542
543543 log_p = np .log (np .maximum (power , 1e-30 ))
544- log_p_smooth = savgol_filter (log_p , wlen , sg_polyorder )
544+ log_p_smooth = csignal . savgol_filter (log_p , wlen , sg_polyorder )
545545
546546 # Running low-quantile of smoothed log-power
547547 log_n = _running_quantile_1d (log_p_smooth , window = quantile_window , q = quantile )
@@ -713,14 +713,13 @@ def _spectral_pcc_baseline(
713713 edges = to_same_device (edges_cpu , prediction )
714714 bid = radial_bin_id (prediction .shape , edges , spacing = spacing_seq )
715715
716- xp = get_array_module (prediction )
717716 n_bins_needed = int (asnumpy (bid [bid >= 0 ].max ())) + 1 if np .any (bid >= 0 ) else 0
718717 if len (w_bins ) < n_bins_needed :
719718 raise ValueError (
720719 f"frozen_weights has { len (w_bins )} bins but binning requires "
721720 f"{ n_bins_needed } ; check that bin_delta and image shape match."
722721 )
723- w_bins_dev = xp .asarray (w_bins ) if xp is not np else w_bins
722+ w_bins_dev = to_same_device ( np .asarray (w_bins , dtype = np . float32 ), prediction )
724723
725724 W = np .zeros_like (bid , dtype = np .float32 )
726725 valid = bid >= 0
@@ -751,13 +750,14 @@ def frc_weights(
751750 alpha : float = 2.0 ,
752751 nbins_low : int = 3 ,
753752 smooth_window : int = 5 ,
753+ split_type : str = "binomial" ,
754+ n_repeats : int = 3 ,
755+ rng : np .random .Generator | int | None = 42 ,
754756) -> np .ndarray :
755757 """Per-bin weights derived from single-image FRC reproducibility.
756758
757- Splits the GT image via checkerboard, computes ring-wise FRC,
758- and converts the FRC curve to monotone-decreasing weights that
759- indicate per-frequency reliability. Operates entirely in
760- index-frequency units.
759+ Uses binomial splitting (same-shape, full frequency coverage) by
760+ default, falling back to checkerboard if requested.
761761
762762 Parameters
763763 ----------
@@ -766,13 +766,21 @@ def frc_weights(
766766 bin_delta : int
767767 Radial-bin width in index units.
768768 threshold : float
769- FRC threshold (default 0.143 = 1/7 ).
769+ FRC threshold (default 0.143 = 1-bit ).
770770 alpha : float
771771 Weight exponent (default 2.0 — sharpens the transition).
772772 nbins_low : int
773773 Number of lowest bins to zero (DC / background exclusion).
774774 smooth_window : int
775775 Median-filter window (clamped to odd, >= 3).
776+ split_type : str
777+ ``"binomial"`` (default, same-shape, Rieger et al. 2024) or
778+ ``"checkerboard"`` (subsampled halves, Koho et al. 2019).
779+ n_repeats : int
780+ Number of independent binomial splits to average (default 3).
781+ Ignored for checkerboard.
782+ rng : Generator, int, or None
783+ Random seed for binomial split reproducibility.
776784
777785 Returns
778786 -------
@@ -781,7 +789,7 @@ def frc_weights(
781789 binning matching ``radial_edges(image.shape, bin_delta,
782790 spacing=None)``).
783791 """
784- from scipy . ndimage import median_filter
792+ from cubic . scipy import ndimage as cndimage
785793
786794 # --- lazy import to avoid circular deps ---
787795 from .spectral .frc import calculate_frc as _calculate_frc
@@ -791,16 +799,26 @@ def frc_weights(
791799 if image .shape [0 ] != image .shape [1 ]:
792800 raise ValueError (f"frc_weights requires square images, got { image .shape } ." )
793801
794- # 1. FRC curve via existing public API (no spacing → index units)
795- result = _calculate_frc (
796- image ,
802+ # 1. FRC curve via public API (no spacing → index units)
803+ frc_kwargs = dict (
797804 image2 = None ,
798805 backend = "hist" ,
799806 bin_delta = bin_delta ,
800807 zero_padding = False ,
801808 disable_hamming = False ,
802- average = True ,
809+ split_type = split_type ,
803810 )
811+ if split_type == "binomial" :
812+ frc_kwargs .update (
813+ counts_mode = "poisson_thinning" ,
814+ n_repeats = n_repeats ,
815+ rng = rng ,
816+ average = False , # no checkerboard reverse averaging
817+ )
818+ else :
819+ frc_kwargs ["average" ] = True # checkerboard forward+reverse
820+
821+ result = _calculate_frc (image , ** frc_kwargs )
804822 frc_curve = np .clip (
805823 np .asarray (result .correlation ["correlation" ], dtype = np .float64 ),
806824 - 1.0 ,
@@ -820,16 +838,18 @@ def frc_weights(
820838 freq_nyq_full = float (np .floor (image .shape [0 ] / 2.0 ))
821839 freq_full_norm = radii_full_idx / freq_nyq_full
822840
823- # FRC was computed on checkerboard halves (shape//2), so its [0,1]
824- # normalised axis covers only the half-image Nyquist. Derive the
825- # frequency scaling factor from actual image dimensions.
826- freq_nyq_half = float (np .floor (image .shape [0 ] // 2 / 2.0 ))
827- freq_scale = freq_nyq_full / freq_nyq_half # typically 2.0
828- freq_full_in_half = np .clip (freq_full_norm * freq_scale , 0.0 , 1.0 )
841+ # 3. Map FRC frequency axis onto full-resolution bins
842+ if split_type == "checkerboard" :
843+ # Checkerboard halves are shape//2 → FRC [0,1] covers half Nyquist
844+ freq_nyq_half = float (np .floor (image .shape [0 ] // 2 / 2.0 ))
845+ freq_scale = freq_nyq_full / freq_nyq_half
846+ interp_x = np .clip (freq_full_norm * freq_scale , 0.0 , 1.0 )
847+ else :
848+ # Binomial split: same shape → FRC and full bins share the same axis
849+ interp_x = freq_full_norm
829850
830- # 3. Interpolate FRC onto full-resolution bins
831851 frc_full = np .interp (
832- freq_full_in_half ,
852+ interp_x ,
833853 freq_norm ,
834854 frc_curve ,
835855 left = float (frc_curve [0 ]),
@@ -849,7 +869,7 @@ def frc_weights(
849869 # 5. Smooth + monotone non-increasing envelope
850870 sw = smooth_window | 1 # clamp to odd
851871 sw = max (3 , min (sw , len (w ) | 1 ))
852- w = median_filter (w , size = sw ).astype (np .float64 )
872+ w = cndimage . median_filter (w , size = sw ).astype (np .float64 )
853873 w = np .maximum .accumulate (w [::- 1 ])[::- 1 ].copy ()
854874
855875 # 6. Low-k exclusion
@@ -871,6 +891,9 @@ def spectral_pcc_frcw(
871891 alpha : float = 2.0 ,
872892 nbins_low : int = 3 ,
873893 smooth_window : int = 5 ,
894+ split_type : str = "binomial" ,
895+ n_repeats : int = 3 ,
896+ rng : np .random .Generator | int | None = 42 ,
874897 frozen_weights : np .ndarray | None = None ,
875898) -> float :
876899 """Spectral PCC weighted by single-image FRC reproducibility.
@@ -894,6 +917,12 @@ def spectral_pcc_frcw(
894917 Number of lowest bins to exclude.
895918 smooth_window : int
896919 Median-filter window for weight smoothing.
920+ split_type : str
921+ ``"binomial"`` or ``"checkerboard"`` — passed to :func:`frc_weights`.
922+ n_repeats : int
923+ Number of binomial splits to average.
924+ rng : Generator, int, or None
925+ Random seed for binomial split.
897926 frozen_weights : np.ndarray, optional
898927 Pre-computed 1-D radial-bin weights. If given, skip FRC weight
899928 estimation. Must match index-unit binning for the target shape.
@@ -920,6 +949,9 @@ def spectral_pcc_frcw(
920949 alpha = alpha ,
921950 nbins_low = nbins_low ,
922951 smooth_window = smooth_window ,
952+ split_type = split_type ,
953+ n_repeats = n_repeats ,
954+ rng = rng ,
923955 )
924956
925957 # 2. Zero-weight-mass guard
@@ -950,14 +982,13 @@ def spectral_pcc_frcw(
950982 edges = to_same_device (edges_cpu , prediction )
951983 bid = radial_bin_id (prediction .shape , edges , spacing = None )
952984
953- xp = get_array_module (prediction )
954985 n_bins_needed = int (asnumpy (bid [bid >= 0 ].max ())) + 1 if np .any (bid >= 0 ) else 0
955986 if len (w_bins ) < n_bins_needed :
956987 raise ValueError (
957988 f"frozen_weights has { len (w_bins )} bins but binning requires "
958989 f"{ n_bins_needed } ; check that bin_delta and image shape match."
959990 )
960- w_bins_dev = xp .asarray (w_bins ) if xp is not np else w_bins
991+ w_bins_dev = to_same_device ( np .asarray (w_bins , dtype = np . float32 ), prediction )
961992
962993 W = np .zeros_like (bid , dtype = np .float32 )
963994 valid = bid >= 0
@@ -1332,14 +1363,13 @@ def spectral_pcc(
13321363 edges = to_same_device (edges_cpu , prediction )
13331364 bid = radial_bin_id (prediction .shape , edges , spacing = spacing_seq )
13341365
1335- xp = get_array_module (prediction )
13361366 n_bins_needed = int (asnumpy (bid [bid >= 0 ].max ())) + 1 if np .any (bid >= 0 ) else 0
13371367 if len (w_bins ) < n_bins_needed :
13381368 raise ValueError (
13391369 f"Weight vector has { len (w_bins )} bins but binning requires "
13401370 f"{ n_bins_needed } ; check that bin_delta and image shape match."
13411371 )
1342- w_bins_dev = xp .asarray (w_bins ) if xp is not np else w_bins
1372+ w_bins_dev = to_same_device ( np .asarray (w_bins , dtype = np . float32 ), prediction )
13431373
13441374 # Build weight volume: map bin weights through bin_id
13451375 W = np .zeros_like (bid , dtype = np .float32 )
0 commit comments