Skip to content

Commit e561842

Browse files
committed
Merge remote-tracking branch 'LennyN95/sequential_inference'
2 parents 11cd116 + 3f32df8 commit e561842

File tree

1 file changed

+28
-9
lines changed

1 file changed

+28
-9
lines changed

nnunetv2/inference/predict_from_raw_data.py

Lines changed: 28 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -766,6 +766,12 @@ def predict_from_files_sequential(self,
766766
empty_cache(self.device)
767767
return ret
768768

769+
def _getDefaultValue(env: str, dtype: type, default: any,) -> any:
770+
try:
771+
val = dtype(os.environ.get(env) or default)
772+
except:
773+
val = default
774+
return val
769775

770776
def predict_entry_point_modelfolder():
771777
import argparse
@@ -901,10 +907,10 @@ def predict_entry_point():
901907
help='Continue an aborted previous prediction (will not overwrite existing files)')
902908
parser.add_argument('-chk', type=str, required=False, default='checkpoint_final.pth',
903909
help='Name of the checkpoint you want to use. Default: checkpoint_final.pth')
904-
parser.add_argument('-npp', type=int, required=False, default=3,
910+
parser.add_argument('-npp', type=int, required=False, default=_getDefaultValue('nnUNet_npp', int, 3),
905911
help='Number of processes used for preprocessing. More is not always better. Beware of '
906912
'out-of-RAM issues. Default: 3')
907-
parser.add_argument('-nps', type=int, required=False, default=3,
913+
parser.add_argument('-nps', type=int, required=False, default=_getDefaultValue('nnUNet_nps', int, 3),
908914
help='Number of processes used for segmentation export. More is not always better. Beware of '
909915
'out-of-RAM issues. Default: 3')
910916
parser.add_argument('-prev_stage_predictions', type=str, required=False, default=None,
@@ -971,13 +977,26 @@ def predict_entry_point():
971977
args.f,
972978
checkpoint_name=args.chk
973979
)
974-
predictor.predict_from_files(args.i, args.o, save_probabilities=args.save_probabilities,
975-
overwrite=not args.continue_prediction,
976-
num_processes_preprocessing=args.npp,
977-
num_processes_segmentation_export=args.nps,
978-
folder_with_segs_from_prev_stage=args.prev_stage_predictions,
979-
num_parts=args.num_parts,
980-
part_id=args.part_id)
980+
981+
run_sequential = args.nps == 0 and args.npp == 0
982+
983+
if run_sequential:
984+
985+
print("Running in non-multiprocessing mode")
986+
predictor.predict_from_files_sequential(args.i, args.o, save_probabilities=args.save_probabilities,
987+
overwrite=not args.continue_prediction,
988+
folder_with_segs_from_prev_stage=args.prev_stage_predictions)
989+
990+
else:
991+
992+
predictor.predict_from_files(args.i, args.o, save_probabilities=args.save_probabilities,
993+
overwrite=not args.continue_prediction,
994+
num_processes_preprocessing=args.npp,
995+
num_processes_segmentation_export=args.nps,
996+
folder_with_segs_from_prev_stage=args.prev_stage_predictions,
997+
num_parts=args.num_parts,
998+
part_id=args.part_id)
999+
9811000
# r = predict_from_raw_data(args.i,
9821001
# args.o,
9831002
# model_folder,

0 commit comments

Comments
 (0)