Skip to content

Commit 85df5ad

Browse files
ali-izharnbovee
andauthored
Bug fix: cpu metrics missing functionality restored (#12)
* bug fix: missing functions * mps device support * bump version * return formatting to ruff standard * silence ruff checks in locations that could impede functionality if we do not cleanly fix. * add pytest to dev dependencies * add other requirements files for cuda versions * correct typo in Readme --------- Co-authored-by: nbovee <12849851+nbovee@users.noreply.github.com>
1 parent 123d5cd commit 85df5ad

File tree

24 files changed

+1826
-10488
lines changed

24 files changed

+1826
-10488
lines changed

README.md

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@
4444
```bash
4545
git clone https://github.com/nbovee/tracr.git && cd tracr
4646
python3 -m venv venv && source venv/bin/activate
47-
pip install -r requirements.txt
47+
pip install -r requirements.txt # alternatively, use the requirements-cu###.txt file for your cuda version.
4848
```
4949

5050
2. **Configure devices** by copying and editing the template
@@ -327,7 +327,10 @@ Select optimal split points based on:
327327
<summary>Unit Testing</summary>
328328
- If issues present themselves, the provided unit tests may have some insight to the error. Please run the following, and refine to individual files for further details:
329329
330-
```python -m unittest discovery -s ./tests```
330+
```python -m unittest discover -s ./tests```
331+
or if using uv, the command will be:
332+
333+
```uv run -m unittest discover -s ./tests```
331334
</details>
332335

333336
<summary>Connection Issues</summary>

config/modelsplit_template.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# DEFAULT CONFIGURATIONS
1414
# ================================================================
1515
default:
16-
device: "cuda" # [OPTIONAL] Computing device: 'cuda' (GPU) or 'cpu'. Default: 'cuda' if available, else 'cpu'
16+
device: "cuda" # [OPTIONAL] Computing device: 'cuda' (NVIDIA GPU), 'mps' (Apple Silicon GPU), or 'cpu'. Default: 'cuda' if available, 'mps' on Apple Silicon if available, else 'cpu'
1717
save_layer_images: false # [OPTIONAL] Save intermediate layer images. Default: false
1818
collect_metrics: false # [OPTIONAL] Collect detailed metrics per layer (time-consuming for cpu systems). Default: false
1919

pyproject.toml

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "tracr"
3-
version = "0.4.0"
3+
version = "0.4.1"
44
description = "An experimental framework for computational offloading and distributed neural network inference through split computing"
55
readme = "README.md"
66
requires-python = ">=3.10"
@@ -61,6 +61,11 @@ conflicts = [
6161
[tool.ruff.lint]
6262
ignore = ["F841"]
6363

64+
[dependency-groups]
65+
dev = [
66+
"pytest>=8.3.5",
67+
]
68+
6469
[tool.uv.sources]
6570
torch = [
6671
{ index = "pytorch-cpu", extra = "cpu", marker = "platform_system != 'Darwin'" },

requirements-cu118.txt

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
# This file was autogenerated by uv via the following command:
2+
# uv pip compile pyproject.toml -o requirements-cu118.txt --no-deps --extra cu118 --extra full
3+
blosc2==3.2.0
4+
# via tracr (pyproject.toml)
5+
loguru==0.7.3
6+
# via tracr (pyproject.toml)
7+
pandas==2.2.3
8+
# via tracr (pyproject.toml)
9+
paramiko==3.5.1
10+
# via tracr (pyproject.toml)
11+
pyyaml==6.0.2
12+
# via tracr (pyproject.toml)
13+
rich==13.9.4
14+
# via tracr (pyproject.toml)
15+
rpyc==6.0.1
16+
# via tracr (pyproject.toml)
17+
tomli==2.2.1
18+
# via tracr (pyproject.toml)
19+
torch==2.6.0+cu118
20+
# via tracr (pyproject.toml)
21+
torchaudio==2.6.0+cu118
22+
# via tracr (pyproject.toml)
23+
torchinfo==1.8.0
24+
# via tracr (pyproject.toml)
25+
torchvision==0.21.0+cu118
26+
# via tracr (pyproject.toml)
27+
tqdm==4.67.1
28+
# via tracr (pyproject.toml)
29+
ultralytics==8.3.93
30+
# via tracr (pyproject.toml)

requirements-cu121.txt

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
# This file was autogenerated by uv via the following command:
2+
# uv pip compile pyproject.toml -o requirements-cu121.txt --no-deps --extra cu121 --extra full
3+
blosc2==3.2.0
4+
# via tracr (pyproject.toml)
5+
loguru==0.7.3
6+
# via tracr (pyproject.toml)
7+
pandas==2.2.3
8+
# via tracr (pyproject.toml)
9+
paramiko==3.5.1
10+
# via tracr (pyproject.toml)
11+
pyyaml==6.0.2
12+
# via tracr (pyproject.toml)
13+
rich==13.9.4
14+
# via tracr (pyproject.toml)
15+
rpyc==6.0.1
16+
# via tracr (pyproject.toml)
17+
tomli==2.2.1
18+
# via tracr (pyproject.toml)
19+
torch==2.5.1+cu121
20+
# via tracr (pyproject.toml)
21+
torchaudio==2.5.1+cu121
22+
# via tracr (pyproject.toml)
23+
torchinfo==1.8.0
24+
# via tracr (pyproject.toml)
25+
torchvision==0.20.1+cu121
26+
# via tracr (pyproject.toml)
27+
tqdm==4.67.1
28+
# via tracr (pyproject.toml)
29+
ultralytics==8.3.93
30+
# via tracr (pyproject.toml)

requirements-cu124.txt

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
# This file was autogenerated by uv via the following command:
2+
# uv pip compile pyproject.toml -o requirements-cu124.txt --no-deps --extra cu124 --extra full
3+
blosc2==3.2.0
4+
# via tracr (pyproject.toml)
5+
loguru==0.7.3
6+
# via tracr (pyproject.toml)
7+
pandas==2.2.3
8+
# via tracr (pyproject.toml)
9+
paramiko==3.5.1
10+
# via tracr (pyproject.toml)
11+
pyyaml==6.0.2
12+
# via tracr (pyproject.toml)
13+
rich==13.9.4
14+
# via tracr (pyproject.toml)
15+
rpyc==6.0.1
16+
# via tracr (pyproject.toml)
17+
tomli==2.2.1
18+
# via tracr (pyproject.toml)
19+
torch==2.6.0+cu124
20+
# via tracr (pyproject.toml)
21+
torchaudio==2.6.0+cu124
22+
# via tracr (pyproject.toml)
23+
torchinfo==1.8.0
24+
# via tracr (pyproject.toml)
25+
torchvision==0.21.0+cu124
26+
# via tracr (pyproject.toml)
27+
tqdm==4.67.1
28+
# via tracr (pyproject.toml)
29+
ultralytics==8.3.93
30+
# via tracr (pyproject.toml)

server.py

Lines changed: 38 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -25,17 +25,17 @@
2525
if str(project_root) not in sys.path:
2626
sys.path.append(str(project_root))
2727

28-
from src.api import (
28+
from src.api import ( # noqa: E402
2929
DataCompression,
3030
DeviceManager,
3131
ExperimentManager,
3232
DeviceType,
3333
start_logging_server,
3434
shutdown_logging_server,
35-
DataCompression,
35+
DataCompression, # noqa: F811
3636
read_yaml_file,
3737
)
38-
from src.api.network.protocols import (
38+
from src.api.network.protocols import ( # noqa: E402
3939
LENGTH_PREFIX_SIZE,
4040
ACK_MESSAGE,
4141
SERVER_COMPRESSION_SETTINGS,
@@ -61,11 +61,43 @@ def get_device(requested_device: str = "cuda") -> str:
6161
logger.info("CPU device explicitly requested")
6262
return "cpu"
6363

64-
if requested_device in ("cuda", "gpu", "mps") and torch.cuda.is_available():
64+
if requested_device == "cuda" and torch.cuda.is_available():
6565
logger.info("CUDA is available and will be used")
6666
return "cuda"
6767

68-
logger.warning("CUDA requested but not available, falling back to CPU")
68+
# Check for MPS (Apple Silicon GPUs)
69+
if (
70+
requested_device == "mps"
71+
and hasattr(torch.backends, "mps")
72+
and torch.backends.mps.is_available()
73+
):
74+
logger.info("MPS (Apple Silicon GPU) is available and will be used")
75+
return "mps"
76+
77+
# If we're here, requested GPU is not available - try alternatives
78+
if requested_device in ("cuda", "gpu", "mps"):
79+
# If any GPU was requested, try all available options in priority order
80+
if torch.cuda.is_available():
81+
logger.info(
82+
f"{requested_device.upper()} requested but not available, using CUDA instead"
83+
)
84+
return "cuda"
85+
86+
if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
87+
logger.info(
88+
f"{requested_device.upper()} requested but not available, using MPS (Apple Silicon GPU) instead"
89+
)
90+
return "mps"
91+
92+
logger.warning(
93+
f"{requested_device.upper()} requested but no GPU available, falling back to CPU"
94+
)
95+
return "cpu"
96+
97+
# For any other requested device, fall back to CPU
98+
logger.warning(
99+
f"Requested device '{requested_device}' not recognized, falling back to CPU"
100+
)
69101
return "cpu"
70102

71103

@@ -353,7 +385,7 @@ def _receive_config(self, conn: socket.socket) -> dict:
353385
# Deserialize using pickle
354386
try:
355387
config = pickle.loads(config_data)
356-
logger.debug(f"Successfully received and parsed configuration")
388+
logger.debug("Successfully received and parsed configuration")
357389
return config
358390
except Exception as e:
359391
logger.error(f"Failed to deserialize config: {e}")

src/api/__init__.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
"""API module"""
22

3-
from .core import *
4-
from .devices import *
5-
from .experiments import *
6-
from .inference import *
7-
from .network import *
8-
from .utils import *
3+
from .core import * # noqa: F403
4+
from .devices import * # noqa: F403
5+
from .experiments import * # noqa: F403
6+
from .inference import * # noqa: F403
7+
from .network import * # noqa: F403
8+
from .utils import * # noqa: F403
99

10-
__all__ = [
10+
__all__ = [ # noqa: F405
1111
"core",
1212
"devices",
1313
"experiments",

src/api/devices/manager.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
)
1717
from .discovery import LAN
1818
from ..network.ssh import SSHKeyHandler, SSHConfig, create_ssh_client
19-
from ..network.protocols import SSH_PORT, SSH_CONNECTIVITY_TIMEOUT, DEFAULT_PORT
19+
from ..network.protocols import SSH_PORT, SSH_CONNECTIVITY_TIMEOUT, DEFAULT_PORT # noqa: F401
2020
from ..utils.utils import get_repo_root
2121

2222
logger = logging.getLogger("split_computing_logger")

src/api/experiments/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""Core experiment infrastructure for split computing"""
22

33
import logging
4-
import sys
4+
import sys # noqa: F401
55
import time
66
from dataclasses import dataclass, field
77
from pathlib import Path

0 commit comments

Comments
 (0)