44
55from mrinufft .density .utils import flat_traj
66from mrinufft ._utils import get_array_module
7+ from mrinufft ._array_compat import with_numpy_cupy
78from .utils import register_smaps
89import numpy as np
10+ from numpy .typing import NDArray
11+
12+ from collections .abc import Callable
913
1014
1115def _extract_kspace_center (
12- kspace_data ,
13- kspace_loc ,
14- threshold = None ,
15- density = None ,
16- window_fun = "ellipse" ,
17- ):
16+ kspace_data : NDArray ,
17+ kspace_loc : NDArray ,
18+ threshold : float | tuple [ float , ...] = None ,
19+ density : NDArray | None = None ,
20+ window_fun : str | Callable [[ NDArray ], NDArray ] = "ellipse" ,
21+ ) -> tuple [ NDArray , NDArray , NDArray | None ] :
1822 r"""Extract k-space center and corresponding sampling locations.
1923
2024 The extracted center of the k-space, i.e. both the kspace locations and
@@ -81,7 +85,7 @@ def _extract_kspace_center(
8185 return data_thresholded , center_locations , dc
8286 else :
8387 if callable (window_fun ):
84- window = window_fun (center_locations )
88+ window = window_fun (kspace_loc )
8589 else :
8690 if window_fun in ["hann" , "hanning" , "hamming" ]:
8791 radius = xp .linalg .norm (kspace_loc , axis = 1 )
@@ -99,16 +103,16 @@ def _extract_kspace_center(
99103@register_smaps
100104@flat_traj
101105def low_frequency (
102- traj ,
103- shape ,
104- kspace_data ,
105- backend ,
106+ traj : NDArray ,
107+ shape : tuple [ int , ...] ,
108+ kspace_data : NDArray ,
109+ backend : str ,
106110 threshold : float | tuple [float , ...] = 0.1 ,
107- density = None ,
108- window_fun : str = "ellipse" ,
111+ density : NDArray | None = None ,
112+ window_fun : str | Callable [[ NDArray ], NDArray ] = "ellipse" ,
109113 blurr_factor : int | float | tuple [float , ...] = 0.0 ,
110114 mask : bool = False ,
111- ):
115+ ) -> tuple [ NDArray , NDArray ] :
112116 """
113117 Calculate low-frequency sensitivity maps.
114118
@@ -190,3 +194,61 @@ def low_frequency(
190194 SOS = np .linalg .norm (Smaps , axis = 0 ) + 1e-10
191195 Smaps = Smaps / SOS
192196 return Smaps , SOS
197+
198+
199+ @with_numpy_cupy
200+ def coil_compression (
201+ kspace_data : NDArray ,
202+ K : int | float ,
203+ traj : NDArray | None = None ,
204+ krad_thresh : float | None = None ,
205+ ) -> NDArray :
206+ """
207+ Coil compression using principal component analysis on k-space data.
208+
209+ Parameters
210+ ----------
211+ kspace_data : NDArray
212+ Multi-coil k-space data. Shape: (n_coils, n_samples).
213+ K : int or float
214+ Number of virtual coils to retain (if int), or energy threshold (if
215+ float between 0 and 1).
216+ traj : NDArray, optional
217+ Sampling trajectory. Shape: (n_samples, n_dims).
218+ krad_thresh : float, optional
219+ Relative k-space radius (as a fraction of maximum) to use for selecting
220+ the calibration region for principal component analysis. If None, use
221+ all k-space samples.
222+
223+ Returns
224+ -------
225+ NDArray
226+ Coil-compressed data. Shape: (K, n_samples) if K is int, number of
227+ retained components otherwise.
228+ """
229+ xp = get_array_module (kspace_data )
230+
231+ if krad_thresh is not None and traj is not None :
232+ traj_rad = xp .sqrt (xp .sum (traj ** 2 , axis = - 1 ))
233+ center_data = kspace_data [:, traj_rad < krad_thresh * xp .max (traj )]
234+ elif krad_thresh is None :
235+ center_data = kspace_data
236+ else :
237+ raise ValueError ("traj and krad_thresh must be specified." )
238+
239+ # Compute the covar matrix of selected data
240+ cov = center_data @ center_data .T .conj ()
241+ w , v = xp .linalg .eigh (cov )
242+ # sort eigenvalues largest to smallest
243+ si = xp .argsort (w )[::- 1 ]
244+ w_sorted = w [si ]
245+ v_sorted = v [si ]
246+ if isinstance (K , float ):
247+ # retain enough components to reach energy K
248+ w_cumsum = xp .cumsum (w_sorted ) # from largest to smallest
249+ total_energy = xp .sum (w_sorted )
250+ K = int (xp .searchsorted (w_cumsum / total_energy , K , side = "left" ) + 1 )
251+ K = min (K , w_sorted .size )
252+ V = v_sorted [:K ] # use top K component
253+ compress_data = V @ kspace_data
254+ return compress_data
0 commit comments