Skip to content

Commit 65ccdd7

Browse files
committed
Filter cochlea with segmentation table and export lower resolution
1 parent 301e018 commit 65ccdd7

File tree

2 files changed

+319
-22
lines changed

2 files changed

+319
-22
lines changed

flamingo_tools/segmentation/postprocessing.py

Lines changed: 175 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import pandas as pd
1111

1212
from elf.io import open_file
13+
from scipy.ndimage import distance_transform_edt, binary_dilation, binary_closing
1314
from scipy.sparse import csr_matrix
1415
from scipy.spatial import distance
1516
from scipy.spatial import cKDTree, ConvexHull
@@ -615,3 +616,177 @@ def postprocess_ihc_seg(
615616
table.loc[:, "component_labels"] = comp_labels
616617

617618
return table
619+
620+
621+
def dilate_and_trim(
622+
arr_orig: np.ndarray,
623+
edt: np.ndarray,
624+
iterations: int = 15,
625+
offset: float = 0.4,
626+
) -> np.ndarray:
627+
"""Dilate and trim original binary array according to a
628+
Euclidean Distance Trasform computed for a separate target array.
629+
630+
Args:
631+
arr_orig: Original 3D binary array
632+
edt: 3D array containing Euclidean Distance transform for guiding dilation
633+
iterations: Number of iterations for dilations
634+
offset: Offset for regulating dilation. value should be in range(0, 0.45)
635+
636+
Returns:
637+
Dilated binary array
638+
"""
639+
border_coords = [(1, 0, 0), (-1, 0, 0), (0, 1, 0), (0, -1, 0), (0, 0, 1), (0, 0, -1)]
640+
for _ in range(iterations):
641+
arr_dilated = binary_dilation(arr_orig)
642+
for x in range(arr_dilated.shape[0]):
643+
for y in range(arr_dilated.shape[1]):
644+
for z in range(arr_dilated.shape[2]):
645+
if arr_dilated[x, y, z] != 0:
646+
if arr_orig[x, y, z] == 0:
647+
min_dist = float('inf')
648+
for dx, dy, dz in border_coords:
649+
nx, ny, nz = x+dx, y+dy, z+dz
650+
if arr_orig[nx, ny, nz] == 1:
651+
min_dist = min([min_dist, edt[nx, ny, nz]])
652+
if edt[x, y, z] >= min_dist - offset:
653+
arr_dilated[x, y, z] = 0
654+
arr_orig = arr_dilated
655+
return arr_dilated
656+
657+
658+
def filter_cochlea_volume_single(
659+
table: pd.DataFrame,
660+
components: Optional[List[int]] = [1],
661+
scale_factor: int = 48,
662+
resolution: float = 0.38,
663+
dilation_iterations: int = 12,
664+
padding: int = 1200,
665+
) -> np.ndarray:
666+
"""Filter cochlea volume based on a segmentation table.
667+
Centroids contained in the segmentation table are used to create a down-scaled binary array.
668+
The array can be dilated.
669+
670+
Args:
671+
table: Segmentation table.
672+
components: Component labels for filtering segmentation table.
673+
scale_factor: Down-sampling factor for filtering.
674+
resolution: Resolution of pixel in µm.
675+
dilation_iterations: Iterations for dilating binary segmentation mask. A negative value omits binary closing.
676+
padding: Padding in pixel to apply to guessed dimensions based on centroid coordinates.
677+
678+
Returns:
679+
Binary 3D array of filtered cochlea.
680+
"""
681+
# filter components
682+
if components is not None:
683+
table = table[table["component_labels"].isin(components)]
684+
685+
# identify approximate input dimensions for down-scaling
686+
centroids = list(zip(table["anchor_x"] / resolution,
687+
table["anchor_y"] / resolution,
688+
table["anchor_z"] / resolution))
689+
690+
# padding the array allows for dilation without worrying about array borders
691+
max_x = table["anchor_x"].max() / resolution + padding
692+
max_y = table["anchor_y"].max() / resolution + padding
693+
max_z = table["anchor_z"].max() / resolution + padding
694+
ref_dimensions = (max_x, max_y, max_z)
695+
696+
# down-scale arrays
697+
array_downscaled = downscaled_centroids(centroids, ref_dimensions=ref_dimensions,
698+
scale_factor=scale_factor, downsample_mode="capped")
699+
700+
if dilation_iterations > 0:
701+
array_dilated = binary_dilation(array_downscaled, np.ones((3, 3, 3)), iterations=dilation_iterations)
702+
return binary_closing(array_dilated, np.ones((3, 3, 3)), iterations=1)
703+
704+
elif dilation_iterations == 0:
705+
return binary_closing(array_downscaled, np.ones((3, 3, 3)), iterations=1)
706+
707+
else:
708+
return array_downscaled
709+
710+
711+
def filter_cochlea_volume(
712+
sgn_table: pd.DataFrame,
713+
ihc_table: pd.DataFrame,
714+
sgn_components: Optional[List[int]] = [1],
715+
ihc_components: Optional[List[int]] = [1],
716+
scale_factor: int = 48,
717+
resolution: float = 0.38,
718+
dilation_iterations: int = 12,
719+
padding: int = 1200,
720+
dilation_method = "individual",
721+
) -> np.ndarray:
722+
"""Filter cochlea volume with SGN and IHC segmentation.
723+
Centroids contained in the segmentation tables are used to create down-scaled binary arrays.
724+
The arrays are then dilated using guided dilation to fill the section inbetween SGNs and IHCs.
725+
726+
Args:
727+
sgn_table: SGN segmentation table.
728+
ihc_table: IHC segmentation table.
729+
sgn_components: Component labels for filtering SGN segmentation table.
730+
ihc_components: Component labels for filtering IHC segmentation table.
731+
scale_factor: Down-sampling factor for filtering.
732+
resolution: Resolution of pixel in µm.
733+
dilation_iterations: Iterations for dilating binary segmentation mask.
734+
padding: Padding in pixel to apply to guessed dimensions based on centroid coordinates.
735+
dilation_method: Dilation style for SGN and IHC segmentation, either 'individual', 'combined' or no dilation.
736+
737+
Returns:
738+
Binary 3D array of filtered cochlea.
739+
"""
740+
# filter components
741+
if sgn_components is not None:
742+
sgn_table = sgn_table[sgn_table["component_labels"].isin(sgn_components)]
743+
if ihc_components is not None:
744+
ihc_table = ihc_table[ihc_table["component_labels"].isin(ihc_components)]
745+
746+
# identify approximate input dimensions for down-scaling
747+
centroids_sgn = list(zip(sgn_table["anchor_x"] / resolution,
748+
sgn_table["anchor_y"] / resolution,
749+
sgn_table["anchor_z"] / resolution))
750+
centroids_ihc = list(zip(ihc_table["anchor_x"] / resolution,
751+
ihc_table["anchor_y"] / resolution,
752+
ihc_table["anchor_z"] / resolution))
753+
754+
# padding the array allows for dilation without worrying about array borders
755+
max_x = max([sgn_table["anchor_x"].max(), ihc_table["anchor_x"].max()]) / resolution + padding
756+
max_y = max([sgn_table["anchor_y"].max(), ihc_table["anchor_y"].max()]) / resolution + padding
757+
max_z = max([sgn_table["anchor_z"].max(), ihc_table["anchor_z"].max()]) / resolution + padding
758+
ref_dimensions = (max_x, max_y, max_z)
759+
760+
# down-scale arrays
761+
array_downscaled_sgn = downscaled_centroids(centroids_sgn, ref_dimensions=ref_dimensions,
762+
scale_factor=scale_factor, downsample_mode="capped")
763+
764+
array_downscaled_ihc = downscaled_centroids(centroids_ihc, ref_dimensions=ref_dimensions,
765+
scale_factor=scale_factor, downsample_mode="capped")
766+
767+
# dilate down-scaled SGN array in direction of IHC segmentation
768+
distance_from_sgn = distance_transform_edt(~array_downscaled_sgn.astype(bool))
769+
iterations = 20
770+
arr_dilated = dilate_and_trim(array_downscaled_ihc.copy(), distance_from_sgn, iterations=iterations, offset=0.4)
771+
772+
# dilate single structures first
773+
if dilation_method == "individual":
774+
ihc_dilated = binary_dilation(array_downscaled_ihc, np.ones((3, 3, 3)), iterations=dilation_iterations)
775+
sgn_dilated = binary_dilation(array_downscaled_sgn, np.ones((3, 3, 3)), iterations=dilation_iterations)
776+
combined_dilated = arr_dilated + ihc_dilated + sgn_dilated
777+
combined_dilated[combined_dilated > 0] = 1
778+
combined_dilated = binary_dilation(combined_dilated, np.ones((3, 3, 3)), iterations=1)
779+
780+
# dilate combined structure
781+
elif dilation_method == "combined":
782+
# combine SGN, IHC, and region between both to form output mask
783+
combined_structure = arr_dilated + array_downscaled_ihc + array_downscaled_sgn
784+
combined_structure[combined_structure > 0] = 1
785+
combined_dilated = binary_dilation(combined_structure, np.ones((3, 3, 3)), iterations=dilation_iterations)
786+
787+
# no dilation of combined structure
788+
else:
789+
combined_dilated = arr_dilated + ihc_dilated + sgn_dilated
790+
combined_dilated[combined_dilated > 0] = 1
791+
792+
return combined_dilated

scripts/export_lower_resolution.py

Lines changed: 144 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
11
import argparse
22
import os
3+
from typing import List, Optional
34

45
import numpy as np
56
import pandas as pd
67
import tifffile
78
import zarr
89

910
from flamingo_tools.s3_utils import get_s3_path, BUCKET_NAME, SERVICE_ENDPOINT
11+
from flamingo_tools.segmentation.postprocessing import filter_cochlea_volume, filter_cochlea_volume_single
1012
# from skimage.segmentation import relabel_sequential
1113

1214

@@ -26,37 +28,157 @@ def filter_component(fs, segmentation, cochlea, seg_name, components):
2628
return segmentation
2729

2830

31+
def filter_cochlea(
32+
cochlea: str,
33+
filter_cochlea_channels: str,
34+
sgn_components: Optional[List[int]] = None,
35+
ihc_components: Optional[List[int]] = None,
36+
ds_factor: int = 24,
37+
dilation_iterations: int = 8,
38+
) -> np.ndarray:
39+
"""Pre-process information for filtering cochlea volume based on segmentation table.
40+
Differentiates between the input of a single channel of either IHC or SGN or if both are supplied.
41+
If a single channel is given, the filtered volume contains
42+
a down-sampled segmentation area, which has been dilated.
43+
If both IHC and SGN segmentation are supplied, a more specialized dilation
44+
is applied to ensure that the connecting volume is not filtered.
45+
46+
Args:
47+
cochlea: Name of cochlea.
48+
filter_cochlea_channels: Segmentation table(s) used for filtering.
49+
sgn_components: Component labels for filtering SGN segmentation table.
50+
ihc_components: Component labels for filtering IHC segmentation table.
51+
ds_factor: Down-sampling factor for filtering.
52+
dilation_iterations: Iterations for dilating binary segmentation mask.
53+
54+
Returns:
55+
Binary 3D array of filtered cochlea
56+
"""
57+
# we check if the supplied channels contain an SGN and IHC channel
58+
sgn_channels = [ch for ch in filter_cochlea_channels if "SGN" in ch]
59+
sgn_channel = None if len(sgn_channels) == 0 else sgn_channels[0]
60+
61+
ihc_channels = [ch for ch in filter_cochlea_channels if "IHC" in ch]
62+
ihc_channel = None if len(ihc_channels) == 0 else ihc_channels[0]
63+
64+
if ihc_channel is None and sgn_channel is None:
65+
raise ValueError("Channels supplied for filtering cochlea volume do not contain an IHC or SGN segmentation.")
66+
67+
if sgn_channel is not None:
68+
internal_path = os.path.join(cochlea, "tables", sgn_channel, "default.tsv")
69+
tsv_path, fs = get_s3_path(internal_path, bucket_name=BUCKET_NAME, service_endpoint=SERVICE_ENDPOINT)
70+
with fs.open(tsv_path, "r") as f:
71+
table_sgn = pd.read_csv(f, sep="\t")
72+
73+
if ihc_channel is not None:
74+
internal_path = os.path.join(cochlea, "tables", ihc_channel, "default.tsv")
75+
tsv_path, fs = get_s3_path(internal_path, bucket_name=BUCKET_NAME, service_endpoint=SERVICE_ENDPOINT)
76+
with fs.open(tsv_path, "r") as f:
77+
table_ihc = pd.read_csv(f, sep="\t")
78+
79+
if sgn_channel is None:
80+
# filter based in IHC segmentation
81+
return filter_cochlea_volume_single(table_ihc, components=ihc_components,
82+
scale_factor=ds_factor, dilation_iterations=dilation_iterations)
83+
elif ihc_channel is None:
84+
# filter based on SGN segmentation
85+
return filter_cochlea_volume_single(table_sgn, components=sgn_components,
86+
scale_factor=ds_factor, dilation_iterations=dilation_iterations)
87+
else:
88+
# filter based on SGN and IHC segmentation with a specialized function
89+
return filter_cochlea_volume(table_sgn, table_ihc,
90+
sgn_components=sgn_components,
91+
ihc_components=ihc_components,
92+
scale_factor=ds_factor,
93+
dilation_iterations=dilation_iterations)
94+
95+
96+
def upscale_volume(
97+
target_data: np.ndarray,
98+
downscaled_volume: np.ndarray,
99+
upscale_factor: int,
100+
) -> np.ndarray:
101+
"""Up-scale binary 3D mask to dimensions of target data.
102+
After an initial up-scaling, the dimensions are cropped or zero-padded to fit the target shape.
103+
104+
Args:
105+
target_data: Reference data for up-scaling.
106+
downscaled_volume: Down-scaled binary 3D array.
107+
upscale_factor: Initial factor for up-scaling binary array.
108+
109+
Returns:
110+
Resized binary array.
111+
"""
112+
target_shape = target_data.shape
113+
upscaled_filter = np.repeat(
114+
np.repeat(
115+
np.repeat(downscaled_volume, upscale_factor, axis=0),
116+
upscale_factor, axis=1),
117+
upscale_factor, axis=2)
118+
resized = np.zeros(target_shape, dtype=target_data.dtype)
119+
min_x, min_y, min_z = tuple(min(upscaled_filter.shape[i], target_shape[i]) for i in range(3))
120+
resized[:min_x, :min_y, :min_z] = upscaled_filter[:min_x, :min_y, :min_z]
121+
return resized
122+
123+
29124
def export_lower_resolution(args):
30-
output_folder = os.path.join(args.output_folder, args.cochlea, f"scale{args.scale}")
31-
os.makedirs(output_folder, exist_ok=True)
32-
33-
input_key = f"s{args.scale}"
34-
for channel in args.channels:
35-
out_path = os.path.join(output_folder, f"{channel}.tif")
36-
if os.path.exists(out_path):
37-
continue
38-
39-
print("Exporting channel", channel)
40-
internal_path = os.path.join(args.cochlea, "images", "ome-zarr", f"{channel}.ome.zarr")
41-
s3_store, fs = get_s3_path(internal_path, bucket_name=BUCKET_NAME, service_endpoint=SERVICE_ENDPOINT)
42-
with zarr.open(s3_store, mode="r") as f:
43-
data = f[input_key][:]
44-
print(data.shape)
45-
if args.filter_by_components is not None:
46-
data = filter_component(fs, data, args.cochlea, channel, args.filter_by_components)
47-
if args.binarize:
48-
data = (data > 0).astype("uint16")
49-
tifffile.imwrite(out_path, data, bigtiff=True, compression="zlib")
125+
# calculate single filter mask for all lower resolutions
126+
if args.filter_cochlea_channels is not None:
127+
ds_factor = 48
128+
filter_volume = filter_cochlea(args.cochlea, args.filter_cochlea_channels,
129+
sgn_components=args.filter_sgn_components,
130+
ihc_components=args.filter_ihc_components,
131+
dilation_iterations=args.filter_dilation_iterations, ds_factor=ds_factor)
132+
filter_volume = np.transpose(filter_volume, (2,1,0))
133+
134+
# iterate through exporting lower resolutions
135+
for scale in args.scale:
136+
if args.filter_cochlea_channels is not None:
137+
output_folder = os.path.join(args.output_folder, args.cochlea,
138+
f"scale{scale}_dilation{args.filter_dilation_iterations}")
139+
else:
140+
output_folder = os.path.join(args.output_folder, args.cochlea, f"scale{scale}")
141+
os.makedirs(output_folder, exist_ok=True)
142+
143+
input_key = f"s{scale}"
144+
for channel in args.channels:
145+
out_path = os.path.join(output_folder, f"{channel}.tif")
146+
if os.path.exists(out_path):
147+
continue
148+
149+
print("Exporting channel", channel)
150+
internal_path = os.path.join(args.cochlea, "images", "ome-zarr", f"{channel}.ome.zarr")
151+
s3_store, fs = get_s3_path(internal_path, bucket_name=BUCKET_NAME, service_endpoint=SERVICE_ENDPOINT)
152+
with zarr.open(s3_store, mode="r") as f:
153+
data = f[input_key][:]
154+
print("Data shape", data.shape)
155+
if args.filter_by_components is not None:
156+
data = filter_component(fs, data, args.cochlea, channel, args.filter_by_components)
157+
if args.filter_cochlea_channels is not None:
158+
us_factor = ds_factor // (2 ** scale)
159+
upscaled_filter = upscale_volume(data, filter_volume, upscale_factor=us_factor)
160+
data[upscaled_filter == 0] = 0
161+
if "PV" in channel:
162+
max_intensity = 1400
163+
data[data > max_intensity] = 0
164+
165+
if args.binarize:
166+
data = (data > 0).astype("uint16")
167+
tifffile.imwrite(out_path, data, bigtiff=True, compression="zlib")
50168

51169

52170
def main():
53171
parser = argparse.ArgumentParser()
54172
parser.add_argument("--cochlea", "-c", required=True)
55-
parser.add_argument("--scale", "-s", type=int, required=True)
173+
parser.add_argument("--scale", "-s", nargs="+", type=int, required=True)
56174
parser.add_argument("--output_folder", "-o", required=True)
57-
parser.add_argument("--channels", nargs="+", default=["PV", "VGlut3", "CTBP2"])
175+
parser.add_argument("--channels", nargs="+", type=str, default=["PV", "VGlut3", "CTBP2"])
58176
parser.add_argument("--filter_by_components", nargs="+", type=int, default=None)
177+
parser.add_argument("--filter_sgn_components", nargs="+", type=int, default=[1])
178+
parser.add_argument("--filter_ihc_components", nargs="+", type=int, default=[1])
59179
parser.add_argument("--binarize", action="store_true")
180+
parser.add_argument("--filter_cochlea_channels", nargs="+", type=str, default=None)
181+
parser.add_argument("--filter_dilation_iterations", type=int, default=8)
60182
args = parser.parse_args()
61183

62184
export_lower_resolution(args)

0 commit comments

Comments
 (0)