diff --git a/src/open_dive/scripts/run.py b/src/open_dive/scripts/run.py index 0e3289f..deb6588 100644 --- a/src/open_dive/scripts/run.py +++ b/src/open_dive/scripts/run.py @@ -20,6 +20,7 @@ def main(): "Diffusion glyph options (tensors and ODFs)" ) window_group = parser.add_argument_group("Window options") + roi_mask_group = parser.add_argument_group("ROI mask options") scalar_group.add_argument( "-n", @@ -65,6 +66,12 @@ def main(): type=Path, help="Path to binary mask to generate a glass brain.", ) + scalar_group.add_argument( + "--glass_brain_opacity", + type=float, + default=0.33, + help="Opacity of the glass brain in range (0, 1]. Default is 0.33.", + ) tractography_group.add_argument( ## plot tractogram with slices @@ -101,6 +108,58 @@ def main(): action="store_true", help="Whether to show a tractography values colorbar. Default is False.", ) + tractography_group.add_argument( + "--tractography_is_categorical_values", + action="store_true", + help="If provided, the tractography values are treated as categorical values. Default is False.", + ) + tractography_group.add_argument( + "--tractography_categorical_values", + type=str, + nargs="+", + help="List of categorical values for tractography. Must match number of tractography files.", + ) + tractography_group.add_argument( + "--tractography_categorical_reference_values", + type=str, + nargs="+", + help="List of reference categorical values for tractography.", + ) + + roi_mask_group.add_argument( + "--roi_mask_path", + type=Path, + nargs="+", # Accept one or more arguments + help="Path to binary mask(s) to plot as ROI mask(s).", + ) + roi_mask_group.add_argument( + "--roi_mask_values", + type=float, + nargs="+", + help="Values to use for coloring each ROI mask (must match number of ROI mask files)", + ) + roi_mask_group.add_argument( + "--roi_mask_cmap", + help='Matplotlib or cmcrameri colormap to use for ROI mask. Default is "plasma" if --roi_mask_values is provided, otherwise "Set1".', + ) + roi_mask_group.add_argument( + "--roi_mask_cmap_range", + type=float, + nargs=2, + help="Range to use for the colormap. Default is (0, 1).", + ) + roi_mask_group.add_argument( + "--roi_mask_opacity", + type=float, + nargs="+", + default=[0.33], + help="Value to use for the ROI mask opacity in range (0, 1]. If a list, each value corresponds to a ROI mask in --roi_mask_path. Default is 0.33.", + ) + roi_mask_group.add_argument( + "--roi_mask_colorbar", + action="store_true", + help="Whether to show a ROI mask values colorbar. Default is False.", + ) glyph_group.add_argument( "--tensor_path", @@ -177,6 +236,9 @@ def main(): tractography_cmap=args.tractography_cmap, tractography_cmap_range=args.tractography_cmap_range, tractography_colorbar=args.tractography_colorbar, + tractography_is_categorical_values=args.tractography_is_categorical_values, + tractography_categorical_values=args.tractography_categorical_values, + tractography_categorical_reference_values=args.tractography_categorical_reference_values, tensor_path=args.tensor_path, odf_path=args.odf_path, sh_basis=args.sh_basis, @@ -184,4 +246,11 @@ def main(): azimuth=args.azimuth, elevation=args.elevation, glass_brain_path=args.glass_brain, + glass_brain_opacity=args.glass_brain_opacity, + roi_mask_path=args.roi_mask_path, + roi_mask_values=args.roi_mask_values, + roi_mask_cmap=args.roi_mask_cmap, + roi_mask_cmap_range=args.roi_mask_cmap_range, + roi_mask_opacity=args.roi_mask_opacity, + roi_mask_colorbar=args.roi_mask_colorbar ) diff --git a/src/open_dive/viz.py b/src/open_dive/viz.py index fac22ff..5f31bf3 100644 --- a/src/open_dive/viz.py +++ b/src/open_dive/viz.py @@ -43,12 +43,22 @@ def plot_nifti( tractography_cmap: str | None = None, tractography_cmap_range: tuple[int, int] | None = None, tractography_colorbar: bool = False, + tractography_is_categorical_values: bool = False, + tractography_categorical_values: list[str] | None = None, + tractography_categorical_reference_values: list[str] | None = None, volume_idx: int | None = None, tensor_path: os.PathLike | None = None, odf_path: os.PathLike | None = None, sh_basis: str = "descoteaux07", scale: int = 1, glass_brain_path: os.PathLike | None = None, + glass_brain_opacity: float | None = 0.33, + roi_mask_cmap: str | None = None, + roi_mask_cmap_range: tuple[int, int] | None = None, + roi_mask_values: list[float] | None = None, + roi_mask_colorbar: bool = False, + roi_mask_path: list[os.PathLike] | None = None, + roi_mask_opacity: list[float] | None = [0.33], **kwargs, ) -> None: """Create a 2D rendering of a NIFTI slice. @@ -87,6 +97,12 @@ def plot_nifti( Optional range to use for the colormap tractography_colorbar : bool, default False Whether to show a colorbar for the tractography + tractography_is_categorical_values : bool, default False + Whether the tractography values are categorical (discrete) or continuous + tractography_categorical_values : list of str, optional + Optional categorical values to color the tractography with + tractography_categorical_reference_values : list of str, optional + Optional reference values for categorical tractography values, used to map colors volume_idx : int, optional Index of the volume to display if the image is 4D tensor_path : os.PathLike, optional @@ -99,7 +115,20 @@ def plot_nifti( Scale of the tensor glyphs or ODF glyphs glass_brain_path : os.PathLike, optional Optional glass brain mask to overlay - + glass_brain_opacity : float, default 0.33 + Opacity of the glass brain mask + roi_mask_cmap : str, default "Set1" or "plasma" + Optional colormap to use for the ROI mask, by default "Set1" if roi_mask_values, otherwise "plasma" + roi_mask_cmap_range : tuple of float, default (0, 1) if roi_mask_values is not None + Optional range to use for the colormap + roi_mask_values : list of float, optional + Optional values to color the ROI mask with + roi_mask_colorbar : bool, default False + Whether to show a colorbar for the ROI mask + roi_mask_path : list of os.PathLike, optional + Optional ROI mask(s) to plot with slices. Can provide multiple files + roi_mask_opacity : list of float, default [0.33] + Optional opacity value for the ROI masks between (0, 1) **kwargs Additional keyword arguments to pass to fury.actor.slicer """ @@ -122,10 +151,30 @@ def plot_nifti( if tractography_cmap is None: tractography_cmap = "Set1" if tractography_values is None else "plasma" if tractography_cmap_range is None: - tractography_cmap_range = ( - (0, 1) if tractography_values is None else (min(tractography_values), max(tractography_values)) - ) - tractography_cbar_labels = tractography_values is not None + if not tractography_is_categorical_values: + tractography_cmap_range = ( + (0, 1) if tractography_values is None else (min(tractography_values), max(tractography_values)) + ) + tractography_cbar_labels = tractography_values is not None + else: + if tractography_categorical_reference_values is not None: + tractography_unique_values = list(dict.fromkeys(tractography_categorical_reference_values)) + print(tractography_categorical_reference_values) + assert all(val in tractography_unique_values for val in tractography_categorical_values), f"All categorical values must be in the reference values: reference: {tractography_categorical_reference_values} \nvalues: {tractography_categorical_values}" + else: + tractography_unique_values = np.unique(tractography_categorical_values) if tractography_categorical_values is not None else None + tractography_cmap_range = (0, len(tractography_unique_values) - 1) if tractography_unique_values is not None else (0, 1) + tractography_cbar_labels = tractography_categorical_reference_values is not None + print(tractography_unique_values) + + #same for ROI mask values + if roi_mask_cmap is None: + roi_mask_cmap = "Set1" if roi_mask_values is None else "plasma" + if roi_mask_cmap_range is None: + roi_mask_cmap_range = ( + (0, 1) if roi_mask_values is None else (min(roi_mask_values), max(roi_mask_values)) + ) + roi_mask_cbar_labels = roi_mask_values is not None # Set up scene and bounds scene = window.Scene() @@ -177,14 +226,58 @@ def plot_nifti( ) scene.add(scalar_bar) + #add roi masks + if roi_mask_path is not None: + cmap = plt.get_cmap(roi_mask_cmap) + + # Set to range + if roi_mask_values is not None: + norm = plt.Normalize(vmin=roi_mask_cmap_range[0], vmax=roi_mask_cmap_range[1]) + colors = [cmap(norm(val)) for val in roi_mask_values] + else: + colors = [cmap(i) for i in range(len(roi_mask_path))] + + # Apply colorbar + if roi_mask_colorbar: + roi_bar = _create_colorbar_actor( + value_range=roi_mask_cmap_range, + colorbar_position=(0.1, 0.1), + colorbar_height=0.5, + colorbar_width=0.1, + cmap=cmap, + labels=roi_mask_cbar_labels, + ) + scene.add(roi_bar) + + # Add each ROI mask with its corresponding color + roi_actors = _create_roi_mask_actor( + mask_nifti=roi_mask_path, + colors=colors, + mask_opacities=roi_mask_opacity + ) + for roi_actor in roi_actors: + scene.add(roi_actor) + # Add tractography if tractography_path is not None: cmap = plt.get_cmap(tractography_cmap) # Set to range if tractography_values is not None: - norm = plt.Normalize(vmin=tractography_cmap_range[0], vmax=tractography_cmap_range[1]) - colors = [cmap(norm(val)) for val in tractography_values] + if not tractography_is_categorical_values: + norm = plt.Normalize(vmin=tractography_cmap_range[0], vmax=tractography_cmap_range[1]) + colors = [cmap(norm(val)) for val in tractography_values] + elif tractography_categorical_values is not None: + colors_lst = plt.cm.jet(np.linspace(0, 1, len(tractography_unique_values))) + colors_idx_map = {idx:color for idx,color in enumerate(colors_lst)} + tractography_unique_values_idx_map = {val:idx for idx, val in enumerate(tractography_unique_values)} + print(tractography_unique_values) + print(tractography_unique_values_idx_map) + colors = [colors_idx_map[tractography_unique_values_idx_map[val]] for val in tractography_categorical_values] + #print the color mapping for the values + print("Color mapping being used:") + for k,v in tractography_unique_values_idx_map.items(): + print(f"{k}: {colors_idx_map[v]}") else: colors = [cmap(i) for i in range(len(tractography_path))] @@ -231,7 +324,8 @@ def plot_nifti( odf_actor.display_extent(*extent) if glass_brain_path: - glass_brain_actor = _create_glass_brain_actor(glass_brain_path) + print(glass_brain_opacity) + glass_brain_actor = _create_glass_brain_actor(glass_brain_path, opacity=glass_brain_opacity) scene.add(glass_brain_actor) if scene_bound_data is None: @@ -301,11 +395,74 @@ def _create_glass_brain_actor( # Step 4: Dilate the thresholded mask with 2 passes mask_dilated = binary_dilation(mask_thres, iterations=dilation_iters).astype(np.uint8) + mask_final = mask_dilated - mask_thres + # Create a surface actor - glass_brain_actor = contour_from_roi(mask_dilated, affine=new_affine, opacity=opacity, color=(0.5, 0.5, 0.5)) + glass_brain_actor = contour_from_roi(mask_final, affine=new_affine, opacity=opacity, color=(0.5, 0.5, 0.5)) return glass_brain_actor +def _create_roi_mask_actor( + mask_nifti: list[os.PathLike], + colors: list[tuple[float, float, float]], + mask_opacities: list[float] = [0.33], + resample_factor: int = 2, + smooth_sigma: float = 2, + dilation_iters: int = 2, +) -> Actor: + """Create "glass ROI" visualizations from a binary masks. + + Parameters + ---------- + mask_nifti : os.PathLike + Path to binary mask NIFTI image + resample_factor : int, default 3 + Factor to upsample the mask by + smooth_sigma : float, default 2 + Standard deviation for Gaussian smoothing + dilation_iters : int, default 2 + Number of iterations for binary dilation + mask_opacities : list[float], default [0.33] + Opacities of the ROI masks + colors : list[tuple[float, float, float]], default [(0.5, 0.5, 0.5)] + Colors of the ROI masks + + Returns + ------- + glass_brain : fury.actor.surface + ROI mask surface actor + """ + + roi_actors = [] + roi_opacities = mask_opacities * len(mask_nifti) if len(mask_opacities) == 1 else mask_opacities + for i,(mask_file, color) in enumerate(zip(mask_nifti, colors)): + #load the mask + mask_nifti = nib.load(mask_file) + mask = mask_nifti.get_fdata() + affine = mask_nifti.affine + zooms = mask_nifti.header.get_zooms()[:3] + + # Step 1: Upsample (regrid) the mask by a factor of 5 + new_zooms = tuple(z / resample_factor for z in zooms) + mask_up, new_affine = reslice(mask, affine, zooms, new_zooms) + + # Step 2: Apply Gaussian smoothing with standard deviation 2 + mask_smooth = gaussian_filter(mask_up, sigma=smooth_sigma) + + # Step 3: Threshold the smoothed mask at 0.5 + mask_thres = (mask_smooth > 0.5).astype(np.uint8) + + # Step 4: Dilate the thresholded mask with 2 passes + mask_dilated = binary_dilation(mask_thres, iterations=dilation_iters).astype(np.uint8) + + # Create a surface actor + roi_mask_actor = contour_from_roi(mask_dilated, affine=new_affine, opacity=roi_opacities[i], color=color) + roi_actors.append(roi_mask_actor) + return roi_actors + + + + def _create_nifti_actor( nifti_path: os.PathLike, volume_idx: int | None = None, @@ -347,10 +504,17 @@ def _create_colorbar_actor( colorbar_height: float = 0.5, colorbar_width: float = 0.1, cmap: Colormap | None = None, - labels: bool = True, + labels: list[str] | bool = True, ) -> vtk.vtkScalarBarActor: """Create a colorbar actor for the scene.""" + # if isinstance(labels, list): + # n_labels = len(labels) + # lut = vtk.vtkLookupTable() + # lut.SetNumberOfTableValues(n_labels) + # lut.Build() + # for i, label in enumerate(labels): + # Create a grayscale colormap (from black to white) lut = vtk.vtkLookupTable() lut.SetNumberOfTableValues(256) # Full grayscale (256 levels)