Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
160 changes: 134 additions & 26 deletions flamingo_tools/segmentation/postprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,41 @@ def downscaled_centroids(
return new_array


def graph_connected_components(coords: dict, min_edge_distance: float, min_component_length: int):
"""Create a list of IDs for each connected component of a graph.

Args:
coords: Dictionary containing label IDs as keys and their position as value.
min_edge_distance: Minimal edge distance between graph nodes to create an edge between nodes.
min_component_length: Minimal length of nodes of connected component. Filtered out if lower.

Returns:
List of dictionary keys of connected components.
"""
graph = nx.Graph()
for num, pos in coords.items():
graph.add_node(num, pos=pos)

# create edges between points whose distance is less than threshold min_edge_distance
for num_i, pos_i in coords.items():
for num_j, pos_j in coords.items():
if num_i < num_j:
dist = math.dist(pos_i, pos_j)
if dist <= min_edge_distance:
graph.add_edge(num_i, num_j, weight=dist)

components = list(nx.connected_components(graph))

# remove connected components with less nodes than threshold min_component_length
for component in components:
if len(component) < min_component_length:
for c in component:
graph.remove_node(c)

components = [list(s) for s in nx.connected_components(graph)]
return components


def components_sgn(
table: pd.DataFrame,
keyword: str = "distance_nn100",
Expand Down Expand Up @@ -370,27 +405,7 @@ def components_sgn(
for index, element in zip(labels_subset, centroids_subset):
coords[index] = element

graph = nx.Graph()
for num, pos in coords.items():
graph.add_node(num, pos=pos)

# create edges between points whose distance is less than threshold min_edge_distance
for i in coords:
for j in coords:
if i < j:
dist = math.dist(coords[i], coords[j])
if dist <= min_edge_distance:
graph.add_edge(i, j, weight=dist)

components = list(nx.connected_components(graph))

# remove connected components with less nodes than threshold min_component_length
for component in components:
if len(component) < min_component_length:
for c in component:
graph.remove_node(c)

components = [list(s) for s in nx.connected_components(graph)]
components = graph_connected_components(coords, min_edge_distance, min_component_length)

# add original coordinates closer to eroded component than threshold
if postprocess_graph:
Expand All @@ -410,15 +425,15 @@ def components_sgn(
return components


def label_components(
def label_components_sgn(
table: pd.DataFrame,
min_size: int = 1000,
threshold_erode: Optional[float] = None,
min_component_length: int = 50,
min_edge_distance: float = 30,
iterations_erode: Optional[int] = None,
) -> List[int]:
"""Label components using graph connected components.
"""Label SGN components using graph connected components.

Args:
table: Dataframe of segmentation table.
Expand Down Expand Up @@ -477,9 +492,102 @@ def postprocess_sgn_seg(
Dataframe with component labels.
"""

comp_labels = label_components(table, min_size=min_size, threshold_erode=threshold_erode,
min_component_length=min_component_length,
min_edge_distance=min_edge_distance, iterations_erode=iterations_erode)
comp_labels = label_components_sgn(table, min_size=min_size, threshold_erode=threshold_erode,
min_component_length=min_component_length,
min_edge_distance=min_edge_distance, iterations_erode=iterations_erode)

table.loc[:, "component_labels"] = comp_labels

return table


def components_ihc(
table: pd.DataFrame,
min_component_length: int = 50,
min_edge_distance: float = 30,
):
"""Create connected components for IHC segmentation.

Args:
table: Dataframe of segmentation table.
min_component_length: Minimal length for filtering out connected components.
min_edge_distance: Minimal distance in micrometer between points to create edges for connected components.

Returns:
Subgraph components as lists of label_ids of dataframe.
"""
centroids = list(zip(table["anchor_x"], table["anchor_y"], table["anchor_z"]))
labels = [int(i) for i in list(table["label_id"])]
coords = {}
for index, element in zip(labels, centroids):
coords[index] = element

components = graph_connected_components(coords, min_edge_distance, min_component_length)
return components


def label_components_ihc(
table: pd.DataFrame,
min_size: int = 1000,
min_component_length: int = 50,
min_edge_distance: float = 30,
) -> List[int]:
"""Label components using graph connected components.

Args:
table: Dataframe of segmentation table.
min_size: Minimal number of pixels for filtering small instances.
min_component_length: Minimal length for filtering out connected components.
min_edge_distance: Minimal distance in micrometer between points to create edges for connected components.

Returns:
List of component label for each point in dataframe. 0 - background, then in descending order of size
"""

# First, apply the size filter.
entries_filtered = table[table.n_pixels < min_size]
table = table[table.n_pixels >= min_size]

components = components_ihc(table, min_component_length=min_component_length,
min_edge_distance=min_edge_distance)

# add size-filtered objects to have same initial length
table = pd.concat([table, entries_filtered], ignore_index=True)
table.sort_values("label_id")

length_components = [len(c) for c in components]
length_components, components = zip(*sorted(zip(length_components, components), reverse=True))

component_labels = [0 for _ in range(len(table))]
# be aware of 'label_id' of dataframe starting at 1
for lab, comp in enumerate(components):
for comp_index in comp:
component_labels[comp_index - 1] = lab + 1

return component_labels


def postprocess_ihc_seg(
table: pd.DataFrame,
min_size: int = 1000,
min_component_length: int = 50,
min_edge_distance: float = 30,
) -> pd.DataFrame:
"""Postprocessing IHC segmentation of cochlea.

Args:
table: Dataframe of segmentation table.
min_size: Minimal number of pixels for filtering small instances.
min_component_length: Minimal length for filtering out connected components.
min_edge_distance: Minimal distance in micrometer between points to create edges for connected components.

Returns:
Dataframe with component labels.
"""

comp_labels = label_components_ihc(table, min_size=min_size,
min_component_length=min_component_length,
min_edge_distance=min_edge_distance)

table.loc[:, "component_labels"] = comp_labels

Expand Down
62 changes: 48 additions & 14 deletions flamingo_tools/segmentation/unet_prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,9 @@ def postprocess(x):
blocking = nt.blocking([0] * ndim, shape, block_shape)
n_blocks = blocking.numberOfBlocks
if prediction_instances != 1:
iteration_ids = [x.tolist() for x in np.array_split(list(range(n_blocks)), prediction_instances)]
# shuffle indexes with fixed seed to balance out segmentation blocks for slurm workers
rng = np.random.default_rng(seed=1234)
iteration_ids = [x.tolist() for x in np.array_split(list(rng.permutation(n_blocks)), prediction_instances)]
slurm_iteration = iteration_ids[slurm_task_id]
else:
slurm_iteration = list(range(n_blocks))
Expand All @@ -175,7 +177,7 @@ def postprocess(x):
return original_shape


def find_mask(input_path: str, input_key: Optional[str], output_folder: str) -> None:
def find_mask(input_path: str, input_key: Optional[str], output_folder: str, seg_class: Optional[str] = "sgn") -> None:
"""Determine the mask for running prediction.

The mask corresponds to data that contains actual signal and not just noise.
Expand All @@ -187,10 +189,25 @@ def find_mask(input_path: str, input_key: Optional[str], output_folder: str) ->
input_path: The file path to the image data.
input_key: The key / internal path of the image data.
output_folder: The output folder for storing the mask data.
seg_class: Specifier for exclusion criterias for mask generation.
"""
mask_path = os.path.join(output_folder, "mask.zarr")
f = z5py.File(mask_path, "a")

# set parameters for the exclusion of chunks within mask generation
if seg_class == "sgn":
upper_percentile = 95
min_intensity = 200
print(f"Calculating mask for segmentation class {seg_class}.")
elif seg_class == "ihc":
upper_percentile = 99
min_intensity = 150
print(f"Calculating mask for segmentation class {seg_class}.")
else:
upper_percentile = 95
min_intensity = 200
print("Calculating mask with default values.")

mask_key = "mask"
if mask_key in f:
return
Expand All @@ -209,8 +226,8 @@ def find_mask_block(block_id):
block = blocking.getBlock(block_id)
bb = tuple(slice(beg, end) for beg, end in zip(block.begin, block.end))
data = raw[bb]
max_ = np.percentile(data, 95)
if max_ > 200:
max_ = np.percentile(data, upper_percentile)
if max_ > min_intensity:
ds_mask[bb] = 1

n_threads = min(16, mp.cpu_count())
Expand Down Expand Up @@ -359,6 +376,7 @@ def run_unet_prediction(
center_distance_threshold: float = 0.4,
boundary_distance_threshold: Optional[float] = None,
fg_threshold: float = 0.5,
seg_class: Optional[str] = None,
) -> None:
"""Run prediction and segmentation with a distance U-Net.

Expand All @@ -377,12 +395,12 @@ def run_unet_prediction(
boundary_distance_threshold: The threshold applied to the boundary predictions to derive seeds.
By default this is set to 'None', in which case the boundary distances are not used for the seeds.
fg_threshold: The threshold applied to the foreground prediction for deriving the watershed mask.
seg_class: Specifier for exclusion criterias for mask generation.
"""
os.makedirs(output_folder, exist_ok=True)

if use_mask:
find_mask(input_path, input_key, output_folder)

find_mask(input_path, input_key, output_folder, seg_class=seg_class)
original_shape = prediction_impl(
input_path, input_key, output_folder, model_path, scale, block_shape, halo
)
Expand All @@ -403,12 +421,13 @@ def run_unet_prediction(

def run_unet_prediction_preprocess_slurm(
input_path: str,
input_key: Optional[str],
output_folder: str,
input_key: Optional[str] = None,
s3: Optional[str] = None,
s3_bucket_name: Optional[str] = None,
s3_service_endpoint: Optional[str] = None,
s3_credentials: Optional[str] = None,
seg_class: Optional[str] = None,
) -> None:
"""Pre-processing for the parallel prediction with U-Net models.
Masks are stored in mask.zarr in the output folder.
Expand All @@ -417,29 +436,31 @@ def run_unet_prediction_preprocess_slurm(

Args:
input_path: The path to the input data.
input_key: The key / internal path of the image data.
output_folder: The output folder for storing the segmentation related data.
input_key: The key / internal path of the image data.
s3: Flag for considering input_path fo S3 bucket.
s3_bucket_name: S3 bucket name.
s3_service_endpoint: S3 service endpoint.
s3_credentials: File path to credentials for S3 bucket.
seg_class: Specifier for exclusion criterias for mask generation.
"""
if s3 is not None:
input_path, fs = s3_utils.get_s3_path(
input_path, bucket_name=s3_bucket_name, service_endpoint=s3_service_endpoint, credential_file=s3_credentials
)

if not os.path.isdir(os.path.join(output_folder, "mask.zarr")):
find_mask(input_path, input_key, output_folder)
find_mask(input_path, input_key, output_folder, seg_class=seg_class)

calc_mean_and_std(input_path, input_key, output_folder)
if not os.path.isfile(os.path.join(output_folder, "mean_std.json")):
calc_mean_and_std(input_path, input_key, output_folder)


def run_unet_prediction_slurm(
input_path: str,
input_key: Optional[str],
output_folder: str,
model_path: str,
input_key: Optional[str] = None,
scale: Optional[float] = None,
block_shape: Optional[Tuple[int, int, int]] = None,
halo: Optional[Tuple[int, int, int]] = None,
Expand All @@ -453,9 +474,9 @@ def run_unet_prediction_slurm(

Args:
input_path: The path to the input data.
input_key: The key / internal path of the image data.
output_folder: The output folder for storing the segmentation related data.
model_path: The path to the model to use for segmentation.
input_key: The key / internal path of the image data.
scale: A factor to rescale the data before prediction.
By default the data will not be rescaled.
block_shape: The block-shape for running the prediction.
Expand Down Expand Up @@ -501,13 +522,26 @@ def run_unet_prediction_slurm(


# does NOT need GPU, FIXME: only run on CPU
def run_unet_segmentation_slurm(output_folder: str, min_size: int) -> None:
def run_unet_segmentation_slurm(
output_folder: str,
min_size: int,
center_distance_threshold: float = 0.4,
boundary_distance_threshold: float = 0.5,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should also expose the center_distance_threshold here.

fg_threshold: float = 0.5,
) -> None:
"""Create segmentation from prediction.

Args:
output_folder: The output folder for storing the segmentation related data.
min_size: The minimal size of segmented objects in the output.
center_distance_threshold: The threshold applied to the distance center predictions to derive seeds.
boundary_distance_threshold: The threshold applied to the boundary predictions to derive seeds.
By default this is set to 'None', in which case the boundary distances are not used for the seeds.
fg_threshold: The threshold applied to the foreground prediction for deriving the watershed mask.
"""
min_size = int(min_size)
pmap_out = os.path.join(output_folder, "predictions.zarr")
distance_watershed_implementation(pmap_out, output_folder, min_size=min_size)
distance_watershed_implementation(pmap_out, output_folder, center_distance_threshold=center_distance_threshold,
boundary_distance_threshold=boundary_distance_threshold,
fg_threshold=fg_threshold,
min_size=min_size)
Loading