Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
142 changes: 117 additions & 25 deletions flamingo_tools/segmentation/postprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,31 @@ def downscaled_centroids(
return new_array


def graph_connected_components(coords, min_edge_distance, min_component_length):
graph = nx.Graph()
for num, pos in coords.items():
graph.add_node(num, pos=pos)

# create edges between points whose distance is less than threshold min_edge_distance
for i in coords:
for j in coords:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks wrong. Shouldn't i, j be indices into the respective array or list?
So
for i in range(len(coords))
etc. ?

if i < j:
dist = math.dist(coords[i], coords[j])
if dist <= min_edge_distance:
graph.add_edge(i, j, weight=dist)

components = list(nx.connected_components(graph))

# remove connected components with less nodes than threshold min_component_length
for component in components:
if len(component) < min_component_length:
for c in component:
graph.remove_node(c)

components = [list(s) for s in nx.connected_components(graph)]
return components


def components_sgn(
table: pd.DataFrame,
keyword: str = "distance_nn100",
Expand Down Expand Up @@ -370,27 +395,7 @@ def components_sgn(
for index, element in zip(labels_subset, centroids_subset):
coords[index] = element

graph = nx.Graph()
for num, pos in coords.items():
graph.add_node(num, pos=pos)

# create edges between points whose distance is less than threshold min_edge_distance
for i in coords:
for j in coords:
if i < j:
dist = math.dist(coords[i], coords[j])
if dist <= min_edge_distance:
graph.add_edge(i, j, weight=dist)

components = list(nx.connected_components(graph))

# remove connected components with less nodes than threshold min_component_length
for component in components:
if len(component) < min_component_length:
for c in component:
graph.remove_node(c)

components = [list(s) for s in nx.connected_components(graph)]
components = graph_connected_components(coords, min_edge_distance, min_component_length)

# add original coordinates closer to eroded component than threshold
if postprocess_graph:
Expand All @@ -410,7 +415,7 @@ def components_sgn(
return components


def label_components(
def label_components_sgn(
table: pd.DataFrame,
min_size: int = 1000,
threshold_erode: Optional[float] = None,
Expand Down Expand Up @@ -477,9 +482,96 @@ def postprocess_sgn_seg(
Dataframe with component labels.
"""

comp_labels = label_components(table, min_size=min_size, threshold_erode=threshold_erode,
min_component_length=min_component_length,
min_edge_distance=min_edge_distance, iterations_erode=iterations_erode)
comp_labels = label_components_sgn(table, min_size=min_size, threshold_erode=threshold_erode,
min_component_length=min_component_length,
min_edge_distance=min_edge_distance, iterations_erode=iterations_erode)

table.loc[:, "component_labels"] = comp_labels

return table


def components_ihc(
table: pd.DataFrame,
min_component_length: int = 50,
min_edge_distance: float = 30,
):
centroids = list(zip(table["anchor_x"], table["anchor_y"], table["anchor_z"]))
labels = [int(i) for i in list(table["label_id"])]
coords = {}
for index, element in zip(labels, centroids):
coords[index] = element

components = graph_connected_components(coords, min_edge_distance, min_component_length)
return components


def label_components_ihc(
table: pd.DataFrame,
min_size: int = 1000,
min_component_length: int = 50,
min_edge_distance: float = 30,
) -> List[int]:
"""Label components using graph connected components.

Args:
table: Dataframe of segmentation table.
min_size: Minimal number of pixels for filtering small instances.
threshold_erode: Threshold of column value after erosion step with spatial statistics.
min_component_length: Minimal length for filtering out connected components.
min_edge_distance: Minimal distance in micrometer between points to create edges for connected components.
iterations_erode: Number of iterations for erosion, normally determined automatically.

Returns:
List of component label for each point in dataframe. 0 - background, then in descending order of size
"""

# First, apply the size filter.
entries_filtered = table[table.n_pixels < min_size]
table = table[table.n_pixels >= min_size]

components = components_ihc(table, min_component_length=min_component_length,
min_edge_distance=min_edge_distance)

# add size-filtered objects to have same initial length
table = pd.concat([table, entries_filtered], ignore_index=True)
table.sort_values("label_id")

length_components = [len(c) for c in components]
length_components, components = zip(*sorted(zip(length_components, components), reverse=True))

component_labels = [0 for _ in range(len(table))]
# be aware of 'label_id' of dataframe starting at 1
for lab, comp in enumerate(components):
for comp_index in comp:
component_labels[comp_index - 1] = lab + 1

return component_labels


def postprocess_ihc_seg(
table: pd.DataFrame,
min_size: int = 1000,
min_component_length: int = 50,
min_edge_distance: float = 30,
) -> pd.DataFrame:
"""Postprocessing SGN segmentation of cochlea.

Args:
table: Dataframe of segmentation table.
min_size: Minimal number of pixels for filtering small instances.
threshold_erode: Threshold of column value after erosion step with spatial statistics.
min_component_length: Minimal length for filtering out connected components.
min_edge_distance: Minimal distance in micrometer between points to create edges for connected components.
iterations_erode: Number of iterations for erosion, normally determined automatically.

Returns:
Dataframe with component labels.
"""

comp_labels = label_components_ihc(table, min_size=min_size,
min_component_length=min_component_length,
min_edge_distance=min_edge_distance)

table.loc[:, "component_labels"] = comp_labels

Expand Down
17 changes: 11 additions & 6 deletions flamingo_tools/segmentation/unet_prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,8 +403,8 @@ def run_unet_prediction(

def run_unet_prediction_preprocess_slurm(
input_path: str,
input_key: Optional[str],
output_folder: str,
input_key: Optional[str] = None,
s3: Optional[str] = None,
s3_bucket_name: Optional[str] = None,
s3_service_endpoint: Optional[str] = None,
Expand All @@ -417,8 +417,8 @@ def run_unet_prediction_preprocess_slurm(

Args:
input_path: The path to the input data.
input_key: The key / internal path of the image data.
output_folder: The output folder for storing the segmentation related data.
input_key: The key / internal path of the image data.
s3: Flag for considering input_path fo S3 bucket.
s3_bucket_name: S3 bucket name.
s3_service_endpoint: S3 service endpoint.
Expand All @@ -437,9 +437,9 @@ def run_unet_prediction_preprocess_slurm(

def run_unet_prediction_slurm(
input_path: str,
input_key: Optional[str],
output_folder: str,
model_path: str,
input_key: Optional[str] = None,
scale: Optional[float] = None,
block_shape: Optional[Tuple[int, int, int]] = None,
halo: Optional[Tuple[int, int, int]] = None,
Expand All @@ -453,9 +453,9 @@ def run_unet_prediction_slurm(

Args:
input_path: The path to the input data.
input_key: The key / internal path of the image data.
output_folder: The output folder for storing the segmentation related data.
model_path: The path to the model to use for segmentation.
input_key: The key / internal path of the image data.
scale: A factor to rescale the data before prediction.
By default the data will not be rescaled.
block_shape: The block-shape for running the prediction.
Expand Down Expand Up @@ -501,7 +501,11 @@ def run_unet_prediction_slurm(


# does NOT need GPU, FIXME: only run on CPU
def run_unet_segmentation_slurm(output_folder: str, min_size: int) -> None:
def run_unet_segmentation_slurm(
output_folder: str,
min_size: int,
boundary_distance_threshold: float = 0.5,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should also expose the center_distance_threshold here.

) -> None:
"""Create segmentation from prediction.

Args:
Expand All @@ -510,4 +514,5 @@ def run_unet_segmentation_slurm(output_folder: str, min_size: int) -> None:
"""
min_size = int(min_size)
pmap_out = os.path.join(output_folder, "predictions.zarr")
distance_watershed_implementation(pmap_out, output_folder, min_size=min_size)
distance_watershed_implementation(pmap_out, output_folder, boundary_distance_threshold=boundary_distance_threshold,
min_size=min_size)
4 changes: 4 additions & 0 deletions scripts/extract_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ def main(
basename = input_content[0] + resized_suffix
else:
basename = "".join(input_content[-1].split(".")[:-1])
image_prefix = basename.split("_")[-1]

input_dir = input_path.split(basename)[0]
input_dir = os.path.abspath(input_dir)
Expand Down Expand Up @@ -93,6 +94,9 @@ def main(
with zarr.open(s3_path, mode="r") as f:
raw = f[input_key][roi]

elif ".tif" in input_path:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would change this to if input_key is None. That's more consistent to how we handle this otherwise in the codebase.

raw = read_tif(input_path)[roi]

else:
with zarr.open(input_path, mode="r") as f:
raw = f[input_key][roi]
Expand Down