Skip to content

Commit 236f32e

Browse files
authored
Fix training on CPU (#4788)
1 parent 93317e0 commit 236f32e

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

lib/src/otx/backend/native/engine.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@
4141
from otx.types.export import OTXExportFormatType
4242
from otx.types.precision import OTXPrecisionType
4343
from otx.types.task import OTXTaskType
44-
from otx.utils.device import is_xpu_available
44+
from otx.utils.device import get_available_device, is_xpu_available
4545
from otx.utils.utils import measure_flops
4646

4747
if TYPE_CHECKING:
@@ -909,6 +909,8 @@ def configure_accelerator(self) -> None:
909909
],
910910
)
911911
self._cache.args["precision"] = None
912+
elif (self._device.accelerator == DeviceType.cpu) or (get_available_device() == "cpu"):
913+
self._cache.args["precision"] = "32"
912914

913915
def configure_loggers(self, logger: Logger | Iterable[Logger] | bool | None = None) -> Logger | Iterable[Logger]:
914916
"""Sets up the loggers for the trainer.

0 commit comments

Comments
 (0)