Skip to content

Commit ef6761c

Browse files
committed
Update segmentation crops
1 parent b047d91 commit ef6761c

File tree

2 files changed

+71
-69
lines changed

2 files changed

+71
-69
lines changed

scripts/figures/plot_fig2.py

Lines changed: 67 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -6,112 +6,112 @@
66
import matplotlib.pyplot as plt
77
import tifffile
88
from matplotlib import colors
9+
from skimage.segmentation import find_boundaries
910

1011
from util import literature_reference_values
1112

1213
png_dpi = 300
1314

1415

16+
def scramble_instance_labels(arr):
17+
"""Scramble indexes of instance segmentation to avoid neighboring colors.
18+
"""
19+
unique = list(np.unique(arr)[1:])
20+
rng = np.random.default_rng(seed=42)
21+
new_list = rng.uniform(1, len(unique) + 1, size=(len(unique)))
22+
new_arr = np.zeros(arr.shape)
23+
for old_id, new_id in zip(unique, new_list):
24+
new_arr[arr == old_id] = new_id
25+
return new_arr
26+
27+
28+
def plot_seg_crop(img_path, seg_path, save_path, xlim1, xlim2, ylim1, ylim2, boundary_rgba=[0, 0, 0, 0.5], plot=False):
29+
seg = tifffile.imread(seg_path)
30+
if len(seg.shape) == 3:
31+
seg = seg[10, xlim1:xlim2, ylim1:ylim2]
32+
else:
33+
seg = seg[xlim1:xlim2, ylim1:ylim2]
34+
35+
img = tifffile.imread(img_path)
36+
img = img[10, xlim1:xlim2, ylim1:ylim2]
37+
38+
# create color map with random distribution for coloring instance segmentation
39+
unique = list(np.unique(seg)[1:])
40+
n_instances = len(unique)
41+
42+
seg = scramble_instance_labels(seg)
43+
44+
rng = np.random.default_rng(seed=42) # fixed seed for reproducibility
45+
colors_array = rng.uniform(0, 1, size=(n_instances, 4)) # RGBA values in [0,1]
46+
colors_array[:, 3] = 1.0 # full alpha
47+
colors_array[0, 3] = 0.0 # make label 0 transparent (background)
48+
cmap = colors.ListedColormap(colors_array)
49+
50+
boundaries = find_boundaries(seg, mode="inner")
51+
boundary_overlay = np.zeros((*boundaries.shape, 4))
52+
53+
boundary_overlay[boundaries] = boundary_rgba # RGBA = black
54+
55+
fig, ax = plt.subplots(figsize=(6, 6))
56+
ax.imshow(img, cmap="gray")
57+
ax.imshow(seg, cmap=cmap, alpha=0.5, interpolation="nearest")
58+
ax.imshow(boundary_overlay)
59+
ax.axis("off")
60+
plt.tight_layout()
61+
plt.savefig(save_path, bbox_inches="tight", pad_inches=0.1, dpi=png_dpi)
62+
63+
if plot:
64+
plt.show()
65+
else:
66+
plt.close()
67+
68+
1569
def fig_02b_sgn(save_dir, plot=False):
70+
"""Plot crops of SGN segmentation of CochleaNet, Cellpose and micro-sam.
71+
"""
1672
cochlea_dir = "/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet"
1773
val_sgn_dir = f"{cochlea_dir}/predictions/val_sgn"
1874
image_dir = f"{cochlea_dir}/AnnotatedImageCrops/F1ValidationSGNs/for_consensus_annotation"
1975

2076
crop_name = "MLR169R_PV_z3420_allturns_full"
21-
image_name = os.path.join(image_dir, f"{crop_name}.tif")
77+
img_path = os.path.join(image_dir, f"{crop_name}.tif")
2278

2379
xlim1 = 2000
2480
xlim2 = 2500
2581
ylim1 = 3100
2682
ylim2 = 3600
83+
boundary_rgba = [1, 1, 1, 0.5]
2784

2885
for seg_net in ["distance_unet", "cellpose-sam", "micro-sam"]:
2986
save_path = os.path.join(save_dir, f"fig_02b_sgn_{seg_net}.png")
30-
3187
seg_dir = os.path.join(val_sgn_dir, seg_net)
88+
seg_path = os.path.join(seg_dir, f"{crop_name}_seg.tif")
3289

33-
in_path = os.path.join(seg_dir, f"{crop_name}_seg.tif")
34-
seg = tifffile.imread(in_path)
35-
if len(seg.shape) == 3:
36-
seg = seg[10, xlim1:xlim2, ylim1:ylim2]
37-
else:
38-
seg = seg[xlim1:xlim2, ylim1:ylim2]
39-
40-
in_path = os.path.join(seg_dir, image_name)
41-
img = tifffile.imread(in_path)
42-
img = img[10, xlim1:xlim2, ylim1:ylim2]
43-
44-
# create color map with random distribution for coloring instance segmentation
45-
unique = list(np.unique(seg)[1:])
46-
n_instances = len(unique)
47-
rng = np.random.default_rng(seed=42) # fixed seed for reproducibility
48-
colors_array = rng.uniform(0, 1, size=(n_instances, 4)) # RGBA values in [0,1]
49-
colors_array[:, 3] = 1.0 # full alpha
50-
colors_array[0, 3] = 0.0 # make label 0 transparent (background)
51-
rand_cmap = colors.ListedColormap(colors_array)
52-
53-
fig, ax = plt.subplots(figsize=(6, 6))
54-
ax.imshow(img, cmap="gray")
55-
ax.imshow(seg, cmap=rand_cmap, alpha=0.5, interpolation="nearest")
56-
ax.axis("off")
57-
plt.tight_layout()
58-
plt.savefig(save_path, bbox_inches="tight", pad_inches=0.1, dpi=png_dpi)
59-
60-
if plot:
61-
plt.show()
62-
else:
63-
plt.close()
90+
plot_seg_crop(img_path, seg_path, save_path, xlim1, xlim2, ylim1, ylim2, boundary_rgba, plot=plot)
6491

