1+ import  warnings 
12from  abc  import  ABC , abstractmethod 
23from  collections  import  OrderedDict 
34from  itertools  import  chain 
89import  torch 
910import  torch .nn  as  nn 
1011import  yaml 
11- from  pathos .multiprocessing  import  ThreadPool  as  Pool 
1212from  torch .utils .data  import  DataLoader 
1313from  tqdm  import  tqdm 
1414
15- from  ..utils  import  tensor_to_ndarray 
16- from  ..utils .save_utils  import  mask2mat 
15+ from  ..utils  import  FileHandler , tensor_to_ndarray 
1716from  .folder_dataset  import  FolderDataset 
1817from  .post_processor  import  PostProcessor 
1918from  .predictor  import  Predictor 
@@ -33,14 +32,14 @@ def __init__(
3332        normalization : str  =  None ,
3433        device : str  =  "cuda" ,
3534        n_devices : int  =  1 ,
36-         save_masks : bool  =  True ,
3735        save_intermediate : bool  =  False ,
3836        save_dir : Union [Path , str ] =  None ,
37+         save_format : str  =  ".mat" ,
3938        checkpoint_path : Union [Path , str ] =  None ,
4039        n_images : int  =  None ,
4140        type_post_proc : Callable  =  None ,
4241        sem_post_proc : Callable  =  None ,
43-         ** postproc_kwargs ,
42+         ** kwargs ,
4443    ) ->  None :
4544        """Inference for an image folder. 
4645
@@ -77,16 +76,14 @@ def __init__(
7776            n_devices : int, default=1 
7877                Number of devices (cpus/gpus) used for inference. 
7978                The model will be copied into these devices. 
80-             save_masks : bool, default=False 
81-                 If True, the resulting segmentation masks will be saved into `out_masks` 
82-                 variable. 
83-             save_intermediate : bool, default=False 
84-                 If True, intermediate soft masks will be saved into `soft_masks` var. 
8579            save_dir : bool, optional 
8680                Path to save directory. If None, no masks will be saved to disk as .mat 
87-                 files. If not None, overrides `save_masks`, thus for every batch the 
88-                 segmentation results are saved into disk and the intermediate results 
89-                 are flushed. 
81+                 or .json files. Instead the masks will be saved in `self.out_masks`. 
82+             save_intermediate : bool, default=False 
83+                 If True, intermediate soft masks will be saved into `soft_masks` var. 
84+             save_format : str, default=".mat" 
85+                 The file format for the saved output masks. One of (".mat", ".json"). 
86+                 The ".json" option will save masks into geojson format. 
9087            checkpoint_path : Path | str, optional 
9188                Path to the model weight checkpoints. 
9289            n_images : int, optional 
@@ -97,8 +94,8 @@ def __init__(
9794            sem_post_proc : Callable, optional 
9895                A post-processing function for the semantc seg maps. If not None, 
9996                overrides the default. 
100-             **postproc_kwargs : 
101-                 Arbitrary keyword arguments for the  post-processing. 
97+             **kwargs : 
98+                 Arbitrary keyword arguments expecially  for post-processing and saving . 
10299        """ 
103100        # basic inits 
104101        self .model  =  model 
@@ -109,14 +106,25 @@ def __init__(
109106        self .out_activations  =  out_activations 
110107        self .out_boundary_weights  =  out_boundary_weights 
111108        self .head_kwargs  =  self ._check_and_set_head_args ()
109+         self .kwargs  =  kwargs 
112110
113111        self .save_dir  =  Path (save_dir ) if  save_dir  is  not   None  else  None 
114-         self .save_masks  =  save_masks 
115112        self .save_intermediate  =  save_intermediate 
113+         self .save_format  =  save_format 
116114
117115        # dataloader 
118116        self .path  =  Path (input_folder )
117+ 
119118        folder_ds  =  FolderDataset (self .path , n_images = n_images )
119+         if  self .save_dir  is  None  and  len (folder_ds .fnames ) >  40 :
120+             warnings .warn (
121+                 "`save_dir` is None. Thus, the outputs are be saved in `out_masks` " 
122+                 "class variable. If the input folder contains many images, running " 
123+                 "inference will likely flood the memory depending on the size and " 
124+                 "number of the images. Consider saving outputs to disk by providing " 
125+                 "`save_dir` argument." 
126+             )
127+ 
120128        self .dataloader  =  DataLoader (
121129            folder_ds , batch_size = batch_size , shuffle = False , pin_memory = True 
122130        )
@@ -128,7 +136,7 @@ def __init__(
128136            aux_key = self .model .aux_key ,
129137            type_post_proc = type_post_proc ,
130138            sem_post_proc = sem_post_proc ,
131-             ** postproc_kwargs ,
139+             ** kwargs ,
132140        )
133141
134142        # load weights and set devices 
@@ -188,10 +196,16 @@ def _infer_batch(self):
188196    def  infer (self ) ->  None :
189197        """Run inference and post-processing for the images. 
190198
191-         NOTE: Saves outputs in `self.out_masks` or to disk (.mat) files. 
192- 
193-         `self.out_masks` is a nested dict: E.g. 
194-             {"image1": {"inst": [H, W], "type": [H, W], "sem": [H, W]}} 
199+         NOTE: 
200+         - Saves outputs in `self.out_masks` or to disk (.mat/.json) files. 
201+         - If `save_intermediate` is set to True, also intermiediate model outputs are 
202+             saved to `self.soft_masks` 
203+         - `self.out_masks` and `self.soft_masks` are nested dicts: E.g. 
204+                 {"sample1": {"inst": [H, W], "type": [H, W], "sem": [H, W]}} 
205+         - If masks are saved to geojson .json files, more key word arguments 
206+             need to be given at class initialization. Namely: `geo_format`, 
207+             `classes_type`, `classes_sem`, `offsets`. See more in the 
208+             `FileHandler.save_masks` docs. 
195209        """ 
196210        self .soft_masks  =  {}
197211        self .out_masks  =  {}
@@ -223,89 +237,25 @@ def infer(self) -> None:
223237                            self .soft_masks [n ] =  m 
224238
225239                    if  self .save_dir  is  None :
226-                         if  self .save_masks :
227-                             for  n , m  in  zip (names , seg_results ):
228-                                 self .out_masks [n ] =  m 
240+                         for  n , m  in  zip (names , seg_results ):
241+                             self .out_masks [n ] =  m 
229242                    else :
230243                        loader .set_postfix_str ("Saving results to disk" )
231244                        if  self .batch_size  >  1 :
232-                             self .save_parallel (seg_results , names , self .save_dir )
245+                             fnames  =  [Path (self .save_dir ) /  n  for  n  in  names ]
246+                             FileHandler .save_masks_parallel (
247+                                 maps = seg_results ,
248+                                 fnames = fnames ,
249+                                 ** {** self .kwargs , "format" : self .save_format },
250+                             )
233251                        else :
234252                            for  n , m  in  zip (names , seg_results ):
235-                                 self .save_mask (m , n , self .save_dir )
236- 
237-     @staticmethod  
238-     def  save_mask (
239-         maps : Dict [str , np .ndarray ],
240-         fname : str ,
241-         save_dir : Union [str , Path ],
242-         format : str  =  ".mat" ,
243-     ) ->  None :
244-         """Save model outputs to .mat or geojson. 
245- 
246-         Parameters 
247-         ---------- 
248-             maps : Dict[str, np.ndarray] 
249-                 model output names mapped to model outputs. 
250-                 E.g. {"sem": np.ndarray, "type": np.ndarray, "inst": np.ndarray} 
251-             fname : str 
252-                 Name for the output-file. 
253-             save_dir : Path or str 
254-                 Path to the save directory. 
255-             format : str 
256-                 One of ".mat" or "geojson" 
257-         """ 
258-         allowed  =  (".mat" , ".json" )
259-         if  format  not  in   allowed :
260-             raise  ValueError (
261-                 f"Illegal file-format. Got: { format }  . Allowed formats: { allowed }  " 
262-             )
263- 
264-         if  format  ==  ".mat" :
265-             mask2mat (fname , save_dir , ** maps )
266-         else :
267-             pass 
268- 
269-         return  True 
270- 
271-     @staticmethod  
272-     def  save_parallel (
273-         maps : List [Dict [str , np .ndarray ]],
274-         fnames : List [str ],
275-         save_dir : Union [Path , str ],
276-         format : str  =  ".mat" ,
277-         progress_bar : bool  =  False ,
278-     ) ->  None :
279-         """Save the model output masks to a folder. (multi-threaded). 
280- 
281-         Parameters 
282-         ---------- 
283-             maps : List[Dict[str, np.ndarray]] 
284-                 The model output map dictionaries in a list. 
285-             fnames : List[str] 
286-                 Name for the output-files. (In the same order with `maps`) 
287-             save_dir : Path or str 
288-                 Path to the save directory. 
289-             format : str 
290-                 One of ".mat" or "geojson" 
291-             progress_bar : bool, default=False 
292-                 If True, a tqdm progress bar is shown. 
293-         """ 
294-         args  =  tuple (zip (maps , fnames , [save_dir ] *  len (maps ), [format ] *  len (maps )))
295- 
296-         with  Pool () as  pool :
297-             if  progress_bar :
298-                 it  =  tqdm (pool .imap (BaseInferer ._save_mask , args ), total = len (maps ))
299-             else :
300-                 it  =  pool .imap (BaseInferer ._save_mask , args )
301- 
302-             for  _  in  it :
303-                 pass 
304- 
305-     @staticmethod  
306-     def  _save_mask (args : Tuple [Dict [str , np .ndarray ], str , str ]) ->  None :
307-         """Unpacks the args for `save_mask` to enable multi-threading.""" 
308-         return  BaseInferer .save_mask (* args )
253+                                 fname  =  Path (self .save_dir ) /  n 
254+                                 FileHandler .save_masks (
255+                                     fname = fname ,
256+                                     maps = m ,
257+                                     ** {** self .kwargs , "format" : self .save_format },
258+                                 )
309259
310260    def  _strip_state_dict (self , ckpt : Dict ) ->  OrderedDict :
311261        """Strip te first 'model.' (generated by lightning) from the state dict keys.""" 
0 commit comments