11import re
2+ from functools import partial
23from pathlib import Path
34from typing import Any , Dict , List , Tuple , Union
45
56import cv2
67import numpy as np
78import scipy .io as sio
8- from pathos .multiprocessing import ThreadPool as Pool
9- from tqdm import tqdm
109
1110from .mask_utils import (
1211 bounding_box ,
1514 get_inst_types ,
1615 label_semantic ,
1716)
17+ from .multiproc import run_pool
1818
1919
2020class FileHandler :
@@ -58,30 +58,34 @@ def read_mat(
5858 key : str = "inst_map" ,
5959 retype : bool = True ,
6060 return_all : bool = False ,
61- ) -> Union [np .ndarray , None ]:
61+ ) -> Union [np .ndarray , Dict [ str , np . ndarray ], None ]:
6262 """Read a mask from a .mat file.
6363
6464 If a mask is not found, return None
6565
6666 Parameters
6767 ----------
6868 path : str or Path
69- Path to the image file.
69+ Path to the .mat file.
7070 key : str, default="inst_map"
7171 Name/key of the mask type that is being read from .mat
7272 retype : bool, default=True
7373 Convert the matrix type.
7474 return_all : bool, default=False
7575 Return the whole dict. Overrides the `key` arg.
7676
77- Returns
78- -------
79- np.ndarray or None:
80- The mask indice matrix. Shape (H, W)
8177
8278 Raises
8379 ------
8480 ValueError: If an illegal key is given.
81+
82+ Returns
83+ -------
84+ Union[np.ndarray, List[np.ndarray], None]:
85+ if return_all == False:
86+ The instance/type/semantic labelled mask. Shape: (H, W).
87+ if return_all == True:
88+ All the masks in the .mat file returned in a dictionary.
8589 """
8690 dtypes = {
8791 "inst_map" : "int32" ,
@@ -468,7 +472,8 @@ def save_masks_parallel(
468472 classes_type : Dict [str , str ] = None ,
469473 classes_sem : Dict [str , str ] = None ,
470474 offsets : bool = False ,
471- progress_bar : bool = False ,
475+ pooltype : str = "thread" ,
476+ maptype : str = "amap" ,
472477 ** kwargs ,
473478 ) -> None :
474479 """Save the model output masks to a folder. (multi-threaded).
@@ -493,31 +498,44 @@ def save_masks_parallel(
493498 offsets : bool, default=False
494499 If True, geojson coords are shifted by the offsets that are encoded in
495500 the filenames (e.g. "x-1000_y-4000.png"). Ignored if `format` != ".json"
496- progress_bar : bool, default=False
497- If True, a tqdm progress bar is shown.
501+ pooltype : str, default="thread"
502+ The pathos pooltype. Allowed: ("process", "thread", "serial").
503+ Defaults to "thread". (Fastest in benchmarks.)
504+ maptype : str, default="amap"
505+ The map type of the pathos Pool object.
506+ Allowed: ("map", "amap", "imap", "uimap")
507+ Defaults to "amap". (Fastest in benchmarks).
498508 """
499- formats = [ format ] * len ( maps )
500- geo_formats = [ geo_format ] * len ( maps )
501- classes_type = [ classes_type ] * len ( maps )
502- classes_sem = [ classes_sem ] * len ( maps )
503- offsets = [ offsets ] * len ( maps )
504- args = tuple (
505- zip ( fnames , maps , formats , geo_formats , classes_type , classes_sem , offsets )
509+ func = partial (
510+ FileHandler . _save_masks ,
511+ format = format ,
512+ geo_format = geo_format ,
513+ classes_type = classes_type ,
514+ classes_sem = classes_sem ,
515+ offsets = offsets ,
506516 )
507517
508- with Pool () as pool :
509- if progress_bar :
510- it = tqdm (pool .imap (FileHandler ._save_masks , args ), total = len (maps ))
511- else :
512- it = pool .imap (FileHandler ._save_masks , args )
513-
514- for _ in it :
515- pass
518+ args = tuple (zip (fnames , maps ))
519+ run_pool (func , args , ret = False , pooltype = pooltype , maptype = maptype )
516520
517521 @staticmethod
518- def _save_masks (args : Tuple [Dict [str , np .ndarray ], str , str ]) -> None :
522+ def _save_masks (
523+ args : Tuple [str , Dict [str , np .ndarray ]],
524+ format : str ,
525+ geo_format : str ,
526+ classes_type : Dict [str , str ],
527+ classes_sem : Dict [str , str ],
528+ offsets : bool ,
529+ ) -> None :
519530 """Unpacks the args for `save_mask` to enable multi-threading."""
520- return FileHandler .save_masks (* args )
531+ return FileHandler .save_masks (
532+ * args ,
533+ format = format ,
534+ geo_format = geo_format ,
535+ classes_type = classes_type ,
536+ classes_sem = classes_sem ,
537+ offsets = offsets ,
538+ )
521539
522540 @staticmethod
523541 def get_split (string : str ) -> List [str ]:
0 commit comments