diff --git a/nnunetv2/inference/predict_from_raw_data.py b/nnunetv2/inference/predict_from_raw_data.py index 1985afe52..4a7f97c8a 100644 --- a/nnunetv2/inference/predict_from_raw_data.py +++ b/nnunetv2/inference/predict_from_raw_data.py @@ -748,6 +748,12 @@ def predict_from_files_sequential(self, empty_cache(self.device) return ret +def _getDefaultValue(env: str, dtype: type, default: any,) -> any: + try: + val = dtype(os.environ.get(env) or default) + except: + val = default + return val def predict_entry_point_modelfolder(): import argparse @@ -883,10 +889,10 @@ def predict_entry_point(): help='Continue an aborted previous prediction (will not overwrite existing files)') parser.add_argument('-chk', type=str, required=False, default='checkpoint_final.pth', help='Name of the checkpoint you want to use. Default: checkpoint_final.pth') - parser.add_argument('-npp', type=int, required=False, default=3, + parser.add_argument('-npp', type=int, required=False, default=_getDefaultValue('nnUNet_npp', int, 3), help='Number of processes used for preprocessing. More is not always better. Beware of ' 'out-of-RAM issues. Default: 3') - parser.add_argument('-nps', type=int, required=False, default=3, + parser.add_argument('-nps', type=int, required=False, default=_getDefaultValue('nnUNet_nps', int, 3), help='Number of processes used for segmentation export. More is not always better. Beware of ' 'out-of-RAM issues. Default: 3') parser.add_argument('-prev_stage_predictions', type=str, required=False, default=None, @@ -953,13 +959,26 @@ def predict_entry_point(): args.f, checkpoint_name=args.chk ) - predictor.predict_from_files(args.i, args.o, save_probabilities=args.save_probabilities, - overwrite=not args.continue_prediction, - num_processes_preprocessing=args.npp, - num_processes_segmentation_export=args.nps, - folder_with_segs_from_prev_stage=args.prev_stage_predictions, - num_parts=args.num_parts, - part_id=args.part_id) + + run_sequential = args.nps == 0 and args.npp == 0 + + if run_sequential: + + print("Running in non-multiprocessing mode") + predictor.predict_from_files_sequential(args.i, args.o, save_probabilities=args.save_probabilities, + overwrite=not args.continue_prediction, + folder_with_segs_from_prev_stage=args.prev_stage_predictions) + + else: + + predictor.predict_from_files(args.i, args.o, save_probabilities=args.save_probabilities, + overwrite=not args.continue_prediction, + num_processes_preprocessing=args.npp, + num_processes_segmentation_export=args.nps, + folder_with_segs_from_prev_stage=args.prev_stage_predictions, + num_parts=args.num_parts, + part_id=args.part_id) + # r = predict_from_raw_data(args.i, # args.o, # model_folder,