Skip to content

Commit 504351f

Browse files
committed
Small fixes
1 parent d0d41ec commit 504351f

File tree

3 files changed

+26
-23
lines changed

3 files changed

+26
-23
lines changed

flamingo_tools/segmentation/postprocessing.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -325,12 +325,12 @@ def graph_connected_components(coords, min_edge_distance, min_component_length):
325325
graph.add_node(num, pos=pos)
326326

327327
# create edges between points whose distance is less than threshold min_edge_distance
328-
for i in coords:
329-
for j in coords:
330-
if i < j:
331-
dist = math.dist(coords[i], coords[j])
328+
for num_i, pos_i in coords.items():
329+
for num_j, pos_j in coords.items():
330+
if num_i < num_j:
331+
dist = math.dist(pos_i, pos_j)
332332
if dist <= min_edge_distance:
333-
graph.add_edge(i, j, weight=dist)
333+
graph.add_edge(num_i, num_j, weight=dist)
334334

335335
components = list(nx.connected_components(graph))
336336

flamingo_tools/segmentation/unet_prediction.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -432,7 +432,8 @@ def run_unet_prediction_preprocess_slurm(
432432
if not os.path.isdir(os.path.join(output_folder, "mask.zarr")):
433433
find_mask(input_path, input_key, output_folder)
434434

435-
calc_mean_and_std(input_path, input_key, output_folder)
435+
if not os.path.isfile(os.path.join(output_folder, "mean_std.json")):
436+
calc_mean_and_std(input_path, input_key, output_folder)
436437

437438

438439
def run_unet_prediction_slurm(
@@ -504,15 +505,23 @@ def run_unet_prediction_slurm(
504505
def run_unet_segmentation_slurm(
505506
output_folder: str,
506507
min_size: int,
508+
center_distance_threshold: float = 0.4,
507509
boundary_distance_threshold: float = 0.5,
510+
fg_threshold: float = 0.5,
508511
) -> None:
509512
"""Create segmentation from prediction.
510513
511514
Args:
512515
output_folder: The output folder for storing the segmentation related data.
513516
min_size: The minimal size of segmented objects in the output.
517+
center_distance_threshold: The threshold applied to the distance center predictions to derive seeds.
518+
boundary_distance_threshold: The threshold applied to the boundary predictions to derive seeds.
519+
By default this is set to 'None', in which case the boundary distances are not used for the seeds.
520+
fg_threshold: The threshold applied to the foreground prediction for deriving the watershed mask.
514521
"""
515522
min_size = int(min_size)
516523
pmap_out = os.path.join(output_folder, "predictions.zarr")
517-
distance_watershed_implementation(pmap_out, output_folder, boundary_distance_threshold=boundary_distance_threshold,
524+
distance_watershed_implementation(pmap_out, output_folder, center_distance_threshold=center_distance_threshold,
525+
boundary_distance_threshold=boundary_distance_threshold,
526+
fg_threshold=fg_threshold,
518527
min_size=min_size)

scripts/extract_block.py

Lines changed: 10 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,14 @@
1010
import zarr
1111

1212
import flamingo_tools.s3_utils as s3_utils
13+
from flamingo_tools.file_utils import read_image_data
1314

1415

1516
def main(
1617
input_path: str,
1718
coords: List[int],
18-
output_dir: str = None,
19-
input_key: str = "setup0/timepoint0/s0",
19+
output_dir: Optional[str] = None,
20+
input_key: Optional[str] = None,
2021
output_key: Optional[str] = None,
2122
resolution: float = 0.38,
2223
roi_halo: List[int] = [128, 128, 64],
@@ -88,24 +89,17 @@ def main(
8889
roi = tuple(slice(co - rh, co + rh) for co, rh in zip(coords, roi_halo))
8990

9091
if s3:
91-
s3_path, fs = s3_utils.get_s3_path(input_path, bucket_name=s3_bucket_name,
92-
service_endpoint=s3_service_endpoint, credential_file=s3_credentials)
92+
input_path, fs = s3_utils.get_s3_path(input_path, bucket_name=s3_bucket_name,
93+
service_endpoint=s3_service_endpoint, credential_file=s3_credentials)
9394

94-
with zarr.open(s3_path, mode="r") as f:
95-
raw = f[input_key][roi]
96-
97-
elif ".tif" in input_path:
98-
raw = read_tif(input_path)[roi]
99-
100-
else:
101-
with zarr.open(input_path, mode="r") as f:
102-
raw = f[input_key][roi]
95+
data_ = read_image_data(input_path, input_key)
96+
data_roi = data_[roi]
10397

10498
if tif:
105-
imageio.imwrite(output_file, raw, compression="zlib")
99+
imageio.imwrite(output_file, data_roi, compression="zlib")
106100
else:
107101
with zarr.open(output_file, mode="w") as f_out:
108-
f_out.create_dataset(output_key, data=raw, compression="gzip")
102+
f_out.create_dataset(output_key, data=data_roi, compression="gzip")
109103

110104

111105
if __name__ == "__main__":
@@ -118,7 +112,7 @@ def main(
118112
parser.add_argument('-c', "--coord", type=str, required=True,
119113
help="3D coordinate as center of extracted block, json-encoded.")
120114

121-
parser.add_argument('-k', "--input_key", type=str, default="setup0/timepoint0/s0",
115+
parser.add_argument('-k', "--input_key", type=str, default=None,
122116
help="Input key for data in input file.")
123117
parser.add_argument("--output_key", type=str, default=None,
124118
help="Output key for data in output file.")

0 commit comments

Comments
 (0)