33from glob import glob
44
55import imageio .v3 as imageio
6+ import numpy as np
7+
8+ from skimage .segmentation import watershed
9+ from skimage .measure import label
610from torch_em .util import load_model
711from torch_em .util .prediction import predict_with_halo
812
9- INPUT_ROOT = "/mnt/vast-nhr/home/pape41/u12086/Work/my_projects/flamingo-tools/scripts/more-annotations/LA_VISION_M04" # noqa
1013
14+ def _get_files (sgn = True ):
15+ input_root = "/mnt/vast-nhr/home/pape41/u12086/Work/my_projects/flamingo-tools/scripts/more-annotations/LA_VISION_M04" # noqa
16+ input_root2 = "/mnt/vast-nhr/home/pape41/u12086/Work/my_projects/flamingo-tools/scripts/more-annotations/LA_VISION_M04_2" # noqa
17+ input_root3 = "/mnt/vast-nhr/home/pape41/u12086/Work/my_projects/flamingo-tools/scripts/more-annotations/LA_VISION_Mar05" # noqa
18+
19+ input_root_ihc = "/mnt/vast-nhr/home/pape41/u12086/Work/my_projects/flamingo-tools/scripts/more-annotations/LA_VISION_Mar05-ihc" # noqa
20+
21+ if sgn :
22+ input_files = glob (os .path .join (input_root , "*.tif" )) + \
23+ glob (os .path .join (input_root2 , "*.tif" )) + \
24+ glob (os .path .join (input_root3 , "*.tif" ))
25+ else :
26+ input_files = glob (os .path .join (input_root_ihc , "*.tif" ))
27+
28+ return input_files
1129
12- def predict_blocks (model_path , name ):
30+
31+ def predict_blocks (model_path , name , sgn = True ):
1332 output_folder = os .path .join ("./predictions" , name )
1433 os .makedirs (output_folder , exist_ok = True )
1534
16- input_blocks = glob ( os . path . join ( INPUT_ROOT , "*.tif" ) )
35+ input_blocks = _get_files ( sgn )
1736
18- model = load_model ( model_path )
37+ model = None
1938 for path in input_blocks :
39+ out_path = os .path .join (output_folder , os .path .basename (path ))
40+ if os .path .exists (out_path ):
41+ continue
42+ if model is None :
43+ model = load_model (model_path )
2044 data = imageio .imread (path )
2145 pred = predict_with_halo (data , model , gpu_ids = [0 ], block_shape = [64 , 128 , 128 ], halo = [8 , 32 , 32 ])
22- out_path = os .path .join (output_folder , os .path .basename (path ))
2346 imageio .imwrite (out_path , pred , compression = "zlib" )
2447
2548
2649def _segment_impl (pred , dist_threshold = 0.5 ):
27- import numpy as np
28- from skimage .segmentation import watershed
29- from skimage .measure import label
30-
3150 fg , center_dist , boundary_dist = pred
3251 mask = fg > 0.5
3352
@@ -37,35 +56,42 @@ def _segment_impl(pred, dist_threshold=0.5):
3756 return seg
3857
3958
40- def check_segmentation (name ):
59+ def check_segmentation (name , sgn ):
4160 import napari
4261
43- input_blocks = sorted ( glob ( os . path . join ( INPUT_ROOT , "*.tif" )) )
62+ input_files = _get_files ( sgn )
4463
4564 output_folder = os .path .join ("./predictions" , name )
46- pred = sorted (glob (os .path .join (output_folder , "*.tif" )))
4765
48- for path , pred_path in zip ( input_blocks , pred ) :
66+ for path in input_files :
4967 image = imageio .imread (path )
68+ pred_path = os .path .join (output_folder , os .path .basename (path ))
5069 pred = imageio .imread (pred_path )
51- seg = _segment_impl (pred )
70+ if sgn :
71+ seg = _segment_impl (pred )
72+ else :
73+ seg = label (pred [0 ] > 0.5 )
5274 v = napari .Viewer ()
5375 v .add_image (image )
5476 v .add_image (pred )
5577 v .add_labels (seg )
78+ v .title = os .path .basename (path )
5679 napari .run ()
5780
5881
5982# Model path for original training on low-res SGNs:
6083# /mnt/vast-nhr/home/pape41/u12086/Work/my_projects/flamingo-tools/scripts/training/checkpoints/cochlea_distance_unet_low-res-sgn # noqa
6184def main ():
6285 parser = argparse .ArgumentParser ()
63- parser .add_argument ("--model_path" , "-m" , required = True )
6486 parser .add_argument ("--name" , "-n" , required = True )
87+ parser .add_argument ("--model_path" , "-m" )
88+ parser .add_argument ("--check" , action = "store_true" )
89+ parser .add_argument ("--ihc" , action = "store_true" )
6590 args = parser .parse_args ()
6691
67- # predict_blocks(args.model_path, args.name)
68- check_segmentation (args .name )
92+ predict_blocks (args .model_path , args .name , sgn = not args .ihc )
93+ if args .check :
94+ check_segmentation (args .name , sgn = not args .ihc )
6995
7096
7197if __name__ == "__main__" :
0 commit comments