Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion synaptic_reconstruction/napari.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
name: synaptic_reconstruction
display_name: Synaptic Reconstruction
display_name: SynapseNet
# see https://napari.org/stable/plugins/manifest.html for valid categories
categories: ["Image Processing", "Annotation"]
contributions:
Expand All @@ -16,6 +16,9 @@ contributions:
- id: synaptic_reconstruction.morphology
python_name: synaptic_reconstruction.tools.morphology_widget:MorphologyWidget
title: Morphology Analysis
- id: synaptic_reconstruction.vesicle_pooling
python_name: synaptic_reconstruction.tools.vesicle_pool_widget:VesiclePoolWidget
title: Vesicle Pooling

readers:
- command: synaptic_reconstruction.file_reader
Expand All @@ -32,3 +35,5 @@ contributions:
display_name: Distance Measurement
- command: synaptic_reconstruction.morphology
display_name: Morphology Analysis
- command: synaptic_reconstruction.vesicle_pooling
display_name: Vesicle Pooling
69 changes: 66 additions & 3 deletions synaptic_reconstruction/tools/base_widget.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,20 @@
import os
from pathlib import Path

import napari
import qtpy.QtWidgets as QtWidgets

from napari.utils.notifications import show_info
from qtpy.QtWidgets import (
QWidget, QVBoxLayout, QHBoxLayout, QLabel, QSpinBox, QComboBox, QCheckBox
)
from superqt import QCollapsible

try:
from napari_skimage_regionprops import add_table
except ImportError:
add_table = None


class BaseWidget(QWidget):
def __init__(self):
Expand All @@ -31,12 +38,14 @@ def _create_layer_selector(self, selector_name, layer_type="Image"):
layer_filter = napari.layers.Image
elif layer_type == "Labels":
layer_filter = napari.layers.Labels
elif layer_type == "Shapes":
layer_filter = napari.layers.Shapes
else:
raise ValueError("layer_type must be either 'Image' or 'Labels'.")

selector_widget = QtWidgets.QWidget()
image_selector = QtWidgets.QComboBox()
layer_label = QtWidgets.QLabel(f"{selector_name} Layer:")
layer_label = QtWidgets.QLabel(f"{selector_name}:")

# Populate initial options
self._update_selector(selector=image_selector, layer_filter=layer_filter)
Expand All @@ -58,9 +67,23 @@ def _create_layer_selector(self, selector_name, layer_type="Image"):
def _update_selector(self, selector, layer_filter):
"""Update a single selector with the current image layers in the viewer."""
selector.clear()
image_layers = [layer.name for layer in self.viewer.layers if isinstance(layer, layer_filter)] # if isinstance(layer, napari.layers.Image)
image_layers = [layer.name for layer in self.viewer.layers if isinstance(layer, layer_filter)]
selector.addItems(image_layers)

def _get_layer_selector_layer(self, selector_name):
"""Return the layer currently selected in a given selector."""
if selector_name in self.layer_selectors:
selector_widget = self.layer_selectors[selector_name]

# Retrieve the QComboBox from the QWidget's layout
image_selector = selector_widget.layout().itemAt(1).widget()

if isinstance(image_selector, QComboBox):
selected_layer_name = image_selector.currentText()
if selected_layer_name in self.viewer.layers:
return self.viewer.layers[selected_layer_name]
return None # Return None if layer not found

def _get_layer_selector_data(self, selector_name, return_metadata=False):
"""Return the data for the layer currently selected in a given selector."""
if selector_name in self.layer_selectors:
Expand Down Expand Up @@ -172,7 +195,7 @@ def _add_shape_param(self, names, values, min_val, max_val, step=1, title=None,
title=title[1] if title is not None else title, tooltip=tooltip
)
layout.addLayout(y_layout)

