@@ -162,14 +162,15 @@ def __init__(
162162 except BaseException as e :
163163 print (e )
164164
165- assert device in ("cuda" , "cpu" )
165+ assert device in ("cuda" , "cpu" , "mps" )
166166 if device == "cpu" :
167167 self .device = torch .device ("cpu" )
168- if torch .cuda .is_available () and device == "cuda" :
168+ elif torch .cuda .is_available () and device == "cuda" :
169169 self .device = torch .device ("cuda" )
170-
171170 if torch .cuda .device_count () > 1 and n_devices > 1 :
172171 self .model = nn .DataParallel (self .model , device_ids = range (n_devices ))
172+ elif torch .backends .mps .is_available () and device == "mps" :
173+ self .device = torch .device ("mps" )
173174
174175 self .model .to (self .device )
175176 self .model .eval ()
@@ -245,6 +246,14 @@ def infer(self) -> None:
245246 for n , m in zip (names , soft_masks ):
246247 self .soft_masks [n ] = m
247248
249+ # Quick kludge to add soft type and sem to seg_results
250+ for soft , seg in zip (soft_masks , seg_results ):
251+ if "type" in soft .keys ():
252+ seg ["soft_type" ] = soft ["type" ]
253+ if "sem" in soft .keys ():
254+ seg ["soft_sem" ] = soft ["sem" ]
255+
256+ # save to cache or disk
248257 if self .save_dir is None :
249258 for n , m in zip (names , seg_results ):
250259 self .out_masks [n ] = m
0 commit comments