diff --git a/focoos/hub/focoos_hub.py b/focoos/hub/focoos_hub.py index 5cbb371..1d165c9 100644 --- a/focoos/hub/focoos_hub.py +++ b/focoos/hub/focoos_hub.py @@ -37,7 +37,12 @@ logger = get_logger("HUB") -SUPPORTED_MODEL_FAMILIES = [ModelFamily.BISENETFORMER, ModelFamily.DETR, ModelFamily.MASKFORMER] +SUPPORTED_MODEL_FAMILIES = [ + ModelFamily.BISENETFORMER, + ModelFamily.DETR, + ModelFamily.MASKFORMER, + ModelFamily.IMAGE_CLASSIFIER, +] class FocoosHUB: diff --git a/focoos/models/focoos_model.py b/focoos/models/focoos_model.py index ad4ade7..09099cd 100644 --- a/focoos/models/focoos_model.py +++ b/focoos/models/focoos_model.py @@ -114,10 +114,17 @@ def __init__(self, model: BaseModelNN, model_info: ModelInfo): self.processor.eval() self.model = model.eval() - try: - self.model = self.model.cuda() - except Exception: - logger.warning("Unable to use CUDA") + if torch.cuda.is_available(): + try: + self.model = self.model.cuda() + except Exception: + logger.warning("Unable to use CUDA") + + if torch.backends.mps.is_available(): + try: + self.model = self.model.to(device="mps") + except Exception: + logger.warning("Unable to use MPS") if self.model_info.weights_uri: self._load_weights() @@ -153,8 +160,17 @@ def _setup_model_for_training(self, train_args: TrainerArgs, data_train: MapData """ device = get_cpu_name() system_info = get_system_info() - if system_info.gpu_info and system_info.gpu_info.devices and len(system_info.gpu_info.devices) > 0: + if ( + train_args.device == "cuda" + and system_info.gpu_info + and system_info.gpu_info.devices + and len(system_info.gpu_info.devices) > 0 + ): device = system_info.gpu_info.devices[0].gpu_name + elif train_args.device == "mps" and torch.backends.mps.is_available(): + device = "mps" + else: + device = "cpu" self.model_info.ref = None self.model_info.train_args = train_args # type: ignore @@ -392,7 +408,7 @@ def export( runtime_type: RuntimeType = RuntimeType.TORCHSCRIPT_32, onnx_opset: int = 18, out_dir: Optional[str] = None, - device: Literal["cuda", "cpu", "auto"] = "auto", + device: Literal["cuda", "cpu", "mps", "auto"] = "auto", simplify_onnx: bool = True, overwrite: bool = True, image_size: Optional[Union[int, Tuple[int, int]]] = None, diff --git a/focoos/ports.py b/focoos/ports.py index 34164d3..c319d9a 100644 --- a/focoos/ports.py +++ b/focoos/ports.py @@ -676,6 +676,7 @@ class GPUInfo(PydanticBase): gpu_cuda_version: Optional[str] = None total_gpu_memory_gb: Optional[float] = None devices: Optional[list[GPUDevice]] = None + mps_available: Optional[bool] = None class SystemInfo(PydanticBase): @@ -743,6 +744,7 @@ def pprint(self, level: Literal["INFO", "DEBUG"] = "DEBUG"): output_lines.append(f" - total_memory_gb: {value.get('total_gpu_memory_gb')} GB") output_lines.append(f" - gpu_driver: {value.get('gpu_driver')}") output_lines.append(f" - gpu_cuda_version: {value.get('gpu_cuda_version')}") + output_lines.append(f" - mps_available: {value.get('mps_available')}") if value.get("devices"): output_lines.append(" - devices:") for device in value.get("devices", []): @@ -950,16 +952,21 @@ class DatasetSplitType(str, Enum): def get_gpus_count(): try: - import torch.cuda + import torch - return torch.cuda.device_count() + if torch.backends.mps.is_available(): + return 1 + elif torch.cuda.is_available(): + return torch.cuda.device_count() + else: + return 0 except ImportError: return 0 SchedulerType = Literal["POLY", "FIXED", "COSINE", "MULTISTEP"] OptimizerType = Literal["ADAMW", "SGD", "RMSPROP"] -DeviceType = Literal["cuda", "cpu"] +DeviceType = Literal["cuda", "cpu", "mps"] @dataclass diff --git a/focoos/trainer/trainer.py b/focoos/trainer/trainer.py index 79c8e97..d0bef2b 100644 --- a/focoos/trainer/trainer.py +++ b/focoos/trainer/trainer.py @@ -77,6 +77,8 @@ def train( self.resume = args.resume self.finished = False + self.model.to(self.args.device) + self.args.run_name = self.args.run_name.strip() # Setup logging and environment self.output_dir = os.path.join(self.args.output_dir, self.args.run_name) @@ -631,8 +633,12 @@ def __init__( self.gather_metric_period = gather_metric_period self.zero_grad_before_forward = zero_grad_before_forward + if not torch.cuda.is_available(): + logger.warning("[UnifiedTrainerLoop] CUDA is not available, training without AMP!") + amp = False + # AMP setup - if amp: + if amp and torch.cuda.is_available(): if grad_scaler is None: # the init_scale avoids the first step to be too large # and the scheduler.step() warning @@ -725,8 +731,7 @@ def run_step(self): if self.zero_grad_before_forward: self.optimizer.zero_grad() - if self.amp: - assert torch.cuda.is_available(), "[UnifiedTrainerLoop] CUDA is required for AMP training!" + if self.amp and torch.cuda.is_available(): with autocast(enabled=self.amp, dtype=self.precision, device_type="cuda"): # we need to have preprocess data here images, targets = self.processor.preprocess(data, dtype=self.precision, device=self.model.device) diff --git a/focoos/utils/system.py b/focoos/utils/system.py index 70ab5af..295046c 100644 --- a/focoos/utils/system.py +++ b/focoos/utils/system.py @@ -66,7 +66,7 @@ def get_gpu_info() -> GPUInfo: GPUInfo: An object containing comprehensive GPU information including devices list, driver version, CUDA version and GPU count. """ - gpu_info = GPUInfo() + gpu_info = GPUInfo(mps_available=torch.backends.mps.is_available()) gpus_device = [] try: # Get all GPU information in a single query diff --git a/uv.lock b/uv.lock index 852ee1b..a0581cb 100644 --- a/uv.lock +++ b/uv.lock @@ -677,7 +677,7 @@ wheels = [ [[package]] name = "focoos" -version = "0.20.2" +version = "0.22.0" source = { editable = "." } dependencies = [ { name = "colorama" },