|
51 | 51 | DEFAULT_LOWB_THRESHOLD = 50 |
52 | 52 | """The lower bound for the b-value so that the orientation is considered a DW volume.""" |
53 | 53 |
|
| 54 | +DEFAULT_GRADIENT_EPS = 1e-8 |
| 55 | +"""Epsilon value for b-vector normalization.""" |
| 56 | + |
54 | 57 | DEFAULT_HIGHB_THRESHOLD = 8000 |
55 | 58 | """A b-value cap for DWI data.""" |
56 | 59 |
|
@@ -180,7 +183,18 @@ def format_gradients( |
180 | 183 | formatted = formatted.astype(int) |
181 | 184 |
|
182 | 185 | # Transpose if column-major |
183 | | - return formatted.T if formatted.shape[0] == 4 and formatted.shape[1] != 4 else formatted |
| 186 | + formatted = formatted.T if formatted.shape[0] == 4 and formatted.shape[1] != 4 else formatted |
| 187 | + |
| 188 | + # Normalize b-vectors in-place |
| 189 | + bvecs = formatted[:, :3] |
| 190 | + norms = np.linalg.norm(bvecs, axis=1) |
| 191 | + mask = norms > DEFAULT_GRADIENT_EPS |
| 192 | + if np.any(mask): |
| 193 | + formatted[mask, :3] = bvecs[mask] / norms[mask, None] # Norm b-vectors |
| 194 | + formatted[mask, 3] *= norms[mask] # Scale b-values by norm |
| 195 | + formatted[~mask, :] = 0.0 # Zero-out small b-vectors |
| 196 | + |
| 197 | + return formatted |
184 | 198 |
|
185 | 199 |
|
186 | 200 | def validate_gradients( |
@@ -652,3 +666,81 @@ def transform_fsl_bvec( |
652 | 666 | ijk2ijk_xfm = np.linalg.inv(imaffine) @ xfm @ imaffine |
653 | 667 |
|
654 | 668 | return ijk2ijk_xfm[:3, :3] @ b_ijk[:3] |
| 669 | + |
| 670 | + |
| 671 | +def normalize_gradients(value: np.ndarray, eps: float = 1e-8, copy: bool = True) -> np.ndarray: |
| 672 | + """Normalize b-vectors in arrays of common shapes. |
| 673 | +
|
| 674 | + Parameters |
| 675 | + ---------- |
| 676 | + value : :obj:`~numpy.ndarray` |
| 677 | + Input array with shape one of: |
| 678 | + - (N, 3) : rows are b-vector components (e.g., [gx gy gz]) |
| 679 | + - (N, 4) : first 3 columns are b-vector components (e.g., [gx gy gz b]) |
| 680 | + - (3, N) : columns are b-vector components (e.g., [gx gy gz].T) |
| 681 | + - (4, N) : first 3 rows are b-vector components (e.g., [gx gy gz b].T) |
| 682 | + - (3,) or (1,3) or (3,1) : single b-vector |
| 683 | + Columns are checked first to disambiguate Nx3/Nx4 cases. |
| 684 | + eps : float, optional |
| 685 | + Threshold below which a vector is considered zero and left unchanged. |
| 686 | + copy : bool, optional |
| 687 | + If ``True``, returns a new array; modify in-place otherwise. |
| 688 | +
|
| 689 | + Returns |
| 690 | + ------- |
| 691 | + out : :obj:`~numpy.ndarray` |
| 692 | + Array with the same shape as ``value`` with each 3-component b-vector |
| 693 | + normalized. |
| 694 | + """ |
| 695 | + arr = np.asarray(value, dtype=float) |
| 696 | + |
| 697 | + # 1D single vector |
| 698 | + if arr.ndim == 1: |
| 699 | + if arr.size != 3: |
| 700 | + raise ValueError(GRADIENT_NORMALIZATION_LENGTH_ERROR_MSG) |
| 701 | + norm = np.linalg.norm(arr) |
| 702 | + if norm > eps: |
| 703 | + if copy: |
| 704 | + return arr / norm |
| 705 | + else: |
| 706 | + # Perform in-place normalization on the array view |
| 707 | + arr[:] = arr / norm |
| 708 | + return arr |
| 709 | + else: |
| 710 | + return arr.copy() if copy else arr |
| 711 | + |
| 712 | + if arr.ndim != 2: |
| 713 | + raise ValueError(GRADIENT_NORMALIZATION_SHAPE_ERROR_MSG) |
| 714 | + |
| 715 | + rows, cols = arr.shape |
| 716 | + |
| 717 | + # Prepare output (copy or in-place) |
| 718 | + normalized_arr = arr.copy() if copy else arr |
| 719 | + |
| 720 | + # Determine where the 3-component vectors live and create a (N, 3) view |
| 721 | + # Check columns first to make Nx3/Nx4 deterministic |
| 722 | + if cols == 4: |
| 723 | + # Nx4: first 3 columns are b-vectors components, last are b-values |
| 724 | + vecs = normalized_arr[:, :3] # shape (N, 3) |
| 725 | + elif cols == 3: |
| 726 | + # Nx3: rows are vectors |
| 727 | + vecs = normalized_arr # shape (N, 3) |
| 728 | + elif rows == 4: |
| 729 | + # 4xN: first 3 rows are b-vector components, last row are b-values |
| 730 | + # Create a (N, 3) view by transposing first 3 rows |
| 731 | + vecs = normalized_arr[:3, :].T # shape (N, 3) |
| 732 | + elif rows == 3: |
| 733 | + # 3xN: columns are vectors: normalize per-column |
| 734 | + vecs = normalized_arr.T # shape (N, 3) |
| 735 | + else: |
| 736 | + raise ValueError( |
| 737 | + GRADIENT_NORMALIZATION_UNRECOGNIZED_SHAPE_ERROR_MSG.format(shape=arr.shape) |
| 738 | + ) |
| 739 | + |
| 740 | + # Normalize in-place on vecs (which is a view into output) |
| 741 | + norms = np.linalg.norm(vecs, axis=1) |
| 742 | + mask = norms > eps |
| 743 | + if np.any(mask): |
| 744 | + vecs[mask] = vecs[mask] / norms[mask, None] |
| 745 | + |
| 746 | + return normalized_arr |
0 commit comments