@@ -89,14 +89,25 @@ def __init__(self, sync_dir=None, patch_size=572,
8989 total_mem = 0
9090 self .num_workers = min (multiprocessing .cpu_count (), max_workers )
9191 print (self .num_workers , 'workers assigned for data loader' )
92- print ('GPU Available' , torch .cuda .is_available ())
92+ print ('CUDA Available' , torch .cuda .is_available ())
93+
9394 if torch .cuda .is_available ():
9495 for i in range (torch .cuda .device_count ()):
9596 total_mem += torch .cuda .get_device_properties (i ).total_memory
97+
98+ print ('MPS Available' , torch .backends .mps .is_available ())
99+ # MPS only has one device.
100+ # There is no obvious way of getting memory for MPS
101+ # FIXME: setting arbitrary amount of memory.
102+ if torch .backends .mps .is_available ():
103+ total_mem = 24_589_934_592
104+
105+ if total_mem > 0 : # means CUDA or MPS found
96106 self .bs = total_mem // mem_per_item
97107 self .bs = min (12 , self .bs )
98108 else :
99109 self .bs = 1 # cpu is batch size of 1
110+
100111 print ('Batch size' , self .bs )
101112 self .optimizer = None
102113 # used to check for updates
@@ -287,6 +298,9 @@ def train_one_epoch(self):
287298 if not [is_photo (a ) for a in ls (val_annot_dir )]:
288299 return
289300
301+
302+ device = model_utils .get_device ()
303+
290304 if self .first_loop :
291305 self .first_loop = False
292306 self .write_message ('Training started' )
@@ -313,9 +327,9 @@ def train_one_epoch(self):
313327 defined_tiles ) in enumerate (train_loader ):
314328
315329 self .check_for_instructions ()
316- photo_tiles = photo_tiles .cuda ( )
317- foreground_tiles = foreground_tiles .cuda ( )
318- defined_tiles = defined_tiles .cuda ( )
330+ photo_tiles = photo_tiles .to ( device )
331+ foreground_tiles = foreground_tiles .to ( device )
332+ defined_tiles = defined_tiles .to ( device )
319333 self .optimizer .zero_grad ()
320334 outputs = self .model (photo_tiles )
321335 softmaxed = softmax (outputs , 1 )
0 commit comments