6592

6693
def fig_02b_ihc(save_dir, plot=False):
94+
"""Plot crops of IHC segmentation of CochleaNet, Cellpose and micro-sam.
95+
"""
6796
cochlea_dir = "/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet"
6897
val_sgn_dir = f"{cochlea_dir}/predictions/val_ihc"
6998
image_dir = f"{cochlea_dir}/AnnotatedImageCrops/F1ValidationIHCs"
7099

71100
crop_name = "MLR226L_VGlut3_z1200_3turns_full"
72-
image_name = os.path.join(image_dir, f"{crop_name}.tif")
101+
img_path = os.path.join(image_dir, f"{crop_name}.tif")
73102

74103
xlim1 = 1900
75104
xlim2 = 2400
76105
ylim1 = 2000
77106
ylim2 = 2500
107+
boundary_rgba = [1, 1, 1, 0.5]
78108

79-
for seg_net in ["distance_unet_v3", "cellpose-sam", "micro-sam"]:
109+
for seg_net in ["distance_unet_v4b", "cellpose-sam", "micro-sam"]:
80110
save_path = os.path.join(save_dir, f"fig_02b_ihc_{seg_net}.png")
81-
82111
seg_dir = os.path.join(val_sgn_dir, seg_net)
112+
seg_path = os.path.join(seg_dir, f"{crop_name}_seg.tif")
83113

84-
in_path = os.path.join(seg_dir, f"{crop_name}_seg.tif")
85-
seg = tifffile.imread(in_path)
86-
if len(seg.shape) == 3:
87-
seg = seg[10, xlim1:xlim2, ylim1:ylim2]
88-
else:
89-
seg = seg[xlim1:xlim2, ylim1:ylim2]
90-
91-
in_path = os.path.join(seg_dir, image_name)
92-
img = tifffile.imread(in_path)
93-
img = img[10, xlim1:xlim2, ylim1:ylim2]
94-
95-
# create color map with random distribution for coloring instance segmentation
96-
unique = list(np.unique(seg)[1:])
97-
n_instances = len(unique)
98-
rng = np.random.default_rng(seed=42) # fixed seed for reproducibility
99-
colors_array = rng.uniform(0, 1, size=(n_instances, 4)) # RGBA values in [0,1]
100-
colors_array[:, 3] = 1.0 # full alpha
101-
colors_array[0, 3] = 0.0 # make label 0 transparent (background)
102-
rand_cmap = colors.ListedColormap(colors_array)
103-
104-
fig, ax = plt.subplots(figsize=(6, 6))
105-
ax.imshow(img, cmap="gray")
106-
ax.imshow(seg, cmap=rand_cmap, alpha=0.5, interpolation="nearest")
107-
ax.axis("off")
108-
plt.tight_layout()
109-
plt.savefig(save_path, bbox_inches="tight", pad_inches=0.1, dpi=png_dpi)
110-
111-
if plot:
112-
plt.show()
113-
else:
114-
plt.close()
114+
plot_seg_crop(img_path, seg_path, save_path, xlim1, xlim2, ylim1, ylim2, boundary_rgba, plot=plot)
115115

116116

117117
def fig_02c(save_path, plot=False, all_versions=False):

scripts/prediction/run_prediction_distance_unet.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,8 @@ def main():
3737
in which case the boundary distances are not used for the seeds.")
3838
parser.add_argument("--fg_threshold", default=0.5, type=float,
3939
help="The threshold applied to the foreground prediction for deriving the watershed mask.")
40+
parser.add_argument("--distance_smoothing", default=0, type=float,
41+
help="The sigma value for smoothing the distance predictions with a gaussian kernel.")
4042

4143
args = parser.parse_args()
4244

@@ -78,7 +80,7 @@ def main():
7880
seg_class=args.seg_class,
7981
center_distance_threshold=args.center_distance_threshold,
8082
boundary_distance_threshold=args.boundary_distance_threshold,
81-
fg_threshold=args.fg_threshold,
83+
fg_threshold=args.fg_threshold, distance_smoothing=args.distance_smoothing,
8284
)
8385

8486
abs_path = os.path.abspath(args.input)
@@ -95,7 +97,7 @@ def main():
9597
seg_class=args.seg_class,
9698
center_distance_threshold=args.center_distance_threshold,
9799
boundary_distance_threshold=args.boundary_distance_threshold,
98-
fg_threshold=args.fg_threshold,
100+
fg_threshold=args.fg_threshold, distance_smoothing=args.distance_smoothing,
99101
)
100102
timer_output = os.path.join(args.output_folder, "timer.json")
101103

0 commit comments

Comments
 (0)