Skip to content

Commit 7d318f5

Browse files
committed
Update inference script
1 parent 02f28ec commit 7d318f5

File tree

1 file changed

+82
-40
lines changed

1 file changed

+82
-40
lines changed

celldetection_scripts/cpn_inference.py

Lines changed: 82 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ def dict_collate_fn(batch, check_padding=True, img_min_ndim=2) -> Union[OrderedD
4040
results[k] = cd.data.padding_stack(*items, axis=0)
4141
else:
4242
if isinstance(items[0], torch.Tensor):
43-
results[k] = torch.stack(items, axis=0)
43+
results[k] = torch.stack(items, dim=0)
4444
else:
4545
results[k] = np.stack(items, axis=0)
4646
if image_like:
@@ -364,7 +364,8 @@ def cpn_inference(
364364
brightness: float = 0.0,
365365
percentile: Optional[List[float]] = None,
366366
model_parameters: str = '',
367-
verbose: bool = True
367+
verbose: bool = True,
368+
skip_existing: bool = False
368369
):
369370
"""
370371
Process contour proposals for instance segmentation using specified parameters.
@@ -424,6 +425,7 @@ def cpn_inference(
424425
percentile (list[float], optional): Percentile norm. Performs min-max normalization with specified percentiles. Default is None.
425426
model_parameters (str): Model parameters. Pass as string in "key=value,key1=value1" format. Default is ''.
426427
verbose (bool): Verbosity toggle.
428+
skip_existing(bool): Whether to inputs with existing output files.
427429
"""
428430

429431
args = dict(locals())
@@ -470,7 +472,7 @@ def resolve_inputs_(collection, x, tag='inputs'):
470472
model_list = []
471473
for m in models:
472474
if isinstance(m, nn.Module):
473-
model_list.append(models)
475+
model_list.append(m)
474476
else:
475477
assert isinstance(m, str)
476478
if m.startswith('http://') or m.startswith('https://') or m.startswith('cd://') or (
@@ -490,7 +492,7 @@ def resolve_inputs_(collection, x, tag='inputs'):
490492
if verbose and model_parameters is not None and len(model_parameters):
491493
print('Changing the following model parameters:', model_parameters)
492494

493-
if devices.isnumeric():
495+
if isinstance(devices, str) and devices.isnumeric():
494496
devices = int(devices)
495497

496498
if verbose:
@@ -513,13 +515,18 @@ def resolve_inputs_(collection, x, tag='inputs'):
513515

514516
makedirs(outputs, exist_ok=True)
515517

516-
def load_inputs(x, dataset_name, method, tag, idx):
518+
def load_inputs(x, dataset_name, method, tag, idx, ext_checks=('.h5',)):
517519
if isinstance(x, np.ndarray):
518520
dst = join(outputs, f'ndarray_{idx}' + '{ext}')
519521
image = x
520522
else:
521523
prefix, ext = splitext(basename(x))
522524
dst = join(outputs, prefix + '{ext}')
525+
526+
if skip_existing:
527+
if any(isfile(dst.format(ext=ext)) for ext in ext_checks):
528+
raise FileExistsError
529+
523530
if x.startswith('http://') or x.startswith('https://'):
524531
image = cd.fetch_image(x)
525532
elif ext in ('.h5', '.hdf5'):
@@ -536,8 +543,16 @@ def load_inputs(x, dataset_name, method, tag, idx):
536543
image = cd.load_image(x, method=method)
537544
return image, dst
538545

546+
output_list = []
539547
for src_idx, src in enumerate(input_list):
540-
img, dst = load_inputs(src, inputs_dataset, inputs_method, 'inputs', idx=src_idx)
548+
try:
549+
img, dst = load_inputs(src, inputs_dataset, inputs_method, 'inputs', idx=src_idx)
550+
except FileExistsError:
551+
if verbose:
552+
print('Skipping input, because output exists already:', src)
553+
continue
554+
dst_h5 = dst.format(ext='.h5')
555+
541556
if isinstance(src, np.ndarray):
542557
inputs_tup = 'ndarray',
543558
else:
@@ -585,6 +600,9 @@ def load_inputs(x, dataset_name, method, tag, idx):
585600
))
586601

