-
Notifications
You must be signed in to change notification settings - Fork 3
Add targeted device selection for CUDA and 128GB M3 Max MPS #5
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -37,6 +37,8 @@ | |
| import os | ||
| import math | ||
| import random | ||
| import platform | ||
| import subprocess | ||
| import torch | ||
| import torch.nn as nn | ||
| import torch.nn.functional as F | ||
|
|
@@ -57,6 +59,83 @@ | |
| _DEVSTRAL_MODEL = None | ||
|
|
||
|
|
||
| def _read_sysctl_value(name: str) -> Optional[str]: | ||
| """Read a sysctl value on macOS, returning None if unavailable.""" | ||
|
|
||
| if platform.system() != "Darwin": | ||
| return None | ||
| try: | ||
| output = subprocess.check_output(["sysctl", "-n", name]) | ||
| except (subprocess.CalledProcessError, FileNotFoundError, OSError): | ||
| return None | ||
| return output.decode().strip() | ||
|
|
||
|
|
||
| def _mac_total_memory_bytes() -> Optional[int]: | ||
| raw_value = _read_sysctl_value("hw.memsize") | ||
| if raw_value is None: | ||
| return None | ||
| try: | ||
| return int(raw_value) | ||
| except ValueError: | ||
| return None | ||
|
|
||
|
|
||
| def _apple_chip_name() -> Optional[str]: | ||
| chip_name = _read_sysctl_value("machdep.cpu.brand_string") | ||
| return chip_name.strip() if chip_name else None | ||
|
|
||
|
|
||
| def _is_supported_mps_device(required_mem_gb: int = 128) -> bool: | ||
| """Return True if running on an Apple M3 Max with at least required_mem_gb.""" | ||
|
|
||
| if platform.system() != "Darwin": | ||
| return False | ||
| mps_backend = getattr(torch.backends, "mps", None) | ||
| if mps_backend is None or not mps_backend.is_available(): | ||
| return False | ||
|
|
||
| chip_name = _apple_chip_name() | ||
| if not chip_name or "m3" not in chip_name.lower() or "max" not in chip_name.lower(): | ||
| return False | ||
|
|
||
| total_mem_bytes = _mac_total_memory_bytes() | ||
| required_mem_bytes = required_mem_gb * (1024**3) | ||
| if total_mem_bytes is None or total_mem_bytes < required_mem_bytes: | ||
| return False | ||
|
|
||
| return True | ||
|
|
||
|
|
||
| def select_device() -> torch.device: | ||
| """Choose the best available device with CUDA priority, then restricted MPS.""" | ||
|
|
||
| if torch.cuda.is_available(): | ||
| return torch.device("cuda") | ||
|
|
||
| mps_backend = getattr(torch.backends, "mps", None) | ||
| if mps_backend is not None: | ||
| if _is_supported_mps_device(): | ||
| return torch.device("mps") | ||
| if mps_backend.is_available(): | ||
| chip_name = _apple_chip_name() or "unknown" | ||
| total_mem_bytes = _mac_total_memory_bytes() | ||
| if total_mem_bytes is not None: | ||
| mem_gb = total_mem_bytes / (1024**3) | ||
| print( | ||
| "Warning: MPS backend detected but restricted to Apple M3 Max systems " | ||
| "with 128GB memory. " | ||
| f"Detected chip '{chip_name}' with {mem_gb:.1f}GB. Falling back to CPU." | ||
| ) | ||
|
Comment on lines
+125
to
+129
|
||
| else: | ||
| print( | ||
| "Warning: MPS backend detected but unable to verify Apple M3 Max 128GB " | ||
| "requirement. Falling back to CPU." | ||
| ) | ||
|
Comment on lines
+131
to
+134
|
||
|
|
||
| return torch.device("cpu") | ||
|
|
||
|
|
||
| @dataclass | ||
| class RewardConfig: | ||
| """Configuration for modular reward functions.""" | ||
|
|
@@ -993,7 +1072,8 @@ def main(): | |
| ########################################################################### | ||
| # Device Setup | ||
| ########################################################################### | ||
| device = "cuda" if torch.cuda.is_available() else "cpu" | ||
| device = select_device() | ||
| print(f"Using device: {device}") | ||
|
|
||
| ########################################################################### | ||
| # Stage 0: Gather Data | ||
|
|
@@ -1035,7 +1115,7 @@ def main(): | |
| tokenizer.pad_token_id = tokenizer.eos_token_id | ||
|
|
||
| model = AutoModelForCausalLM.from_pretrained( | ||
| base_ckpt, torch_dtype="auto", device_map="auto", trust_remote_code=True | ||
| base_ckpt, torch_dtype="auto", trust_remote_code=True | ||
| ) | ||
| print("Loaded base Qwen 7B Instruct model successfully.") | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Using subprocess.check_output with user-controlled input could be vulnerable to command injection. Consider validating the
nameparameter against an allowlist of known sysctl keys before passing it to the subprocess call.