2222from monailabel .interfaces .exception import MONAILabelError , MONAILabelException
2323from monailabel .interfaces .utils .transform import run_transforms
2424from monailabel .transform .writer import Writer
25+ from monailabel .utils .others .generic import device_list
2526
2627logger = logging .getLogger (__name__ )
2728
@@ -57,7 +58,7 @@ class InferTask:
5758
5859 def __init__ (
5960 self ,
60- path : Union [str , Sequence [str ]],
61+ path : Union [None , str , Sequence [str ]],
6162 network : Union [None , Any ],
6263 type : Union [str , InferType ],
6364 labels : Union [str , None , Sequence [str ], Dict [Any , Any ]],
@@ -70,6 +71,7 @@ def __init__(
7071 config : Union [None , Dict [str , Any ]] = None ,
7172 load_strict : bool = False ,
7273 roi_size = None ,
74+ preload = False ,
7375 ):
7476 """
7577 :param path: Model File Path. Supports multiple paths to support versions (Last item will be picked as latest)
@@ -84,8 +86,9 @@ def __init__(
8486 :param config: K,V pairs to be part of user config
8587 :param load_strict: Load model in strict mode
8688 :param roi_size: ROI size for scanning window inference
89+ :param preload: Preload model/network on all available GPU devices
8790 """
88- self .path = path
91+ self .path = [] if not path else [ path ] if isinstance ( path , str ) else path
8992 self .network = network
9093 self .type = type
9194 self .labels = [] if labels is None else [labels ] if isinstance (labels , str ) else labels
@@ -99,8 +102,9 @@ def __init__(
99102 self .roi_size = roi_size
100103
101104 self ._networks : Dict = {}
105+
102106 self ._config : Dict [str , Any ] = {
103- # "device": "cuda" ,
107+ # "device": device_list() ,
104108 # "result_extension": None,
105109 # "result_dtype": None,
106110 # "result_compress": False
@@ -111,6 +115,11 @@ def __init__(
111115 if config :
112116 self ._config .update (config )
113117
118+ if preload :
119+ for device in device_list ():
120+ logger .info (f"Preload Network for device: { device } " )
121+ self ._get_network (device )
122+
114123 def info (self ) -> Dict [str , Any ]:
115124 return {
116125 "type" : self .type ,
@@ -127,19 +136,19 @@ def is_valid(self) -> bool:
127136 if self .network or self .type == InferType .SCRIBBLES :
128137 return True
129138
130- paths = [ self . path ] if isinstance ( self . path , str ) else self .path
139+ paths = self .path
131140 for path in reversed (paths ):
132- if os .path .exists (path ):
141+ if path and os .path .exists (path ):
133142 return True
134143 return False
135144
136145 def get_path (self ):
137146 if not self .path :
138147 return None
139148
140- paths = [ self . path ] if isinstance ( self . path , str ) else self .path
149+ paths = self .path
141150 for path in reversed (paths ):
142- if os .path .exists (path ):
151+ if path and os .path .exists (path ):
143152 return path
144153 return None
145154
@@ -247,7 +256,10 @@ def __call__(self, request) -> Tuple[str, Dict[str, Any]]:
247256
248257 # device
249258 device = req .get ("device" , "cuda" )
259+ if device .startswith ("cuda" ) and not torch .cuda .is_available ():
260+ device = "cpu"
250261 req ["device" ] = device
262+
251263 logger .setLevel (req .get ("logging" , "INFO" ).upper ())
252264 logger .info (f"Infer Request (final): { req } " )
253265
@@ -346,9 +358,6 @@ def _get_network(self, device):
346358 f"Model Path ({ self .path } ) does not exist/valid" ,
347359 )
348360
349- if device .startswith ("cuda" ) and not torch .cuda .is_available ():
350- device = "cpu"
351-
352361 cached = self ._networks .get (device )
353362 statbuf = os .stat (path ) if path else None
354363 network = None
@@ -360,16 +369,15 @@ def _get_network(self, device):
360369
361370 if network is None :
362371 if self .network :
363- network = self .network
372+ network = copy .deepcopy (self .network )
373+ network .to (torch .device (device ))
374+
364375 if path :
365376 checkpoint = torch .load (path , map_location = torch .device (device ))
366377 model_state_dict = checkpoint .get (self .model_state_dict , checkpoint )
367378 network .load_state_dict (model_state_dict , strict = self .load_strict )
368379 else :
369- network = torch .jit .load (path , map_location = torch .device (device ))
370-
371- if device .startswith ("cuda" ):
372- network = network .cuda (device )
380+ network = torch .jit .load (path , map_location = torch .device (device )).to (torch .device )
373381
374382 network .eval ()
375383 self ._networks [device ] = (network , statbuf .st_mtime if statbuf else 0 )
@@ -388,22 +396,18 @@ def run_inferer(self, data, convert_to_batch=True, device="cuda"):
388396 """
389397
390398 inferer = self .inferer (data )
391- logger .info ("Inferer:: {} => {}" .format (inferer .__class__ .__name__ , inferer .__dict__ ))
392-
393- device = device if device else "cuda"
394- if device .startswith ("cuda" ) and not torch .cuda .is_available ():
395- device = "cpu"
399+ logger .info ("Inferer:: {} => {} => {}" .format (device , inferer .__class__ .__name__ , inferer .__dict__ ))
396400
397401 network = self ._get_network (device )
398402 if network :
399403 inputs = data [self .input_key ]
400404 inputs = inputs if torch .is_tensor (inputs ) else torch .from_numpy (inputs )
401405 inputs = inputs [None ] if convert_to_batch else inputs
402- if device .startswith ("cuda" ):
403- inputs = inputs .cuda (torch .device (device ))
406+ inputs = inputs .to (torch .device (device ))
404407
405408 with torch .no_grad ():
406409 outputs = inferer (inputs , network )
410+
407411 if device .startswith ("cuda" ):
408412 torch .cuda .empty_cache ()
409413
0 commit comments