Skip to content

Commit 889e4b2

Browse files
committed
Merge branch 'mac-trainer-m2'
2 parents 757cbf8 + 9e2344b commit 889e4b2

File tree

3 files changed

+32
-13
lines changed

3 files changed

+32
-13
lines changed

trainer/requirements.txt

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,12 @@ pytest==8.4.2
55
scikit-image==0.25.2
66
scipy==1.16.2
77
pillow==12.0.0
8-
torch==2.6.0+cu124
9-
torchvision==0.21.0+cu124
108

9+
# ---- PyTorch: platform-specific pins ----
10+
# macOS (MPS/CPU wheels from PyPI)
11+
torch==2.6.0; sys_platform == "darwin"
12+
torchvision==0.21.0; sys_platform == "darwin"
13+
14+
# Windows/Linux CUDA 12.4
15+
torch==2.6.0+cu124; sys_platform != "darwin"
16+
torchvision==0.21.0+cu124; sys_platform != "darwin"

trainer/src/model_utils.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -48,14 +48,14 @@ def get_latest_model_paths(model_dir, k):
4848

4949
def load_model(model_path):
5050
model = UNetGNRes()
51-
if torch.cuda.is_available():
51+
if torch.cuda.is_available() or torch.backends.mps.is_available():
5252
try:
5353
model.load_state_dict(torch.load(model_path))
5454
model = torch.nn.DataParallel(model)
5555
except:
5656
model = torch.nn.DataParallel(model)
5757
model.load_state_dict(torch.load(model_path))
58-
model.cuda()
58+
model.to(device)
5959
else:
6060
# if you are running on a CPU-only machine, please use torch.load with
6161
# map_location=torch.device('cpu') to map your storages to the CPU.
@@ -77,8 +77,8 @@ def create_first_model_with_random_weights(model_dir):
7777
model_path = os.path.join(model_dir, model_name)
7878
torch.save(model.state_dict(), model_path)
7979

80-
if torch.cuda.is_available():
81-
model.cuda()
80+
if torch.cuda.is_available() or torch.backends.mps.is_avilable():
81+
model.to(device)
8282
return model
8383

8484

@@ -292,14 +292,13 @@ def unet_segment(cnn, image, bs, in_w, out_w, threshold=0.5):
292292
tile_idx += 1
293293
tiles_to_process.append(tile)
294294
tiles_for_gpu = torch.from_numpy(np.array(tiles_to_process))
295-
if torch.cuda.is_available():
296-
tiles_for_gpu.cuda()
295+
tiles_for_gpu = tiles_for_gpu.to(device)
297296
tiles_for_gpu = tiles_for_gpu.float()
298297
batches.append(tiles_for_gpu)
299298

300299
output_tiles = []
301300
for gpu_tiles in batches:
302-
outputs = cnn(gpu_tiles.cuda())
301+
outputs = cnn(gpu_tiles.to(device))
303302
softmaxed = softmax(outputs, 1)
304303
foreground_probs = softmaxed[:, 1, :] # just the foreground probability.
305304
if threshold is not None:

trainer/src/trainer.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -89,14 +89,25 @@ def __init__(self, sync_dir=None, patch_size=572,
8989
total_mem = 0
9090
self.num_workers=min(multiprocessing.cpu_count(), max_workers)
9191
print(self.num_workers, 'workers assigned for data loader')
92-
print('GPU Available', torch.cuda.is_available())
92+
print('CUDA Available', torch.cuda.is_available())
93+
9394
if torch.cuda.is_available():
9495
for i in range(torch.cuda.device_count()):
9596
total_mem += torch.cuda.get_device_properties(i).total_memory
97+
98+
print('MPS Available', torch.backends.mps.is_available())
99+
# MPS only has one device.
100+
# There is no obvious way of getting memory for MPS
101+
# FIXME: setting arbitrary amount of memory.
102+
if torch.backends.mps.is_available():
103+
total_mem = 24_589_934_592
104+
105+
if total_mem > 0: # means CUDA or MPS found
96106
self.bs = total_mem // mem_per_item
97107
self.bs = min(12, self.bs)
98108
else:
99109
self.bs = 1 # cpu is batch size of 1
110+
100111
print('Batch size', self.bs)
101112
self.optimizer = None
102113
# used to check for updates
@@ -287,6 +298,9 @@ def train_one_epoch(self):
287298
if not [is_photo(a) for a in ls(val_annot_dir)]:
288299
return
289300

301+
302+
device = model_utils.get_device()
303+
290304
if self.first_loop:
291305
self.first_loop = False
292306
self.write_message('Training started')
@@ -313,9 +327,9 @@ def train_one_epoch(self):
313327
defined_tiles) in enumerate(train_loader):
314328

315329
self.check_for_instructions()
316-
photo_tiles = photo_tiles.cuda()
317-
foreground_tiles = foreground_tiles.cuda()
318-
defined_tiles = defined_tiles.cuda()
330+
photo_tiles = photo_tiles.to(device)
331+
foreground_tiles = foreground_tiles.to(device)
332+
defined_tiles = defined_tiles.to(device)
319333
self.optimizer.zero_grad()
320334
outputs = self.model(photo_tiles)
321335
softmaxed = softmax(outputs, 1)

0 commit comments

Comments
 (0)