Skip to content

Commit 9e2344b

Browse files
committed
enable training with m2 mac
1 parent f6cf4fd commit 9e2344b

File tree

2 files changed

+24
-11
lines changed

2 files changed

+24
-11
lines changed

trainer/src/model_utils.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -48,14 +48,14 @@ def get_latest_model_paths(model_dir, k):
4848

4949
def load_model(model_path):
5050
model = UNetGNRes()
51-
if torch.cuda.is_available():
51+
if torch.cuda.is_available() or torch.backends.mps.is_available():
5252
try:
5353
model.load_state_dict(torch.load(model_path))
5454
model = torch.nn.DataParallel(model)
5555
except:
5656
model = torch.nn.DataParallel(model)
5757
model.load_state_dict(torch.load(model_path))
58-
model.cuda()
58+
model.to(device)
5959
else:
6060
# if you are running on a CPU-only machine, please use torch.load with
6161
# map_location=torch.device('cpu') to map your storages to the CPU.
@@ -77,8 +77,8 @@ def create_first_model_with_random_weights(model_dir):
7777
model_path = os.path.join(model_dir, model_name)
7878
torch.save(model.state_dict(), model_path)
7979

80-
if torch.cuda.is_available():
81-
model.cuda()
80+
if torch.cuda.is_available() or torch.backends.mps.is_avilable():
81+
model.to(device)
8282
return model
8383

8484

@@ -292,14 +292,13 @@ def unet_segment(cnn, image, bs, in_w, out_w, threshold=0.5):
292292
tile_idx += 1
293293
tiles_to_process.append(tile)
294294
tiles_for_gpu = torch.from_numpy(np.array(tiles_to_process))
295-
if torch.cuda.is_available():
296-
tiles_for_gpu.cuda()
295+
tiles_for_gpu = tiles_for_gpu.to(device)
297296
tiles_for_gpu = tiles_for_gpu.float()
298297
batches.append(tiles_for_gpu)
299298

300299
output_tiles = []
301300
for gpu_tiles in batches:
302-
outputs = cnn(gpu_tiles.cuda())
301+
outputs = cnn(gpu_tiles.to(device))
303302
softmaxed = softmax(outputs, 1)
304303
foreground_probs = softmaxed[:, 1, :] # just the foreground probability.
305304
if threshold is not None:

trainer/src/trainer.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -89,14 +89,25 @@ def __init__(self, sync_dir=None, patch_size=572,
8989
total_mem = 0
9090
self.num_workers=min(multiprocessing.cpu_count(), max_workers)
9191
print(self.num_workers, 'workers assigned for data loader')
92-
print('GPU Available', torch.cuda.is_available())
92+
print('CUDA Available', torch.cuda.is_available())
93+
9394
if torch.cuda.is_available():
9495
for i in range(torch.cuda.device_count()):
9596
total_mem += torch.cuda.get_device_properties(i).total_memory
97+
98+
print('MPS Available', torch.backends.mps.is_available())
99+
# MPS only has one device.
100+
# There is no obvious way of getting memory for MPS
101+
# FIXME: setting arbitrary amount of memory.
102+
if torch.backends.mps.is_available():
103+
total_mem = 24_589_934_592
104+
105+
if total_mem > 0: # means CUDA or MPS found
96106
self.bs = total_mem // mem_per_item
97107
self.bs = min(12, self.bs)
98108
else:
99109
self.bs = 1 # cpu is batch size of 1
110+
100111
print('Batch size', self.bs)
101112
self.optimizer = None
102113
# used to check for updates
@@ -287,6 +298,9 @@ def train_one_epoch(self):
287298
if not [is_photo(a) for a in ls(val_annot_dir)]:
288299
return
289300

301+
302+
device = model_utils.get_device()
303+
290304
if self.first_loop:
291305
self.first_loop = False
292306
self.write_message('Training started')
@@ -313,9 +327,9 @@ def train_one_epoch(self):
313327
defined_tiles) in enumerate(train_loader):
314328

315329
self.check_for_instructions()
316-
photo_tiles = photo_tiles.cuda()
317-
foreground_tiles = foreground_tiles.cuda()
318-
defined_tiles = defined_tiles.cuda()
330+
photo_tiles = photo_tiles.to(device)
331+
foreground_tiles = foreground_tiles.to(device)
332+
defined_tiles = defined_tiles.to(device)
319333
self.optimizer.zero_grad()
320334
outputs = self.model(photo_tiles)
321335
softmaxed = softmax(outputs, 1)

0 commit comments

Comments
 (0)