587602
is_dist = is_available() and is_initialized()
603+
output = cd.asnumpy(y)
604+
output_list.append(output)
605+
out_files = dict()
588606
if (is_dist and get_rank() == 0) or not is_dist:
589607
props = properties
590608
do_props = props is not None and len(props)
@@ -594,31 +612,40 @@ def load_inputs(x, dataset_name, method, tag, idx):
594612
if do_labels:
595613
labels_ = cd.data.contours2labels(y['contours'], img.shape[:2])
596614
if labels:
597-
y['labels'] = labels_
615+
y['labels'] = output['labels'] = labels_
598616
if flat_labels_:
599617
flat_labels_ = cd.data.resolve_label_channels(labels_)
600618
if flat_labels_:
601-
y['flat_labels'] = flat_labels_
619+
y['flat_labels'] = output['flat_labels'] = flat_labels_
602620

603-
cd.to_h5(dst.format(ext='.h5'), **cd.asnumpy(y), # json since None values in attrs are not supported
621+
out_files['h5'] = dst_h5
622+
cd.to_h5(dst_h5, **output, # json since None values in attrs are not supported
604623
attributes=dict(contours=dict(args=cd.dict_to_json_string(args))))
605624
if do_props: # TODO: Use mask in properties (writing out labels)
606625
if flat_labels_:
607626
assert flat_labels_ is not None
608-
cd.data.labels2property_table(flat_labels_, props, spacing=spacing,
609-
separator=separator).to_csv(dst.format(ext='_flat.csv'))
627+
tab = cd.data.labels2property_table(flat_labels_, props, spacing=spacing,
628+
separator=separator)
629+
output['properties_flat'] = tab
630+
out_files['properties_flat'] = dst_flat_csv = dst.format(ext='_flat.csv')
631+
tab.to_csv(dst_flat_csv)
610632
if labels or not flat_labels_:
611633
assert labels_ is not None
612-
cd.data.labels2property_table(labels_, props, spacing=spacing, separator=separator).to_csv(
613-
dst.format(ext='.csv'))
634+
tab = cd.data.labels2property_table(labels_, props, spacing=spacing, separator=separator)
635+
output['properties'] = tab
636+
out_files['properties'] = dst_csv = dst.format(ext='.csv')
637+
tab.to_csv(dst_csv)
614638

615639
if overlay:
616640
if do_labels:
617641
assert labels_ is not None or flat_labels_ is not None
618642
label_vis = img_as_ubyte(cd.label_cmap(flat_labels_ if labels_ is None else labels_))
619643
else:
620644
label_vis = cd.data.contours2overlay(y['contours'], img.shape[:2])
621-
tifffile.imwrite(dst.format(ext='_overlay.tif'), label_vis, compression='ZLIB')
645+
dst_ove_tif = dst.format(ext='_overlay.tif')
646+
tifffile.imwrite(dst_ove_tif, label_vis, compression='ZLIB')
647+
output['overlay'] = label_vis
648+
out_files['overlay'] = dst_ove_tif
622649

623650
if demo_figure:
624651
from matplotlib import pyplot as plt
@@ -627,28 +654,40 @@ def load_inputs(x, dataset_name, method, tag, idx):
627654
cd.plot_boxes(y['boxes'])
628655
loc = cd.asnumpy(y['locations'])
629656
plt.scatter(loc[:, 0], loc[:, 1], marker='x')
630-
cd.save_fig(dst.format(ext='_demo.png'))
657+
out_files['demo_figure'] = dst_demo = dst.format(ext='_demo.png')
658+
cd.save_fig(dst_demo)
659+
if len(out_files):
660+
output['files'] = out_files
661+
return output_list
631662

632663

633664
def main():
665+
from inspect import signature
666+
667+
par = signature(cpn_inference).parameters
668+
669+
def d(name):
670+
return par[name].default
671+
634672
parser = argparse.ArgumentParser('Contour Proposal Networks for Instance Segmentation')
635673

636674
parser.add_argument('-i', '--inputs', nargs='+', type=str,
637675
help='Inputs. Either filename, name pattern (glob), or URL (leading http:// or https://).')
638676
parser.add_argument('-o', '--outputs', default='outputs', type=str, help='output path')
639-
parser.add_argument('--inputs_method', default='imageio',
677+
parser.add_argument('--inputs_method', default=d('inputs_method'),
640678
help='Method used for loading non-hdf5 inputs.')
641-
parser.add_argument('--inputs_dataset', default='image', help='Dataset name for hdf5 inputs.')
679+
parser.add_argument('--inputs_dataset', default=d('inputs_dataset'),
680+
help='Dataset name for hdf5 inputs.')
642681
parser.add_argument('-m', '--models', nargs='+',
643682
help='Model. Either filename, name pattern (glob), URL (leading http:// or https://), or '
644683
'hosted model name (leading cd://). '
645684
'Example: `--model \'cd://ginoro_CpnResNeXt101UNet-fbe875f1a3e5ce2c\'`')
646-
parser.add_argument('--masks', default=None, nargs='+', type=str,
685+
parser.add_argument('--masks', default=d('masks'), nargs='+', type=str,
647686
help='Masks. Either filename, name pattern (glob), or URL (leading http:// or https://). '
648687
'A mask determines where the model searches for objects. Regions with values <= 0'
649688
'are ignored. Hence, objects will only be found where the mask is positive. '
650689
'Masks are linked to inputs by order. If masks are used, all inputs must have one.')
651-
parser.add_argument('--point_masks', default=None, nargs='+', type=str,
690+
parser.add_argument('--point_masks', default=d('point_masks'), nargs='+', type=str,
652691
help='Point masks. Either filename, name pattern (glob), or URL (leading http:// or https://). '
653692
'A point mask is a mask image with positive values at an object`s location. '
654693
'The model aims to convert points to contours. '
@@ -658,30 +697,30 @@ def main():
658697
'Otherwise (default), the points in `point_masks` are considered non-exclusive, meaning '
659698
'other objects are detected and segmented in addition. '
660699
'Note that this option overrides `masks`.')
661-
parser.add_argument('--masks_dataset', default='mask', help='Dataset name for hdf5 inputs.')
662-
parser.add_argument('--point_masks_dataset', default='point_mask', help='Dataset name for hdf5 inputs.')
663-
parser.add_argument('--devices', default='auto', type=str, help='Devices.')
664-
parser.add_argument('--accelerator', default='auto', type=str, help='Accelerator.')
665-
parser.add_argument('--strategy', default='auto', type=str, help='Strategy.')
700+
parser.add_argument('--masks_dataset', default=d('masks_dataset'), help='Dataset name for hdf5 inputs.')
701+
parser.add_argument('--point_masks_dataset', default=d('point_masks_dataset'), help='Dataset name for hdf5 inputs.')
702+
parser.add_argument('--devices', default=d('devices'), type=str, help='Devices.')
703+
parser.add_argument('--accelerator', default=d('accelerator'), type=str, help='Accelerator.')
704+
parser.add_argument('--strategy', default=d('strategy'), type=str, help='Strategy.')
666705
parser.add_argument('--precision', default='32-true', type=str,
667706
help='Precision. One of (64, 64-true, 32, 32-true, 16, 16-mixed, bf16, bf16-mixed)')
668-
parser.add_argument('--num_workers', default=0, type=int, help='Number of workers.')
669-
parser.add_argument('--prefetch_factor', default=2, type=int,
707+
parser.add_argument('--num_workers', default=d('num_workers'), type=int, help='Number of workers.')
708+
parser.add_argument('--prefetch_factor', default=d('prefetch_factor'), type=int,
670709
help='Number of batches loaded in advance by each worker.')
671710
parser.add_argument('--pin_memory', action='store_true',
672711
help='If set, the data loader will copy Tensors into device/CUDA '
673712
'pinned memory before returning them.')
674-
parser.add_argument('--batch_size', default=1, type=int, help='How many samples per batch to load.')
675-
parser.add_argument('--tile_size', default=1024, nargs='+', type=int,
713+
parser.add_argument('--batch_size', default=d('batch_size'), type=int, help='How many samples per batch to load.')
714+
parser.add_argument('--tile_size', default=d('tile_size'), nargs='+', type=int,
676715
help='Tile/window size for sliding window processing.')
677-
parser.add_argument('--stride', default=768, nargs='+', type=int,
716+
parser.add_argument('--stride', default=d('stride'), nargs='+', type=int,
678717
help='Stride for sliding window processing.')
679-
parser.add_argument('--border_removal', default=4, type=int,
718+
parser.add_argument('--border_removal', default=d('border_removal'), type=int,
680719
help='Number of border pixels for the removal of '
681720
'partial objects during tiled inference.')
682-
parser.add_argument('--stitching_rule', default='nms', type=str,
721+
parser.add_argument('--stitching_rule', default=d('stitching_rule'), type=str,
683722
help='Stitching rule to use for collating results from sliding window processing.')
684-
parser.add_argument('--min_vote', default=1, type=int,
723+
parser.add_argument('--min_vote', default=d('min_vote'), type=int,
685724
help='Required smallest vote count for a detected object to be accepted. '
686725
'Only used for ensembles. Minimum vote count is 1, maximum the number of '
687726
'models that are part of the ensemble.')
@@ -697,21 +736,23 @@ def main():
697736
parser.add_argument('--truncated_images', action='store_true',
698737
help='Whether to support truncated images.')
699738
parser.add_argument('-p', '--properties', nargs='*', help='Region properties')
700-
parser.add_argument('--spacing', default=1., type=float,
739+
parser.add_argument('--spacing', default=d('spacing'), type=float,
701740
help='The pixel spacing. Relevant for pixel-based region properties.')
702-
parser.add_argument('--separator', default='-', type=str,
741+
parser.add_argument('--separator', default=d('separator'), type=str,
703742
help='Separator string for region properties that are written to multiple columns. '
704743
'Default is "-" as in bbox-0, bbox-1, bbox-2, bbox-4.')
705744

706-
parser.add_argument('--gamma', default=1., type=float, help='Gamma value for gamma transform.')
707-
parser.add_argument('--contrast', default=1., type=float, help='Factor for contrast adjustment.')
708-
parser.add_argument('--brightness', default=0., type=float, help='Factor for brightness adjustment.')
709-
parser.add_argument('--percentile', default=None, nargs='+', type=float,
745+
parser.add_argument('--gamma', default=d('gamma'), type=float, help='Gamma value for gamma transform.')
746+
parser.add_argument('--contrast', default=d('contrast'), type=float, help='Factor for contrast adjustment.')
747+
parser.add_argument('--brightness', default=d('brightness'), type=float, help='Factor for brightness adjustment.')
748+
parser.add_argument('--percentile', default=d('percentile'), nargs='+', type=float,
710749
help='Percentile norm. Performs min-max normalization with specified percentiles.'
711750
'Specify either two values `(min, max)` or just `max` interpreted as '
712751
'(1 - max, max).')
713-
parser.add_argument('--model_parameters', default='', type=str,
752+
parser.add_argument('--model_parameters', default=d('model_parameters'), type=str,
714753
help='Model parameters. Pass as string in "key=value,key1=value1" format')
754+
parser.add_argument('--skip_existing', action='store_true',
755+
help='Whether to skip existing files. ')
715756

716757
args, unknown = parser.parse_known_args()
717758

@@ -756,7 +797,8 @@ def main():
756797
contrast=args.contrast,
757798
brightness=args.brightness,
758799
percentile=args.percentile,
759-
model_parameters=args.model_parameters
800+
model_parameters=args.model_parameters,
801+
skip_existing=args.skip_existing
760802
)
761803

762804
if not (is_available() and is_initialized()) or get_rank() == 0: # because why not

0 commit comments

Comments
 (0)