77import  numpy  as  np 
88import  torch 
99import  torch .nn  as  nn 
10+ import  yaml 
1011from  pathos .multiprocessing  import  ThreadPool  as  Pool 
1112from  torch .utils .data  import  DataLoader 
1213from  tqdm  import  tqdm 
@@ -31,6 +32,7 @@ def __init__(
3132        batch_size : int  =  8 ,
3233        normalization : str  =  None ,
3334        device : str  =  "cuda" ,
35+         n_devices : int  =  1 ,
3436        save_masks : bool  =  True ,
3537        save_intermediate : bool  =  False ,
3638        save_dir : Union [Path , str ] =  None ,
@@ -72,6 +74,9 @@ def __init__(
7274                One of: "dataset", "minmax", "norm", "percentile", None. 
7375            device : str, default="cuda" 
7476                The device of the input and model. One of: "cuda", "cpu" 
77+             n_devices : int, default=1 
78+                 Number of devices (cpus/gpus) used for inference. 
79+                 The model will be copied into these devices. 
7580            save_masks : bool, default=False 
7681                If True, the resulting segmentation masks will be saved into `out_masks` 
7782                variable. 
@@ -95,6 +100,16 @@ def __init__(
95100            **postproc_kwargs: 
96101                Arbitrary keyword arguments for the post-processing. 
97102        """ 
103+         # basic inits 
104+         self .model  =  model 
105+         self .out_heads  =  self ._get_out_info ()  # the names and num channels of out heads 
106+         self .batch_size  =  batch_size 
107+         self .patch_size  =  patch_size 
108+         self .padding  =  padding 
109+         self .out_activations  =  out_activations 
110+         self .out_boundary_weights  =  out_boundary_weights 
111+         self .head_kwargs  =  self ._check_and_set_head_args ()
112+ 
98113        self .save_dir  =  Path (save_dir ) if  save_dir  is  not   None  else  None 
99114        self .save_masks  =  save_masks 
100115        self .save_intermediate  =  save_intermediate 
@@ -106,17 +121,17 @@ def __init__(
106121            folder_ds , batch_size = batch_size , shuffle = False , pin_memory = True 
107122        )
108123
109-         # model and device 
110-         self .model  =  model 
111-         if  device  ==  "cpu" :
112-             self .model .cpu ()
113-             self .device  =  torch .device ("cpu" )
114-         if  torch .cuda .is_available () and  device  ==  "cuda" :
115-             self .model .cuda ()
116-             self .device  =  torch .device ("cuda" )
117- 
118-         self .model .eval ()
124+         # Set post processor 
125+         self .postprocessor  =  PostProcessor (
126+             instance_postproc ,
127+             inst_key = self .model .inst_key ,
128+             aux_key = self .model .aux_key ,
129+             type_post_proc = type_post_proc ,
130+             sem_post_proc = sem_post_proc ,
131+             ** postproc_kwargs ,
132+         )
119133
134+         # load weights and set devices 
120135        if  checkpoint_path  is  not   None :
121136            ckpt  =  torch .load (
122137                checkpoint_path , map_location = lambda  storage , loc : storage 
@@ -130,30 +145,41 @@ def __init__(
130145            except  BaseException  as  e :
131146                print (e )
132147
133-         # 
148+         assert  device  in  ("cuda" , "cpu" )
149+         if  device  ==  "cpu" :
150+             self .device  =  torch .device ("cpu" )
151+         if  torch .cuda .is_available () and  device  ==  "cuda" :
152+             self .device  =  torch .device ("cuda" )
153+ 
154+             if  torch .cuda .device_count () >  1  and  n_devices  >  1 :
155+                 self .model  =  nn .DataParallel (self .model , device_ids = range (n_devices ))
156+ 
157+         self .model .to (self .device )
158+         self .model .eval ()
159+ 
160+         # Helper class to perform forward + extra processing 
134161        self .predictor  =  Predictor (
135162            model = self .model ,
136163            patch_size = patch_size ,
137164            normalization = normalization ,
138165            device = self .device ,
139166        )
140-         self .out_heads  =  self ._get_out_info ()  # the names and num channels of out heads 
141-         self .batch_size  =  batch_size 
142-         self .patch_size  =  patch_size 
143-         self .padding  =  padding 
144-         self .out_activations  =  out_activations 
145-         self .out_boundary_weights  =  out_boundary_weights 
146-         self .head_kwargs  =  self ._check_and_set_head_args ()
147167
148-         # 
149-         self .postprocessor  =  PostProcessor (
150-             instance_postproc ,
151-             inst_key = self .model .inst_key ,
152-             aux_key = self .model .aux_key ,
153-             type_post_proc = type_post_proc ,
154-             sem_post_proc = sem_post_proc ,
155-             ** postproc_kwargs ,
156-         )
168+     @classmethod  
169+     def  from_yaml (cls , model : nn .Module , yaml_path : str ):
170+         """Initialize the inferer from a yaml-file. 
171+ 
172+         Parameters 
173+         ---------- 
174+             model : nn.Module 
175+                 Initialized segmentation model. 
176+             yaml_path : str 
177+                 Path to the yaml file containing rest of the params 
178+         """ 
179+         with  open (yaml_path , "r" ) as  stream :
180+             kwargs  =  yaml .full_load (stream )
181+ 
182+         return  cls (model , ** kwargs )
157183
158184    @abstractmethod  
159185    def  _infer_batch (self ):
0 commit comments