@@ -40,7 +40,7 @@ def dict_collate_fn(batch, check_padding=True, img_min_ndim=2) -> Union[OrderedD
4040 results [k ] = cd .data .padding_stack (* items , axis = 0 )
4141 else :
4242 if isinstance (items [0 ], torch .Tensor ):
43- results [k ] = torch .stack (items , axis = 0 )
43+ results [k ] = torch .stack (items , dim = 0 )
4444 else :
4545 results [k ] = np .stack (items , axis = 0 )
4646 if image_like :
@@ -364,7 +364,8 @@ def cpn_inference(
364364 brightness : float = 0.0 ,
365365 percentile : Optional [List [float ]] = None ,
366366 model_parameters : str = '' ,
367- verbose : bool = True
367+ verbose : bool = True ,
368+ skip_existing : bool = False
368369):
369370 """
370371 Process contour proposals for instance segmentation using specified parameters.
@@ -424,6 +425,7 @@ def cpn_inference(
424425 percentile (list[float], optional): Percentile norm. Performs min-max normalization with specified percentiles. Default is None.
425426 model_parameters (str): Model parameters. Pass as string in "key=value,key1=value1" format. Default is ''.
426427 verbose (bool): Verbosity toggle.
428+ skip_existing(bool): Whether to inputs with existing output files.
427429 """
428430
429431 args = dict (locals ())
@@ -470,7 +472,7 @@ def resolve_inputs_(collection, x, tag='inputs'):
470472 model_list = []
471473 for m in models :
472474 if isinstance (m , nn .Module ):
473- model_list .append (models )
475+ model_list .append (m )
474476 else :
475477 assert isinstance (m , str )
476478 if m .startswith ('http://' ) or m .startswith ('https://' ) or m .startswith ('cd://' ) or (
@@ -490,7 +492,7 @@ def resolve_inputs_(collection, x, tag='inputs'):
490492 if verbose and model_parameters is not None and len (model_parameters ):
491493 print ('Changing the following model parameters:' , model_parameters )
492494
493- if devices .isnumeric ():
495+ if isinstance ( devices , str ) and devices .isnumeric ():
494496 devices = int (devices )
495497
496498 if verbose :
@@ -513,13 +515,18 @@ def resolve_inputs_(collection, x, tag='inputs'):
513515
514516 makedirs (outputs , exist_ok = True )
515517
516- def load_inputs (x , dataset_name , method , tag , idx ):
518+ def load_inputs (x , dataset_name , method , tag , idx , ext_checks = ( '.h5' ,) ):
517519 if isinstance (x , np .ndarray ):
518520 dst = join (outputs , f'ndarray_{ idx } ' + '{ext}' )
519521 image = x
520522 else :
521523 prefix , ext = splitext (basename (x ))
522524 dst = join (outputs , prefix + '{ext}' )
525+
526+ if skip_existing :
527+ if any (isfile (dst .format (ext = ext )) for ext in ext_checks ):
528+ raise FileExistsError
529+
523530 if x .startswith ('http://' ) or x .startswith ('https://' ):
524531 image = cd .fetch_image (x )
525532 elif ext in ('.h5' , '.hdf5' ):
@@ -536,8 +543,16 @@ def load_inputs(x, dataset_name, method, tag, idx):
536543 image = cd .load_image (x , method = method )
537544 return image , dst
538545
546+ output_list = []
539547 for src_idx , src in enumerate (input_list ):
540- img , dst = load_inputs (src , inputs_dataset , inputs_method , 'inputs' , idx = src_idx )
548+ try :
549+ img , dst = load_inputs (src , inputs_dataset , inputs_method , 'inputs' , idx = src_idx )
550+ except FileExistsError :
551+ if verbose :
552+ print ('Skipping input, because output exists already:' , src )
553+ continue
554+ dst_h5 = dst .format (ext = '.h5' )
555+
541556 if isinstance (src , np .ndarray ):
542557 inputs_tup = 'ndarray' ,
543558 else :
@@ -585,6 +600,9 @@ def load_inputs(x, dataset_name, method, tag, idx):
585600 ))
586601
587602 is_dist = is_available () and is_initialized ()
603+ output = cd .asnumpy (y )
604+ output_list .append (output )
605+ out_files = dict ()
588606 if (is_dist and get_rank () == 0 ) or not is_dist :
589607 props = properties
590608 do_props = props is not None and len (props )
@@ -594,31 +612,40 @@ def load_inputs(x, dataset_name, method, tag, idx):
594612 if do_labels :
595613 labels_ = cd .data .contours2labels (y ['contours' ], img .shape [:2 ])
596614 if labels :
597- y ['labels' ] = labels_
615+ y ['labels' ] = output [ 'labels' ] = labels_
598616 if flat_labels_ :
599617 flat_labels_ = cd .data .resolve_label_channels (labels_ )
600618 if flat_labels_ :
601- y ['flat_labels' ] = flat_labels_
619+ y ['flat_labels' ] = output [ 'flat_labels' ] = flat_labels_
602620
603- cd .to_h5 (dst .format (ext = '.h5' ), ** cd .asnumpy (y ), # json since None values in attrs are not supported
621+ out_files ['h5' ] = dst_h5
622+ cd .to_h5 (dst_h5 , ** output , # json since None values in attrs are not supported
604623 attributes = dict (contours = dict (args = cd .dict_to_json_string (args ))))
605624 if do_props : # TODO: Use mask in properties (writing out labels)
606625 if flat_labels_ :
607626 assert flat_labels_ is not None
608- cd .data .labels2property_table (flat_labels_ , props , spacing = spacing ,
609- separator = separator ).to_csv (dst .format (ext = '_flat.csv' ))
627+ tab = cd .data .labels2property_table (flat_labels_ , props , spacing = spacing ,
628+ separator = separator )
629+ output ['properties_flat' ] = tab
630+ out_files ['properties_flat' ] = dst_flat_csv = dst .format (ext = '_flat.csv' )
631+ tab .to_csv (dst_flat_csv )
610632 if labels or not flat_labels_ :
611633 assert labels_ is not None
612- cd .data .labels2property_table (labels_ , props , spacing = spacing , separator = separator ).to_csv (
613- dst .format (ext = '.csv' ))
634+ tab = cd .data .labels2property_table (labels_ , props , spacing = spacing , separator = separator )
635+ output ['properties' ] = tab
636+ out_files ['properties' ] = dst_csv = dst .format (ext = '.csv' )
637+ tab .to_csv (dst_csv )
614638
615639 if overlay :
616640 if do_labels :
617641 assert labels_ is not None or flat_labels_ is not None
618642 label_vis = img_as_ubyte (cd .label_cmap (flat_labels_ if labels_ is None else labels_ ))
619643 else :
620644 label_vis = cd .data .contours2overlay (y ['contours' ], img .shape [:2 ])
621- tifffile .imwrite (dst .format (ext = '_overlay.tif' ), label_vis , compression = 'ZLIB' )
645+ dst_ove_tif = dst .format (ext = '_overlay.tif' )
646+ tifffile .imwrite (dst_ove_tif , label_vis , compression = 'ZLIB' )
647+ output ['overlay' ] = label_vis
648+ out_files ['overlay' ] = dst_ove_tif
622649
623650 if demo_figure :
624651 from matplotlib import pyplot as plt
@@ -627,28 +654,40 @@ def load_inputs(x, dataset_name, method, tag, idx):
627654 cd .plot_boxes (y ['boxes' ])
628655 loc = cd .asnumpy (y ['locations' ])
629656 plt .scatter (loc [:, 0 ], loc [:, 1 ], marker = 'x' )
630- cd .save_fig (dst .format (ext = '_demo.png' ))
657+ out_files ['demo_figure' ] = dst_demo = dst .format (ext = '_demo.png' )
658+ cd .save_fig (dst_demo )
659+ if len (out_files ):
660+ output ['files' ] = out_files
661+ return output_list
631662
632663
633664def main ():
665+ from inspect import signature
666+
667+ par = signature (cpn_inference ).parameters
668+
669+ def d (name ):
670+ return par [name ].default
671+
634672 parser = argparse .ArgumentParser ('Contour Proposal Networks for Instance Segmentation' )
635673
636674 parser .add_argument ('-i' , '--inputs' , nargs = '+' , type = str ,
637675 help = 'Inputs. Either filename, name pattern (glob), or URL (leading http:// or https://).' )
638676 parser .add_argument ('-o' , '--outputs' , default = 'outputs' , type = str , help = 'output path' )
639- parser .add_argument ('--inputs_method' , default = 'imageio' ,
677+ parser .add_argument ('--inputs_method' , default = d ( 'inputs_method' ) ,
640678 help = 'Method used for loading non-hdf5 inputs.' )
641- parser .add_argument ('--inputs_dataset' , default = 'image' , help = 'Dataset name for hdf5 inputs.' )
679+ parser .add_argument ('--inputs_dataset' , default = d ('inputs_dataset' ),
680+ help = 'Dataset name for hdf5 inputs.' )
642681 parser .add_argument ('-m' , '--models' , nargs = '+' ,
643682 help = 'Model. Either filename, name pattern (glob), URL (leading http:// or https://), or '
644683 'hosted model name (leading cd://). '
645684 'Example: `--model \' cd://ginoro_CpnResNeXt101UNet-fbe875f1a3e5ce2c\' `' )
646- parser .add_argument ('--masks' , default = None , nargs = '+' , type = str ,
685+ parser .add_argument ('--masks' , default = d ( 'masks' ) , nargs = '+' , type = str ,
647686 help = 'Masks. Either filename, name pattern (glob), or URL (leading http:// or https://). '
648687 'A mask determines where the model searches for objects. Regions with values <= 0'
649688 'are ignored. Hence, objects will only be found where the mask is positive. '
650689 'Masks are linked to inputs by order. If masks are used, all inputs must have one.' )
651- parser .add_argument ('--point_masks' , default = None , nargs = '+' , type = str ,
690+ parser .add_argument ('--point_masks' , default = d ( 'point_masks' ) , nargs = '+' , type = str ,
652691 help = 'Point masks. Either filename, name pattern (glob), or URL (leading http:// or https://). '
653692 'A point mask is a mask image with positive values at an object`s location. '
654693 'The model aims to convert points to contours. '
@@ -658,30 +697,30 @@ def main():
658697 'Otherwise (default), the points in `point_masks` are considered non-exclusive, meaning '
659698 'other objects are detected and segmented in addition. '
660699 'Note that this option overrides `masks`.' )
661- parser .add_argument ('--masks_dataset' , default = 'mask' , help = 'Dataset name for hdf5 inputs.' )
662- parser .add_argument ('--point_masks_dataset' , default = 'point_mask' , help = 'Dataset name for hdf5 inputs.' )
663- parser .add_argument ('--devices' , default = 'auto' , type = str , help = 'Devices.' )
664- parser .add_argument ('--accelerator' , default = 'auto' , type = str , help = 'Accelerator.' )
665- parser .add_argument ('--strategy' , default = 'auto' , type = str , help = 'Strategy.' )
700+ parser .add_argument ('--masks_dataset' , default = d ( 'masks_dataset' ) , help = 'Dataset name for hdf5 inputs.' )
701+ parser .add_argument ('--point_masks_dataset' , default = d ( 'point_masks_dataset' ) , help = 'Dataset name for hdf5 inputs.' )
702+ parser .add_argument ('--devices' , default = d ( 'devices' ) , type = str , help = 'Devices.' )
703+ parser .add_argument ('--accelerator' , default = d ( 'accelerator' ) , type = str , help = 'Accelerator.' )
704+ parser .add_argument ('--strategy' , default = d ( 'strategy' ) , type = str , help = 'Strategy.' )
666705 parser .add_argument ('--precision' , default = '32-true' , type = str ,
667706 help = 'Precision. One of (64, 64-true, 32, 32-true, 16, 16-mixed, bf16, bf16-mixed)' )
668- parser .add_argument ('--num_workers' , default = 0 , type = int , help = 'Number of workers.' )
669- parser .add_argument ('--prefetch_factor' , default = 2 , type = int ,
707+ parser .add_argument ('--num_workers' , default = d ( 'num_workers' ) , type = int , help = 'Number of workers.' )
708+ parser .add_argument ('--prefetch_factor' , default = d ( 'prefetch_factor' ) , type = int ,
670709 help = 'Number of batches loaded in advance by each worker.' )
671710 parser .add_argument ('--pin_memory' , action = 'store_true' ,
672711 help = 'If set, the data loader will copy Tensors into device/CUDA '
673712 'pinned memory before returning them.' )
674- parser .add_argument ('--batch_size' , default = 1 , type = int , help = 'How many samples per batch to load.' )
675- parser .add_argument ('--tile_size' , default = 1024 , nargs = '+' , type = int ,
713+ parser .add_argument ('--batch_size' , default = d ( 'batch_size' ) , type = int , help = 'How many samples per batch to load.' )
714+ parser .add_argument ('--tile_size' , default = d ( 'tile_size' ) , nargs = '+' , type = int ,
676715 help = 'Tile/window size for sliding window processing.' )
677- parser .add_argument ('--stride' , default = 768 , nargs = '+' , type = int ,
716+ parser .add_argument ('--stride' , default = d ( 'stride' ) , nargs = '+' , type = int ,
678717 help = 'Stride for sliding window processing.' )
679- parser .add_argument ('--border_removal' , default = 4 , type = int ,
718+ parser .add_argument ('--border_removal' , default = d ( 'border_removal' ) , type = int ,
680719 help = 'Number of border pixels for the removal of '
681720 'partial objects during tiled inference.' )
682- parser .add_argument ('--stitching_rule' , default = 'nms' , type = str ,
721+ parser .add_argument ('--stitching_rule' , default = d ( 'stitching_rule' ) , type = str ,
683722 help = 'Stitching rule to use for collating results from sliding window processing.' )
684- parser .add_argument ('--min_vote' , default = 1 , type = int ,
723+ parser .add_argument ('--min_vote' , default = d ( 'min_vote' ) , type = int ,
685724 help = 'Required smallest vote count for a detected object to be accepted. '
686725 'Only used for ensembles. Minimum vote count is 1, maximum the number of '
687726 'models that are part of the ensemble.' )
@@ -697,21 +736,23 @@ def main():
697736 parser .add_argument ('--truncated_images' , action = 'store_true' ,
698737 help = 'Whether to support truncated images.' )
699738 parser .add_argument ('-p' , '--properties' , nargs = '*' , help = 'Region properties' )
700- parser .add_argument ('--spacing' , default = 1. , type = float ,
739+ parser .add_argument ('--spacing' , default = d ( 'spacing' ) , type = float ,
701740 help = 'The pixel spacing. Relevant for pixel-based region properties.' )
702- parser .add_argument ('--separator' , default = '-' , type = str ,
741+ parser .add_argument ('--separator' , default = d ( 'separator' ) , type = str ,
703742 help = 'Separator string for region properties that are written to multiple columns. '
704743 'Default is "-" as in bbox-0, bbox-1, bbox-2, bbox-4.' )
705744
706- parser .add_argument ('--gamma' , default = 1. , type = float , help = 'Gamma value for gamma transform.' )
707- parser .add_argument ('--contrast' , default = 1. , type = float , help = 'Factor for contrast adjustment.' )
708- parser .add_argument ('--brightness' , default = 0. , type = float , help = 'Factor for brightness adjustment.' )
709- parser .add_argument ('--percentile' , default = None , nargs = '+' , type = float ,
745+ parser .add_argument ('--gamma' , default = d ( 'gamma' ) , type = float , help = 'Gamma value for gamma transform.' )
746+ parser .add_argument ('--contrast' , default = d ( 'contrast' ) , type = float , help = 'Factor for contrast adjustment.' )
747+ parser .add_argument ('--brightness' , default = d ( 'brightness' ) , type = float , help = 'Factor for brightness adjustment.' )
748+ parser .add_argument ('--percentile' , default = d ( 'percentile' ) , nargs = '+' , type = float ,
710749 help = 'Percentile norm. Performs min-max normalization with specified percentiles.'
711750 'Specify either two values `(min, max)` or just `max` interpreted as '
712751 '(1 - max, max).' )
713- parser .add_argument ('--model_parameters' , default = '' , type = str ,
752+ parser .add_argument ('--model_parameters' , default = d ( 'model_parameters' ) , type = str ,
714753 help = 'Model parameters. Pass as string in "key=value,key1=value1" format' )
754+ parser .add_argument ('--skip_existing' , action = 'store_true' ,
755+ help = 'Whether to skip existing files. ' )
715756
716757 args , unknown = parser .parse_known_args ()
717758
@@ -756,7 +797,8 @@ def main():
756797 contrast = args .contrast ,
757798 brightness = args .brightness ,
758799 percentile = args .percentile ,
759- model_parameters = args .model_parameters
800+ model_parameters = args .model_parameters ,
801+ skip_existing = args .skip_existing
760802 )
761803
762804 if not (is_available () and is_initialized ()) or get_rank () == 0 : # because why not
0 commit comments