From 5787a9ac53c42b1ba21cf2efa0fedfb776e5b223 Mon Sep 17 00:00:00 2001 From: LennyN95 Date: Mon, 17 Feb 2025 17:11:08 +0100 Subject: [PATCH 1/2] add nnUNet_npp and nnUNet_nps environment variables --- nnunetv2/inference/predict_from_raw_data.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/nnunetv2/inference/predict_from_raw_data.py b/nnunetv2/inference/predict_from_raw_data.py index 1985afe52..cf787576d 100644 --- a/nnunetv2/inference/predict_from_raw_data.py +++ b/nnunetv2/inference/predict_from_raw_data.py @@ -748,6 +748,12 @@ def predict_from_files_sequential(self, empty_cache(self.device) return ret +def _getDefaultValue(env: str, dtype: type, default: any,) -> any: + try: + val = dtype(os.environ.get(env) or default) + except: + val = default + return val def predict_entry_point_modelfolder(): import argparse @@ -883,10 +889,10 @@ def predict_entry_point(): help='Continue an aborted previous prediction (will not overwrite existing files)') parser.add_argument('-chk', type=str, required=False, default='checkpoint_final.pth', help='Name of the checkpoint you want to use. Default: checkpoint_final.pth') - parser.add_argument('-npp', type=int, required=False, default=3, + parser.add_argument('-npp', type=int, required=False, default=_getDefaultValue('nnUNet_npp', int, 3), help='Number of processes used for preprocessing. More is not always better. Beware of ' 'out-of-RAM issues. Default: 3') - parser.add_argument('-nps', type=int, required=False, default=3, + parser.add_argument('-nps', type=int, required=False, default=_getDefaultValue('nnUNet_nps', int, 3), help='Number of processes used for segmentation export. More is not always better. Beware of ' 'out-of-RAM issues. Default: 3') parser.add_argument('-prev_stage_predictions', type=str, required=False, default=None, From 3f32df8986fe15e999ae2da799750a7f9a774351 Mon Sep 17 00:00:00 2001 From: LennyN95 Date: Mon, 17 Feb 2025 17:11:32 +0100 Subject: [PATCH 2/2] use predict_from_files_sequential when nps and npp are zero --- nnunetv2/inference/predict_from_raw_data.py | 27 +++++++++++++++------ 1 file changed, 20 insertions(+), 7 deletions(-) diff --git a/nnunetv2/inference/predict_from_raw_data.py b/nnunetv2/inference/predict_from_raw_data.py index cf787576d..4a7f97c8a 100644 --- a/nnunetv2/inference/predict_from_raw_data.py +++ b/nnunetv2/inference/predict_from_raw_data.py @@ -959,13 +959,26 @@ def predict_entry_point(): args.f, checkpoint_name=args.chk ) - predictor.predict_from_files(args.i, args.o, save_probabilities=args.save_probabilities, - overwrite=not args.continue_prediction, - num_processes_preprocessing=args.npp, - num_processes_segmentation_export=args.nps, - folder_with_segs_from_prev_stage=args.prev_stage_predictions, - num_parts=args.num_parts, - part_id=args.part_id) + + run_sequential = args.nps == 0 and args.npp == 0 + + if run_sequential: + + print("Running in non-multiprocessing mode") + predictor.predict_from_files_sequential(args.i, args.o, save_probabilities=args.save_probabilities, + overwrite=not args.continue_prediction, + folder_with_segs_from_prev_stage=args.prev_stage_predictions) + + else: + + predictor.predict_from_files(args.i, args.o, save_probabilities=args.save_probabilities, + overwrite=not args.continue_prediction, + num_processes_preprocessing=args.npp, + num_processes_segmentation_export=args.nps, + folder_with_segs_from_prev_stage=args.prev_stage_predictions, + num_parts=args.num_parts, + part_id=args.part_id) + # r = predict_from_raw_data(args.i, # args.o, # model_folder,