Skip to content

Commit 90bd669

Browse files
authored
fix: detect nuclei inside roi if provided (#28)
1 parent aeac294 commit 90bd669

File tree

4 files changed

+124
-6
lines changed

4 files changed

+124
-6
lines changed

descriptor.yml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,12 @@ inputs:
3131
max_height: 4000
3232
description: The input image in which nuclei should be segmented
3333

34+
roi:
35+
display_name: Input ROI
36+
description: The ROI in the image in which the nuclei should be segmented
37+
type: geometry
38+
optional: true
39+
3440
stardist_prob_t:
3541
display_name: Probability threshold
3642
description: Probability Threshold in range [0.0, 1.0] - higher values lead to fewer segmented objects, but will likely avoid false positives

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ dependencies = [
1010
"imageio==2.34.1",
1111
"numpy==1.26.4",
1212
"pyyaml==6.0.1",
13+
"shapely==2.1.2",
1314
"stardist==0.9.1",
1415
"tensorflow-cpu==2.16.1",
1516
]

scripts/run.py

Lines changed: 56 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import yaml
77
from csbdeep.utils import normalize
88
from imageio import imread
9+
from shapely.geometry import Point, Polygon
910
from stardist.models import StarDist2D
1011

1112

@@ -14,6 +15,30 @@
1415
MODEL_DATA_DIR = "/models/"
1516

1617

18+
def is_nucleus_inside_roi(nucleus_coords: np.ndarray, roi_polygon: Polygon, image_height: int) -> bool:
19+
"""
20+
Check if a nucleus is inside the ROI polygon.
21+
22+
Args:
23+
nucleus_coords (np.ndarray): The nucleus coordinates in stardist format (2, N) where first row is x, second is y.
24+
roi_polygon (Polygon): The ROI polygon from shapely.
25+
image_height (int): The image height for coordinate transformation.
26+
27+
Returns:
28+
bool: True if the nucleus centroid is inside the ROI, False otherwise.
29+
"""
30+
31+
# Reverse (y,x) → (x,y) and flip Y coordinates to match ROI coordinate system
32+
x_coords = nucleus_coords[0]
33+
y_coords = image_height - nucleus_coords[1]
34+
35+
# Calculate centroid
36+
centroid_x = np.mean(x_coords)
37+
centroid_y = np.mean(y_coords)
38+
39+
# Check if centroid is inside ROI
40+
return roi_polygon.contains(Point(centroid_x, centroid_y))
41+
1742

1843
def from_stardist_to_geojson_string(stardist_polygroup: np.ndarray, image_height: int):
1944
"""Converts a polygon coordinates ndarray generated by the stardist algorithm into a geojson string.
@@ -89,13 +114,22 @@ def write_array(array_path: str, array_data: Iterable[Any], format_fn: Callable[
89114

90115

91116
def main():
92-
# Red inputs
117+
# Read inputs
93118
stardist_norm_perc_low = read_parameter(os.path.join(INPUT_DIR, "stardist_norm_perc_low"), cast_fn=float, default=1.0)
94119
stardist_norm_perc_high = read_parameter(os.path.join(INPUT_DIR, "stardist_norm_perc_high"), cast_fn=float, default=99.0)
95120
stardist_prob_t = read_parameter(os.path.join(INPUT_DIR, "stardist_prob_t"), cast_fn=float, default=0.5)
96121
stardist_nms_t = read_parameter(os.path.join(INPUT_DIR, "stardist_nms_t"), cast_fn=float, default=0.5)
97122
image_path = os.path.join(INPUT_DIR, "image")
98123

124+
# Read ROI if it exists
125+
roi_path = os.path.join(INPUT_DIR, "roi")
126+
roi_polygon = None
127+
if os.path.isfile(roi_path):
128+
with open(roi_path, "r") as fp:
129+
roi_content = fp.read().strip()
130+
roi = geojson.loads(roi_content)
131+
roi_polygon = Polygon(roi['coordinates'][0])
132+
99133
# use local model file in ~/models/2D_versatile_HE/
100134
model = StarDist2D(None, name='2D_versatile_HE', basedir=MODEL_DATA_DIR)
101135

@@ -110,7 +144,6 @@ def main():
110144
axis=(0, 1) # normalize channels independently
111145
)
112146

113-
114147
# Stardist model prediction with thresholds
115148
_, details = model.predict_instances(
116149
img,
@@ -119,13 +152,30 @@ def main():
119152
n_tiles=model._guess_n_tiles(img)
120153
)
121154

122-
# writing ouputs
155+
# Filter nuclei if ROI is provided
156+
if roi_polygon is not None:
157+
filtered_coords = []
158+
filtered_probs = []
159+
160+
for i, nucleus_coords in enumerate(details['coord']):
161+
if is_nucleus_inside_roi(nucleus_coords, roi_polygon, image_height):
162+
filtered_coords.append(nucleus_coords)
163+
filtered_probs.append(details['prob'][i])
164+
else:
165+
filtered_coords = list(details['coord'])
166+
filtered_probs = details['prob'].tolist()
167+
168+
# writing outputs
123169
write_array(
124170
array_path=os.path.join(OUTPUT_DIR, "nuclei"),
125-
array_data=details['coord'],
126-
format_fn=lambda poly: from_stardist_to_geojson_string(poly, image_height)
171+
array_data=filtered_coords,
172+
format_fn=lambda poly: from_stardist_to_geojson_string(poly, image_height),
173+
)
174+
write_array(
175+
array_path=os.path.join(OUTPUT_DIR, "probs"),
176+
array_data=filtered_probs,
177+
format_fn=str,
127178
)
128-
write_array(array_path=os.path.join(OUTPUT_DIR, "probs"), array_data=details['prob'].tolist(), format_fn=str)
129179

130180

131181
if __name__ == "__main__":

0 commit comments

Comments
 (0)