Skip to content

Commit 6dfd23c

Browse files
committed
added MPS default tiling
1 parent f7cca9a commit 6dfd23c

File tree

1 file changed

+11
-2
lines changed

1 file changed

+11
-2
lines changed

synapse_net/inference/util.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -432,13 +432,22 @@ def get_default_tiling(is_2d: bool = False) -> Dict[str, Dict[str, int]]:
432432
else:
433433
raise NotImplementedError(f"Infererence with a GPU with {vram} GB VRAM is not supported.")
434434

435-
print(f"Determined tile size: {tile}")
435+
print(f"Determined tile size for CUDA: {tile}")
436436
tiling = {"tile": tile, "halo": halo}
437437

438+
elif torch.backends.mps.is_available(): # Check for Apple Silicon (MPS)
439+
print("Using Metal Performance Shaders (MPS) for inference.")
440+
# MPS memory detection is limited, so use conservative tiling
441+
tile = {"x": 512, "y": 512, "z": 64}
442+
halo = {"x": 64, "y": 64, "z": 16}
443+
print(f"Determined tile size for MPS: {tile}")
444+
tiling = {"tile": tile, "halo": halo}
445+
446+
438447
# I am not sure what is reasonable on a cpu. For now choosing very small tiling.
439448
# (This will not work well on a CPU in any case.)
440449
else:
441-
print("Determining default tiling")
450+
print("Determining default tiling for CPU")
442451
tiling = {
443452
"tile": {"x": 96, "y": 96, "z": 16},
444453
"halo": {"x": 16, "y": 16, "z": 4},

0 commit comments

Comments
 (0)