if len(names) == 3:
z_layout = QVBoxLayout()
z_param, _ = self._add_int_param(
Expand Down Expand Up @@ -262,3 +285,43 @@ def _get_file_path(self, name, textbox, tooltip=None):
else:
# Handle the case where the selected path is not a file
print("Invalid file selected. Please try again.")

def _handle_resolution(self, metadata, voxel_size_param, ndim):
# Get the resolution / voxel size from the layer metadata if available.
resolution = metadata.get("voxel_size", None)
if resolution is not None:
resolution = [resolution[ax] for ax in ("zyx" if ndim == 3 else "yx")]

# If user input was given then override resolution from metadata.
if voxel_size_param.value() != 0.0: # Changed from default.
resolution = ndim * [voxel_size_param.value()]

assert len(resolution) == ndim
return resolution

def _save_table(self, save_path, data):
ext = os.path.splitext(save_path)[1]
if ext == "": # No file extension given, By default we save to CSV.
file_path = f"{save_path}.csv"
data.to_csv(file_path, index=False)
elif ext == ".csv": # Extension was specified as csv
file_path = save_path
data.to_csv(file_path, index=False)
elif ext == ".xlsx": # We also support excel.
file_path = save_path
data.to_excel(file_path, index=False)
else:
raise ValueError("Invalid extension for table: {ext}. We support .csv or .xlsx.")
return file_path

def _add_properties_and_table(self, layer, table_data, save_path=""):
if layer.properties:
layer.properties = layer.properties.update(table_data)
else:
layer.properties = table_data
if add_table is not None:
add_table(layer, self.viewer)
# Save table to file if save path is provided.
if save_path != "":
file_path = self._save_table(self.save_path.text(), table_data)
show_info(f"INFO: Added table and saved file to {file_path}.")
66 changes: 12 additions & 54 deletions synaptic_reconstruction/tools/distance_measure_widget.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import os

import napari
import napari.layers
import numpy as np
Expand All @@ -11,27 +9,6 @@
from .base_widget import BaseWidget
from .. import distance_measurements

try:
from napari_skimage_regionprops import add_table
except ImportError:
add_table = None


def _save_distance_table(save_path, data):
ext = os.path.splitext(save_path)[1]
if ext == "": # No file extension given, By default we save to CSV.
file_path = f"{save_path}.csv"
data.to_csv(file_path, index=False)
elif ext == ".csv": # Extension was specified as csv
file_path = save_path
data.to_csv(file_path, index=False)
elif ext == ".xlsx": # We also support excel.
file_path = save_path
data.to_excel(file_path, index=False)
else:
raise ValueError("Invalid extension for table: {ext}. We support .csv or .xlsx.")
return file_path


class DistanceMeasureWidget(BaseWidget):
def __init__(self):
Expand Down Expand Up @@ -68,47 +45,33 @@ def __init__(self):
def _to_table_data(self, distances, seg_ids, endpoints1=None, endpoints2=None):
assert len(distances) == len(seg_ids), f"{distances.shape}, {seg_ids.shape}"
if seg_ids.ndim == 2:
table_data = {"label1": seg_ids[:, 0], "label2": seg_ids[:, 1], "distance": distances}
table_data = {"label_id1": seg_ids[:, 0], "label_id2": seg_ids[:, 1], "distance": distances}
else:
table_data = {"label": seg_ids, "distance": distances}
table_data = {"label_id": seg_ids, "distance": distances}
if endpoints1 is not None:
axis_names = "zyx" if endpoints1.shape[1] == 3 else "yx"
table_data.update({f"begin-{ax}": endpoints1[:, i] for i, ax in enumerate(axis_names)})
table_data.update({f"end-{ax}": endpoints2[:, i] for i, ax in enumerate(axis_names)})
return pd.DataFrame(table_data)

def _add_lines_and_table(self, lines, properties, table_data, name):
def _add_lines_and_table(self, lines, table_data, name):
line_layer = self.viewer.add_shapes(
lines,
name=name,
shape_type="line",
edge_width=2,
edge_color="red",
blending="additive",
properties=properties,
)
if add_table is not None:
add_table(line_layer, self.viewer)

if self.save_path.text() != "":
file_path = _save_distance_table(self.save_path.text(), table_data)

if self.save_path.text() != "":
show_info(f"Added distance lines and saved file to {file_path}.")
else:
show_info("Added distance lines.")
self._add_properties_and_table(line_layer, table_data, self.save_path.text())

def on_measure_seg_to_object(self):
segmentation = self._get_layer_selector_data(self.image_selector_name1)
object_data = self._get_layer_selector_data(self.image_selector_name2)
# get metadata from layer if available

# Get the resolution / voxel size.
metadata = self._get_layer_selector_data(self.image_selector_name1, return_metadata=True)
resolution = metadata.get("voxel_size", None)
if resolution is not None:
resolution = [v for v in resolution.values()]
# if user input is present override metadata
if self.voxel_size_param.value() != 0.0: # changed from default
resolution = segmentation.ndim * [self.voxel_size_param.value()]
resolution = self._handle_resolution(metadata, self.voxel_size_param, segmentation.ndim)

(distances,
endpoints1,
Expand All @@ -117,28 +80,23 @@ def on_measure_seg_to_object(self):
segmentation=segmentation, segmented_object=object_data, distance_type="boundary",
resolution=resolution
)
lines, properties = distance_measurements.create_object_distance_lines(
lines, _ = distance_measurements.create_object_distance_lines(
distances=distances,
endpoints1=endpoints1,
endpoints2=endpoints2,
seg_ids=seg_ids,
)
table_data = self._to_table_data(distances, seg_ids, endpoints1, endpoints2)
self._add_lines_and_table(lines, properties, table_data, name="distances")
self._add_lines_and_table(lines, table_data, name="distances")

def on_measure_pairwise(self):
segmentation = self._get_layer_selector_data(self.image_selector_name1)
if segmentation is None:
show_info("Please choose a segmentation.")
return
# get metadata from layer if available

metadata = self._get_layer_selector_data(self.image_selector_name1, return_metadata=True)
resolution = metadata.get("voxel_size", None)
if resolution is not None:
resolution = [v for v in resolution.values()]
# if user input is present override metadata
if self.voxel_size_param.value() != 0.0: # changed from default
resolution = segmentation.ndim * [self.voxel_size_param.value()]
resolution = self._handle_resolution(metadata, self.voxel_size_param, segmentation.ndim)

(distances,
endpoints1,
Expand All @@ -153,7 +111,7 @@ def on_measure_pairwise(self):
distances=properties["distance"],
seg_ids=np.concatenate([properties["id_a"][:, None], properties["id_b"][:, None]], axis=1)
)
self._add_lines_and_table(lines, properties, table_data, name="pairwise-distances")
self._add_lines_and_table(lines, table_data, name="pairwise-distances")

def _create_settings_widget(self):
setting_values = QWidget()
Expand Down
Loading
Loading