Skip to content

Commit 8def43a

Browse files
Merge pull request #164 from FocoosAI/feat/train-on-AppleSilicon
allow training on mps device (apple silicon)
2 parents 0136006 + a4412be commit 8def43a

File tree

6 files changed

+48
-15
lines changed

6 files changed

+48
-15
lines changed

focoos/hub/focoos_hub.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,12 @@
3737
logger = get_logger("HUB")
3838

3939

40-
SUPPORTED_MODEL_FAMILIES = [ModelFamily.BISENETFORMER, ModelFamily.DETR, ModelFamily.MASKFORMER]
40+
SUPPORTED_MODEL_FAMILIES = [
41+
ModelFamily.BISENETFORMER,
42+
ModelFamily.DETR,
43+
ModelFamily.MASKFORMER,
44+
ModelFamily.IMAGE_CLASSIFIER,
45+
]
4146

4247

4348
class FocoosHUB:

focoos/models/focoos_model.py

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -114,10 +114,17 @@ def __init__(self, model: BaseModelNN, model_info: ModelInfo):
114114
self.processor.eval()
115115
self.model = model.eval()
116116

117-
try:
118-
self.model = self.model.cuda()
119-
except Exception:
120-
logger.warning("Unable to use CUDA")
117+
if torch.cuda.is_available():
118+
try:
119+
self.model = self.model.cuda()
120+
except Exception:
121+
logger.warning("Unable to use CUDA")
122+
123+
if torch.backends.mps.is_available():
124+
try:
125+
self.model = self.model.to(device="mps")
126+
except Exception:
127+
logger.warning("Unable to use MPS")
121128

122129
if self.model_info.weights_uri:
123130
self._load_weights()
@@ -153,8 +160,17 @@ def _setup_model_for_training(self, train_args: TrainerArgs, data_train: MapData
153160
"""
154161
device = get_cpu_name()
155162
system_info = get_system_info()
156-
if system_info.gpu_info and system_info.gpu_info.devices and len(system_info.gpu_info.devices) > 0:
163+
if (
164+
train_args.device == "cuda"
165+
and system_info.gpu_info
166+
and system_info.gpu_info.devices
167+
and len(system_info.gpu_info.devices) > 0
168+
):
157169
device = system_info.gpu_info.devices[0].gpu_name
170+
elif train_args.device == "mps" and torch.backends.mps.is_available():
171+
device = "mps"
172+
else:
173+
device = "cpu"
158174
self.model_info.ref = None
159175

160176
self.model_info.train_args = train_args # type: ignore
@@ -392,7 +408,7 @@ def export(
392408
runtime_type: RuntimeType = RuntimeType.TORCHSCRIPT_32,
393409
onnx_opset: int = 18,
394410
out_dir: Optional[str] = None,
395-
device: Literal["cuda", "cpu", "auto"] = "auto",
411+
device: Literal["cuda", "cpu", "mps", "auto"] = "auto",
396412
simplify_onnx: bool = True,
397413
overwrite: bool = True,
398414
image_size: Optional[Union[int, Tuple[int, int]]] = None,

focoos/ports.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -676,6 +676,7 @@ class GPUInfo(PydanticBase):
676676
gpu_cuda_version: Optional[str] = None
677677
total_gpu_memory_gb: Optional[float] = None
678678
devices: Optional[list[GPUDevice]] = None
679+
mps_available: Optional[bool] = None
679680

680681

681682
class SystemInfo(PydanticBase):
@@ -743,6 +744,7 @@ def pprint(self, level: Literal["INFO", "DEBUG"] = "DEBUG"):
743744
output_lines.append(f" - total_memory_gb: {value.get('total_gpu_memory_gb')} GB")
744745
output_lines.append(f" - gpu_driver: {value.get('gpu_driver')}")
745746
output_lines.append(f" - gpu_cuda_version: {value.get('gpu_cuda_version')}")
747+
output_lines.append(f" - mps_available: {value.get('mps_available')}")
746748
if value.get("devices"):
747749
output_lines.append(" - devices:")
748750
for device in value.get("devices", []):
@@ -950,16 +952,21 @@ class DatasetSplitType(str, Enum):
950952

951953
def get_gpus_count():
952954
try:
953-
import torch.cuda
955+
import torch
954956

955-
return torch.cuda.device_count()
957+
if torch.backends.mps.is_available():
958+
return 1
959+
elif torch.cuda.is_available():
960+
return torch.cuda.device_count()
961+
else:
962+
return 0
956963
except ImportError:
957964
return 0
958965

959966

960967
SchedulerType = Literal["POLY", "FIXED", "COSINE", "MULTISTEP"]
961968
OptimizerType = Literal["ADAMW", "SGD", "RMSPROP"]
962-
DeviceType = Literal["cuda", "cpu"]
969+
DeviceType = Literal["cuda", "cpu", "mps"]
963970

964971

965972
@dataclass

focoos/trainer/trainer.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,8 @@ def train(
7777
self.resume = args.resume
7878
self.finished = False
7979

80+
self.model.to(self.args.device)
81+
8082
self.args.run_name = self.args.run_name.strip()
8183
# Setup logging and environment
8284
self.output_dir = os.path.join(self.args.output_dir, self.args.run_name)
@@ -631,8 +633,12 @@ def __init__(
631633
self.gather_metric_period = gather_metric_period
632634
self.zero_grad_before_forward = zero_grad_before_forward
633635

636+
if not torch.cuda.is_available():
637+
logger.warning("[UnifiedTrainerLoop] CUDA is not available, training without AMP!")
638+
amp = False
639+
634640
# AMP setup
635-
if amp:
641+
if amp and torch.cuda.is_available():
636642
if grad_scaler is None:
637643
# the init_scale avoids the first step to be too large
638644
# and the scheduler.step() warning
@@ -725,8 +731,7 @@ def run_step(self):
725731
if self.zero_grad_before_forward:
726732
self.optimizer.zero_grad()
727733

728-
if self.amp:
729-
assert torch.cuda.is_available(), "[UnifiedTrainerLoop] CUDA is required for AMP training!"
734+
if self.amp and torch.cuda.is_available():
730735
with autocast(enabled=self.amp, dtype=self.precision, device_type="cuda"):
731736
# we need to have preprocess data here
732737
images, targets = self.processor.preprocess(data, dtype=self.precision, device=self.model.device)

focoos/utils/system.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ def get_gpu_info() -> GPUInfo:
6666
GPUInfo: An object containing comprehensive GPU information including devices list,
6767
driver version, CUDA version and GPU count.
6868
"""
69-
gpu_info = GPUInfo()
69+
gpu_info = GPUInfo(mps_available=torch.backends.mps.is_available())
7070
gpus_device = []
7171
try:
7272
# Get all GPU information in a single query

uv.lock

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)