diff --git a/docs/source-pytorch/conf.py b/docs/source-pytorch/conf.py
index db3f5334ed162..67b1c8bf7bfd6 100644
--- a/docs/source-pytorch/conf.py
+++ b/docs/source-pytorch/conf.py
@@ -604,10 +604,7 @@ def package_list_from_file(file):
from lightning.pytorch.cli import _JSONARGPARSE_SIGNATURES_AVAILABLE as _JSONARGPARSE_AVAILABLE
from lightning.pytorch.utilities.imports import _TORCHVISION_AVAILABLE
from lightning.fabric.loggers.tensorboard import _TENSORBOARD_AVAILABLE, _TENSORBOARDX_AVAILABLE
-from lightning.pytorch.loggers.neptune import _NEPTUNE_AVAILABLE
-from lightning.pytorch.loggers.comet import _COMET_AVAILABLE
-from lightning.pytorch.loggers.mlflow import _MLFLOW_AVAILABLE
-from lightning.pytorch.loggers.wandb import _WANDB_AVAILABLE
+from lightning.fabric.utilities.imports import _COMET_AVAILABLE, _MLFLOW_AVAILABLE, _NEPTUNE_AVAILABLE, _WANDB_AVAILABLE
"""
coverage_skip_undoc_in_source = True
diff --git a/requirements/fabric/base.txt b/requirements/fabric/base.txt
index fdc814c7f7fa1..0c784b94864ab 100644
--- a/requirements/fabric/base.txt
+++ b/requirements/fabric/base.txt
@@ -6,3 +6,4 @@ fsspec[http] >=2022.5.0, <2025.11.0
packaging >=20.0, <=25.0
typing-extensions >4.5.0, <4.16.0
lightning-utilities >=0.10.0, <0.16.0
+pytorch-lightning-enterprise >=0.0.1dev4
diff --git a/requirements/pytorch/base.txt b/requirements/pytorch/base.txt
index 014b223b1f012..96bb0ace3e115 100644
--- a/requirements/pytorch/base.txt
+++ b/requirements/pytorch/base.txt
@@ -9,3 +9,4 @@ torchmetrics >0.7.0, <1.9.0
packaging >=20.0, <=25.0
typing-extensions >4.5.0, <4.16.0
lightning-utilities >=0.10.0, <0.16.0
+pytorch-lightning-enterprise >=0.0.1dev4
diff --git a/src/lightning/fabric/accelerators/xla.py b/src/lightning/fabric/accelerators/xla.py
index db2cf2586e1ba..0cf42821424ee 100644
--- a/src/lightning/fabric/accelerators/xla.py
+++ b/src/lightning/fabric/accelerators/xla.py
@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import functools
+import warnings
from typing import Any, Union
import torch
@@ -20,7 +21,11 @@
from lightning.fabric.accelerators.accelerator import Accelerator
from lightning.fabric.accelerators.registry import _AcceleratorRegistry
-from lightning.fabric.utilities.device_parser import _check_data_type
+from lightning.fabric.utilities.imports import _raise_enterprise_not_available
+
+_XLA_AVAILABLE = RequirementCache("torch_xla>=1.13", "torch_xla")
+_XLA_GREATER_EQUAL_2_1 = RequirementCache("torch_xla>=2.1")
+_XLA_GREATER_EQUAL_2_5 = RequirementCache("torch_xla>=2.5")
class XLAAccelerator(Accelerator):
@@ -31,38 +36,38 @@ class XLAAccelerator(Accelerator):
"""
def __init__(self, *args: Any, **kwargs: Any) -> None:
- if not _XLA_AVAILABLE:
- raise ModuleNotFoundError(str(_XLA_AVAILABLE))
- if not _using_pjrt():
- raise RuntimeError("The XLA XRT runtime is not supported anymore.")
+ _raise_enterprise_not_available()
super().__init__(*args, **kwargs)
+ from pytorch_lightning_enterprise.accelerators.xla import XLAAccelerator as EnterpriseXLAAccelerator
+
+ self.accelerator_impl = EnterpriseXLAAccelerator(*args, **kwargs)
+
@override
def setup_device(self, device: torch.device) -> None:
- pass
+ return self.accelerator_impl.setup_device(device)
@override
def teardown(self) -> None:
- pass
+ return self.accelerator_impl.teardown()
@staticmethod
@override
def parse_devices(devices: Union[int, str, list[int]]) -> Union[int, list[int]]:
"""Accelerator device parsing logic."""
- return _parse_tpu_devices(devices)
+ _raise_enterprise_not_available()
+ from pytorch_lightning_enterprise.accelerators.xla import XLAAccelerator as EnterpriseXLAAccelerator
+
+ return EnterpriseXLAAccelerator.parse_devices(devices)
@staticmethod
@override
def get_parallel_devices(devices: Union[int, list[int]]) -> list[torch.device]:
"""Gets parallel devices for the Accelerator."""
- devices = _parse_tpu_devices(devices)
- if isinstance(devices, int):
- return [torch.device("xla", i) for i in range(devices)]
- # list of devices is not supported, just a specific index, fine to access [0]
- return [torch.device("xla", devices[0])]
- # we cannot create `xla_device` here because processes have not been spawned yet (this is called in the
- # accelerator connector init). However, there doesn't seem to be a problem with instantiating `torch.device`.
- # it will be replaced with `xla_device` (also a torch.device`, but with extra logic) in the strategy
+ _raise_enterprise_not_available()
+ from pytorch_lightning_enterprise.accelerators.xla import XLAAccelerator as EnterpriseXLAAccelerator
+
+ return EnterpriseXLAAccelerator.get_parallel_devices(devices)
@staticmethod
@override
@@ -71,16 +76,10 @@ def get_parallel_devices(devices: Union[int, list[int]]) -> list[torch.device]:
@functools.lru_cache(maxsize=1)
def auto_device_count() -> int:
"""Get the devices when set to auto."""
- if not _XLA_AVAILABLE:
- return 0
- if _XLA_GREATER_EQUAL_2_1:
- from torch_xla._internal import tpu
-
- return tpu.num_available_devices()
- from torch_xla.experimental import tpu
+ _raise_enterprise_not_available()
+ from pytorch_lightning_enterprise.accelerators.xla import XLAAccelerator as EnterpriseXLAAccelerator
- device_count_on_version = {2: 8, 3: 8, 4: 4}
- return device_count_on_version.get(tpu.version(), 8)
+ return EnterpriseXLAAccelerator.auto_device_count()
@staticmethod
@override
@@ -92,6 +91,9 @@ def is_available() -> bool:
# XLA may raise these exceptions if it's not properly configured. This needs to be avoided for the cases
# when `torch_xla` is imported but not used
return False
+ except ModuleNotFoundError as e:
+ warnings.warn(str(e))
+ return False
@staticmethod
@override
@@ -106,74 +108,3 @@ def register_accelerators(cls, accelerator_registry: _AcceleratorRegistry) -> No
cls,
description=cls.__name__,
)
-
-
-# PJRT support requires this minimum version
-_XLA_AVAILABLE = RequirementCache("torch_xla>=1.13", "torch_xla")
-_XLA_GREATER_EQUAL_2_1 = RequirementCache("torch_xla>=2.1")
-_XLA_GREATER_EQUAL_2_5 = RequirementCache("torch_xla>=2.5")
-
-
-def _using_pjrt() -> bool:
- # `using_pjrt` is removed in torch_xla 2.5
- if _XLA_GREATER_EQUAL_2_5:
- from torch_xla import runtime as xr
-
- return xr.device_type() is not None
- # delete me when torch_xla 2.2 is the min supported version, where XRT support has been dropped.
- if _XLA_GREATER_EQUAL_2_1:
- from torch_xla import runtime as xr
-
- return xr.using_pjrt()
-
- from torch_xla.experimental import pjrt
-
- return pjrt.using_pjrt()
-
-
-def _parse_tpu_devices(devices: Union[int, str, list[int]]) -> Union[int, list[int]]:
- """Parses the TPU devices given in the format as accepted by the
- :class:`~lightning.pytorch.trainer.trainer.Trainer` and :class:`~lightning.fabric.Fabric`.
-
- Args:
- devices: An int of 1 or string '1' indicates that 1 core with multi-processing should be used
- An int 8 or string '8' indicates that all 8 cores with multi-processing should be used
- A single element list of int or string can be used to indicate the specific TPU core to use.
-
- Returns:
- A list of tpu cores to be used.
-
- """
- _check_data_type(devices)
- if isinstance(devices, str):
- devices = _parse_tpu_devices_str(devices)
- _check_tpu_devices_valid(devices)
- return devices
-
-
-def _check_tpu_devices_valid(devices: object) -> None:
- device_count = XLAAccelerator.auto_device_count()
- if (
- # support number of devices
- isinstance(devices, int)
- and devices in {1, device_count}
- # support picking a specific device
- or isinstance(devices, (list, tuple))
- and len(devices) == 1
- and 0 <= devices[0] <= device_count - 1
- ):
- return
- raise ValueError(
- f"`devices` can only be 'auto', 1, {device_count} or [<0-{device_count - 1}>] for TPUs. Got {devices!r}"
- )
-
-
-def _parse_tpu_devices_str(devices: str) -> Union[int, list[int]]:
- devices = devices.strip()
- try:
- return int(devices)
- except ValueError:
- try:
- return [int(x.strip()) for x in devices.split(",") if len(x) > 0]
- except ValueError:
- raise ValueError(f"Could not parse the selected TPU devices: {devices!r}")
diff --git a/src/lightning/fabric/plugins/environments/kubeflow.py b/src/lightning/fabric/plugins/environments/kubeflow.py
index 23a1c0d1753af..8013ec2a65720 100644
--- a/src/lightning/fabric/plugins/environments/kubeflow.py
+++ b/src/lightning/fabric/plugins/environments/kubeflow.py
@@ -13,11 +13,11 @@
# limitations under the License.
import logging
-import os
from typing_extensions import override
from lightning.fabric.plugins.environments.cluster_environment import ClusterEnvironment
+from lightning.fabric.utilities.imports import _raise_enterprise_not_available
log = logging.getLogger(__name__)
@@ -33,20 +33,28 @@ class KubeflowEnvironment(ClusterEnvironment):
"""
+ def __init__(self) -> None:
+ _raise_enterprise_not_available()
+ from pytorch_lightning_enterprise.plugins.environments.kubeflow import (
+ KubeflowEnvironment as EnterpriseKubeflowEnvironment,
+ )
+
+ self.kubeflow_impl = EnterpriseKubeflowEnvironment()
+
@property
@override
def creates_processes_externally(self) -> bool:
- return True
+ return self.kubeflow_impl.creates_processes_externally
@property
@override
def main_address(self) -> str:
- return os.environ["MASTER_ADDR"]
+ return self.kubeflow_impl.main_address
@property
@override
def main_port(self) -> int:
- return int(os.environ["MASTER_PORT"])
+ return self.kubeflow_impl.main_port
@staticmethod
@override
@@ -55,24 +63,24 @@ def detect() -> bool:
@override
def world_size(self) -> int:
- return int(os.environ["WORLD_SIZE"])
+ return self.kubeflow_impl.world_size()
@override
def set_world_size(self, size: int) -> None:
- log.debug("KubeflowEnvironment.set_world_size was called, but setting world size is not allowed. Ignored.")
+ return self.kubeflow_impl.set_world_size(size)
@override
def global_rank(self) -> int:
- return int(os.environ["RANK"])
+ return self.kubeflow_impl.global_rank()
@override
def set_global_rank(self, rank: int) -> None:
- log.debug("KubeflowEnvironment.set_global_rank was called, but setting global rank is not allowed. Ignored.")
+ return self.kubeflow_impl.set_global_rank(rank)
@override
def local_rank(self) -> int:
- return 0
+ return self.kubeflow_impl.local_rank()
@override
def node_rank(self) -> int:
- return self.global_rank()
+ return self.kubeflow_impl.node_rank()
diff --git a/src/lightning/fabric/plugins/environments/lsf.py b/src/lightning/fabric/plugins/environments/lsf.py
index f0a07d61d9f03..0b62bf502303a 100644
--- a/src/lightning/fabric/plugins/environments/lsf.py
+++ b/src/lightning/fabric/plugins/environments/lsf.py
@@ -13,12 +13,11 @@
# limitations under the License.
import logging
import os
-import socket
from typing_extensions import override
from lightning.fabric.plugins.environments.cluster_environment import ClusterEnvironment
-from lightning.fabric.utilities.cloud_io import get_filesystem
+from lightning.fabric.utilities.imports import _raise_enterprise_not_available
log = logging.getLogger(__name__)
@@ -50,36 +49,32 @@ class LSFEnvironment(ClusterEnvironment):
def __init__(self) -> None:
super().__init__()
- self._main_address = self._get_main_address()
- self._main_port = self._get_main_port()
- self._node_rank = self._get_node_rank()
- self._set_init_progress_group_env_vars()
-
- def _set_init_progress_group_env_vars(self) -> None:
- # set environment variables needed for initializing torch distributed process group
- os.environ["MASTER_ADDR"] = str(self._main_address)
- log.debug(f"MASTER_ADDR: {os.environ['MASTER_ADDR']}")
- os.environ["MASTER_PORT"] = str(self._main_port)
- log.debug(f"MASTER_PORT: {os.environ['MASTER_PORT']}")
+
+ _raise_enterprise_not_available()
+ from pytorch_lightning_enterprise.plugins.environments.lsf import (
+ LSFEnvironment as EnterpriseLSFEnvironment,
+ )
+
+ self.lsf_impl = EnterpriseLSFEnvironment()
@property
@override
def creates_processes_externally(self) -> bool:
"""LSF creates subprocesses, i.e., PyTorch Lightning does not need to spawn them."""
- return True
+ return self.lsf_impl.creates_processes_externally
@property
@override
def main_address(self) -> str:
"""The main address is read from an OpenMPI host rank file in the environment variable
``LSB_DJOB_RANKFILE``."""
- return self._main_address
+ return self.lsf_impl.main_address
@property
@override
def main_port(self) -> int:
"""The main port is calculated from the LSF job ID."""
- return self._main_port
+ return self.lsf_impl.main_port
@staticmethod
@override
@@ -91,110 +86,28 @@ def detect() -> bool:
@override
def world_size(self) -> int:
"""The world size is read from the environment variable ``JSM_NAMESPACE_SIZE``."""
- world_size = os.environ.get("JSM_NAMESPACE_SIZE")
- if world_size is None:
- raise ValueError(
- "Cannot determine world size. Environment variable `JSM_NAMESPACE_SIZE` not found."
- " Make sure you run your executable with `jsrun`."
- )
- return int(world_size)
+ return self.lsf_impl.world_size()
@override
def set_world_size(self, size: int) -> None:
- log.debug("LSFEnvironment.set_world_size was called, but setting world size is not allowed. Ignored.")
+ return self.lsf_impl.set_world_size(size)
@override
def global_rank(self) -> int:
"""The world size is read from the environment variable ``JSM_NAMESPACE_RANK``."""
- global_rank = os.environ.get("JSM_NAMESPACE_RANK")
- if global_rank is None:
- raise ValueError(
- "Cannot determine global rank. Environment variable `JSM_NAMESPACE_RANK` not found."
- " Make sure you run your executable with `jsrun`."
- )
- return int(global_rank)
+ return self.lsf_impl.global_rank()
@override
def set_global_rank(self, rank: int) -> None:
- log.debug("LSFEnvironment.set_global_rank was called, but setting global rank is not allowed. Ignored.")
+ return self.lsf_impl.set_global_rank(rank)
@override
def local_rank(self) -> int:
"""The local rank is read from the environment variable `JSM_NAMESPACE_LOCAL_RANK`."""
- local_rank = os.environ.get("JSM_NAMESPACE_LOCAL_RANK")
- if local_rank is None:
- raise ValueError(
- "Cannot determine local rank. Environment variable `JSM_NAMESPACE_LOCAL_RANK` not found."
- " Make sure you run your executable with `jsrun`."
- )
- return int(local_rank)
+ return self.lsf_impl.local_rank()
@override
def node_rank(self) -> int:
"""The node rank is determined by the position of the current hostname in the OpenMPI host rank file stored in
``LSB_DJOB_RANKFILE``."""
- return self._node_rank
-
- def _get_node_rank(self) -> int:
- """A helper method for getting the node rank.
-
- The node rank is determined by the position of the current node in the list of hosts used in the job. This is
- calculated by reading all hosts from ``LSB_DJOB_RANKFILE`` and finding this node's hostname in the list.
-
- """
- hosts = self._read_hosts()
- count: dict[str, int] = {}
- for host in hosts:
- if host not in count:
- count[host] = len(count)
- return count[socket.gethostname()]
-
- @staticmethod
- def _read_hosts() -> list[str]:
- """Read compute hosts that are a part of the compute job.
-
- LSF uses the Job Step Manager (JSM) to manage job steps. Job steps are executed by the JSM from "launch" nodes.
- Each job is assigned a launch node. This launch node will be the first node in the list contained in
- ``LSB_DJOB_RANKFILE``.
-
- """
- var = "LSB_DJOB_RANKFILE"
- rankfile = os.environ.get(var)
- if rankfile is None:
- raise ValueError("Did not find the environment variable `LSB_DJOB_RANKFILE`")
- if not rankfile:
- raise ValueError("The environment variable `LSB_DJOB_RANKFILE` is empty")
-
- fs = get_filesystem(rankfile)
- with fs.open(rankfile, "r") as f:
- ret = [line.strip() for line in f]
- # remove the launch node (i.e. the first node in LSB_DJOB_RANKFILE) from the list
- return ret[1:]
-
- def _get_main_address(self) -> str:
- """A helper for getting the main address.
-
- The main address is assigned to the first node in the list of nodes used for the job.
-
- """
- hosts = self._read_hosts()
- return hosts[0]
-
- @staticmethod
- def _get_main_port() -> int:
- """A helper function for accessing the main port.
-
- Uses the LSF job ID so all ranks can compute the main port.
-
- """
- # check for user-specified main port
- if "MASTER_PORT" in os.environ:
- log.debug(f"Using externally specified main port: {os.environ['MASTER_PORT']}")
- return int(os.environ["MASTER_PORT"])
- if "LSB_JOBID" in os.environ:
- port = int(os.environ["LSB_JOBID"])
- # all ports should be in the 10k+ range
- port = port % 1000 + 10000
- log.debug(f"calculated LSF main port: {port}")
- return port
- raise ValueError("Could not find job id in environment variable LSB_JOBID")
+ return self.lsf_impl.node_rank()
diff --git a/src/lightning/fabric/plugins/environments/slurm.py b/src/lightning/fabric/plugins/environments/slurm.py
index 4d98b7ed6a8eb..4ac79ce0014e9 100644
--- a/src/lightning/fabric/plugins/environments/slurm.py
+++ b/src/lightning/fabric/plugins/environments/slurm.py
@@ -14,7 +14,6 @@
import logging
import os
-import re
import shutil
import signal
import sys
@@ -23,7 +22,7 @@
from typing_extensions import override
from lightning.fabric.plugins.environments.cluster_environment import ClusterEnvironment
-from lightning.fabric.utilities.imports import _IS_WINDOWS
+from lightning.fabric.utilities.imports import _IS_WINDOWS, _raise_enterprise_not_available
from lightning.fabric.utilities.rank_zero import rank_zero_warn
from lightning.fabric.utilities.warnings import PossibleUserWarning
@@ -46,57 +45,35 @@ class SLURMEnvironment(ClusterEnvironment):
def __init__(self, auto_requeue: bool = True, requeue_signal: Optional[signal.Signals] = None) -> None:
super().__init__()
- self.auto_requeue = auto_requeue
- if requeue_signal is None and not _IS_WINDOWS:
- requeue_signal = signal.SIGUSR1
- self.requeue_signal = requeue_signal
- self._validate_srun_used()
- self._validate_srun_variables()
+ _raise_enterprise_not_available()
+ from pytorch_lightning_enterprise.plugins.environments.slurm import (
+ SLURMEnvironment as EnterpriseSLURMEnvironment,
+ )
+
+ self.slurm_impl = EnterpriseSLURMEnvironment(auto_requeue=auto_requeue, requeue_signal=requeue_signal)
+
+ @property
+ def auto_requeue(self) -> bool:
+ return self.slurm_impl.auto_requeue
+
+ @property
+ def requeue_signal(self) -> Optional[signal.Signals]:
+ return self.slurm_impl.requeue_signal
@property
@override
def creates_processes_externally(self) -> bool:
- return True
+ return self.slurm_impl.creates_processes_externally
@property
@override
def main_address(self) -> str:
- root_node = os.environ.get("MASTER_ADDR")
- if root_node is None:
- nodelist = os.environ.get("SLURM_NODELIST", "127.0.0.1")
- root_node = self.resolve_root_node_address(nodelist)
- os.environ["MASTER_ADDR"] = root_node
-
- log.debug(f"MASTER_ADDR: {os.environ['MASTER_ADDR']}")
- return root_node
+ return self.slurm_impl.main_address
@property
@override
def main_port(self) -> int:
- # -----------------------
- # SLURM JOB = PORT number
- # -----------------------
- # this way every process knows what port to use
- job_id = os.environ.get("SLURM_JOB_ID")
- if job_id is not None:
- # use the last 4 numbers in the job id as the id
- default_port = job_id[-4:]
- # all ports should be in the 10k+ range
- default_port = int(default_port) + 15000
- else:
- default_port = 12910
-
- # -----------------------
- # PORT NUMBER = MASTER_PORT
- # -----------------------
- # in case the user passed it in
- if "MASTER_PORT" in os.environ:
- default_port = int(os.environ["MASTER_PORT"])
- else:
- os.environ["MASTER_PORT"] = str(default_port)
-
- log.debug(f"MASTER_PORT: {os.environ['MASTER_PORT']}")
- return default_port
+ return self.slurm_impl.main_port
@staticmethod
@override
@@ -118,58 +95,40 @@ def job_name() -> Optional[str]:
@staticmethod
def job_id() -> Optional[int]:
- # in interactive mode, don't make logs use the same job id
- if _is_slurm_interactive_mode():
- return None
-
- job_id = os.environ.get("SLURM_JOB_ID")
- if job_id is None:
- return None
- try:
- return int(job_id)
- except ValueError:
- return None
+ _raise_enterprise_not_available()
+ from pytorch_lightning_enterprise.plugins.environments.slurm import (
+ SLURMEnvironment as EnterpriseSLURMEnvironment,
+ )
+
+ return EnterpriseSLURMEnvironment.job_id()
@override
def world_size(self) -> int:
- return int(os.environ["SLURM_NTASKS"])
+ return self.slurm_impl.world_size()
@override
def set_world_size(self, size: int) -> None:
- log.debug("SLURMEnvironment.set_world_size was called, but setting world size is not allowed. Ignored.")
+ return self.slurm_impl.set_world_size(size)
@override
def global_rank(self) -> int:
- return int(os.environ["SLURM_PROCID"])
+ return self.slurm_impl.global_rank()
@override
def set_global_rank(self, rank: int) -> None:
- log.debug("SLURMEnvironment.set_global_rank was called, but setting global rank is not allowed. Ignored.")
+ return self.slurm_impl.set_global_rank(rank)
@override
def local_rank(self) -> int:
- return int(os.environ["SLURM_LOCALID"])
+ return self.slurm_impl.local_rank()
@override
def node_rank(self) -> int:
- return int(os.environ["SLURM_NODEID"])
+ return self.slurm_impl.node_rank()
@override
def validate_settings(self, num_devices: int, num_nodes: int) -> None:
- if _is_slurm_interactive_mode():
- return
- ntasks_per_node = os.environ.get("SLURM_NTASKS_PER_NODE")
- if ntasks_per_node is not None and int(ntasks_per_node) != num_devices:
- raise ValueError(
- f"You set `devices={num_devices}` in Lightning, but the number of tasks per node configured in SLURM"
- f" `--ntasks-per-node={ntasks_per_node}` does not match. HINT: Set `devices={ntasks_per_node}`."
- )
- nnodes = os.environ.get("SLURM_NNODES")
- if nnodes is not None and int(nnodes) != num_nodes:
- raise ValueError(
- f"You set `num_nodes={num_nodes}` in Lightning, but the number of nodes configured in SLURM"
- f" `--nodes={nnodes}` does not match. HINT: Set `num_nodes={nnodes}`."
- )
+ return self.slurm_impl.validate_settings(num_devices, num_nodes)
@staticmethod
def resolve_root_node_address(nodes: str) -> str:
@@ -182,9 +141,12 @@ def resolve_root_node_address(nodes: str) -> str:
- the range notation with brackets, e.g., 'host[5-9]' yields 'host5' as the root
"""
- nodes = re.sub(r"\[(.*?)[,-].*\]", "\\1", nodes) # Take the first node of every node range
- nodes = re.sub(r"\[(.*?)\]", "\\1", nodes) # handle special case where node range is single number
- return nodes.split(" ")[0].split(",")[0]
+ _raise_enterprise_not_available()
+ from pytorch_lightning_enterprise.plugins.environments.slurm import (
+ SLURMEnvironment as EnterpriseSLURMEnvironment,
+ )
+
+ return EnterpriseSLURMEnvironment.resolve_root_node_address(nodes)
@staticmethod
def _validate_srun_used() -> None:
@@ -216,12 +178,12 @@ def _validate_srun_variables() -> None:
for a complete list of supported srun variables.
"""
- ntasks = int(os.environ.get("SLURM_NTASKS", "1"))
- if ntasks > 1 and "SLURM_NTASKS_PER_NODE" not in os.environ:
- raise RuntimeError(
- f"You set `--ntasks={ntasks}` in your SLURM bash script, but this variable is not supported."
- f" HINT: Use `--ntasks-per-node={ntasks}` instead."
- )
+ _raise_enterprise_not_available()
+ from pytorch_lightning_enterprise.plugins.environments.slurm import (
+ SLURMEnvironment as EnterpriseSLURMEnvironment,
+ )
+
+ return EnterpriseSLURMEnvironment._validate_srun_variables()
def _is_srun_used() -> bool:
diff --git a/src/lightning/fabric/plugins/environments/torchelastic.py b/src/lightning/fabric/plugins/environments/torchelastic.py
index 4003dd30dec09..62ddd442df7b2 100644
--- a/src/lightning/fabric/plugins/environments/torchelastic.py
+++ b/src/lightning/fabric/plugins/environments/torchelastic.py
@@ -13,13 +13,12 @@
# limitations under the License.
import logging
-import os
import torch.distributed
from typing_extensions import override
from lightning.fabric.plugins.environments.cluster_environment import ClusterEnvironment
-from lightning.fabric.utilities.rank_zero import rank_zero_warn
+from lightning.fabric.utilities.imports import _raise_enterprise_not_available
log = logging.getLogger(__name__)
@@ -27,29 +26,29 @@
class TorchElasticEnvironment(ClusterEnvironment):
"""Environment for fault-tolerant and elastic training with `torchelastic `_"""
+ def __init__(self) -> None:
+ super().__init__()
+ _raise_enterprise_not_available()
+ from pytorch_lightning_enterprise.plugins.environments.torchelastic import (
+ TorchElasticEnvironment as EnterpriseTorchElasticEnvironment,
+ )
+
+ self.torchelastic_impl = EnterpriseTorchElasticEnvironment()
+
@property
@override
def creates_processes_externally(self) -> bool:
- return True
+ return self.torchelastic_impl.creates_processes_externally
@property
@override
def main_address(self) -> str:
- if "MASTER_ADDR" not in os.environ:
- rank_zero_warn("MASTER_ADDR environment variable is not defined. Set as localhost")
- os.environ["MASTER_ADDR"] = "127.0.0.1"
- log.debug(f"MASTER_ADDR: {os.environ['MASTER_ADDR']}")
- return os.environ["MASTER_ADDR"]
+ return self.torchelastic_impl.main_address
@property
@override
def main_port(self) -> int:
- if "MASTER_PORT" not in os.environ:
- rank_zero_warn("MASTER_PORT environment variable is not defined. Set as 12910")
- os.environ["MASTER_PORT"] = "12910"
- log.debug(f"MASTER_PORT: {os.environ['MASTER_PORT']}")
-
- return int(os.environ["MASTER_PORT"])
+ return self.torchelastic_impl.main_port
@staticmethod
@override
@@ -60,34 +59,28 @@ def detect() -> bool:
@override
def world_size(self) -> int:
- return int(os.environ["WORLD_SIZE"])
+ return self.torchelastic_impl.world_size()
@override
def set_world_size(self, size: int) -> None:
- log.debug("TorchElasticEnvironment.set_world_size was called, but setting world size is not allowed. Ignored.")
+ return self.torchelastic_impl.set_world_size(size)
@override
def global_rank(self) -> int:
- return int(os.environ["RANK"])
+ return self.torchelastic_impl.global_rank()
@override
def set_global_rank(self, rank: int) -> None:
- log.debug(
- "TorchElasticEnvironment.set_global_rank was called, but setting global rank is not allowed. Ignored."
- )
+ return self.torchelastic_impl.set_global_rank(rank)
@override
def local_rank(self) -> int:
- return int(os.environ["LOCAL_RANK"])
+ return self.torchelastic_impl.local_rank()
@override
def node_rank(self) -> int:
- return int(os.environ.get("GROUP_RANK", 0))
+ return self.torchelastic_impl.node_rank()
@override
def validate_settings(self, num_devices: int, num_nodes: int) -> None:
- if num_devices * num_nodes != self.world_size():
- raise ValueError(
- f"You set `devices={num_devices}` and `num_nodes={num_nodes}` in Lightning, but the product"
- f" ({num_devices} * {num_nodes}) does not match the world size ({self.world_size()})."
- )
+ return self.torchelastic_impl.validate_settings(num_devices, num_nodes)
diff --git a/src/lightning/fabric/plugins/environments/xla.py b/src/lightning/fabric/plugins/environments/xla.py
index b8350872f22d9..515668fa2b154 100644
--- a/src/lightning/fabric/plugins/environments/xla.py
+++ b/src/lightning/fabric/plugins/environments/xla.py
@@ -17,8 +17,9 @@
from typing_extensions import override
-from lightning.fabric.accelerators.xla import _XLA_AVAILABLE, _XLA_GREATER_EQUAL_2_1, XLAAccelerator
+from lightning.fabric.accelerators.xla import XLAAccelerator
from lightning.fabric.plugins.environments.cluster_environment import ClusterEnvironment
+from lightning.fabric.utilities.imports import _raise_enterprise_not_available
log = logging.getLogger(__name__)
@@ -32,26 +33,28 @@ class XLAEnvironment(ClusterEnvironment):
"""
def __init__(self, *args: Any, **kwargs: Any) -> None:
- if not _XLA_AVAILABLE:
- raise ModuleNotFoundError(str(_XLA_AVAILABLE))
- super().__init__(*args, **kwargs)
+ super().__init__()
+ _raise_enterprise_not_available()
+ from pytorch_lightning_enterprise.plugins.environments.xla import (
+ XLAEnvironment as EnterpriseXLAEnvironment,
+ )
+
+ self.xla_impl = EnterpriseXLAEnvironment(*args, **kwargs)
@property
@override
def creates_processes_externally(self) -> bool:
- return False
+ return self.xla_impl.creates_processes_externally
@property
@override
def main_address(self) -> str:
- # unused by lightning
- raise NotImplementedError
+ return self.xla_impl.main_address
@property
@override
def main_port(self) -> int:
- # unused by lightning
- raise NotImplementedError
+ return self.xla_impl.main_port
@staticmethod
@override
@@ -66,18 +69,11 @@ def world_size(self) -> int:
The output is cached for performance.
"""
- if _XLA_GREATER_EQUAL_2_1:
- from torch_xla import runtime as xr
-
- return xr.world_size()
-
- import torch_xla.core.xla_model as xm
-
- return xm.xrt_world_size()
+ return self.xla_impl.world_size()
@override
def set_world_size(self, size: int) -> None:
- log.debug("XLAEnvironment.set_world_size was called, but setting world size is not allowed. Ignored.")
+ return self.xla_impl.set_world_size(size)
@override
@functools.lru_cache(maxsize=1)
@@ -87,18 +83,11 @@ def global_rank(self) -> int:
The output is cached for performance.
"""
- if _XLA_GREATER_EQUAL_2_1:
- from torch_xla import runtime as xr
-
- return xr.global_ordinal()
-
- import torch_xla.core.xla_model as xm
-
- return xm.get_ordinal()
+ return self.xla_impl.global_rank()
@override
def set_global_rank(self, rank: int) -> None:
- log.debug("XLAEnvironment.set_global_rank was called, but setting global rank is not allowed. Ignored.")
+ return self.xla_impl.set_global_rank(rank)
@override
@functools.lru_cache(maxsize=1)
@@ -108,14 +97,7 @@ def local_rank(self) -> int:
The output is cached for performance.
"""
- if _XLA_GREATER_EQUAL_2_1:
- from torch_xla import runtime as xr
-
- return xr.local_ordinal()
-
- import torch_xla.core.xla_model as xm
-
- return xm.get_local_ordinal()
+ return self.xla_impl.local_rank()
@override
@functools.lru_cache(maxsize=1)
@@ -125,11 +107,4 @@ def node_rank(self) -> int:
The output is cached for performance.
"""
- if _XLA_GREATER_EQUAL_2_1:
- from torch_xla import runtime as xr
-
- return xr.host_index()
- import torch_xla.core.xla_env_vars as xenv
- from torch_xla.utils.utils import getenv_as
-
- return getenv_as(xenv.HOST_ORDINAL, int, 0)
+ return self.xla_impl.node_rank()
diff --git a/src/lightning/fabric/plugins/io/xla.py b/src/lightning/fabric/plugins/io/xla.py
index 146fa2f33b510..da6ea414cac97 100644
--- a/src/lightning/fabric/plugins/io/xla.py
+++ b/src/lightning/fabric/plugins/io/xla.py
@@ -12,17 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
-import os
from typing import Any, Optional
-import torch
-from lightning_utilities.core.apply_func import apply_to_collection
-from lightning_utilities.core.imports import RequirementCache
from typing_extensions import override
-from lightning.fabric.accelerators.xla import _XLA_AVAILABLE
from lightning.fabric.plugins.io.torch_io import TorchCheckpointIO
-from lightning.fabric.utilities.cloud_io import get_filesystem
+from lightning.fabric.utilities.imports import _raise_enterprise_not_available
from lightning.fabric.utilities.types import _PATH
log = logging.getLogger(__name__)
@@ -36,9 +31,11 @@ class XLACheckpointIO(TorchCheckpointIO):
"""
def __init__(self, *args: Any, **kwargs: Any) -> None:
- if not _XLA_AVAILABLE:
- raise ModuleNotFoundError(str(_XLA_AVAILABLE))
super().__init__(*args, **kwargs)
+ _raise_enterprise_not_available()
+ from pytorch_lightning_enterprise.plugins.io.xla import XLACheckpointIO as EnterpriseXLACheckpointIO
+
+ self.xla_impl = EnterpriseXLACheckpointIO(*args, **kwargs)
@override
def save_checkpoint(self, checkpoint: dict[str, Any], path: _PATH, storage_options: Optional[Any] = None) -> None:
@@ -54,21 +51,4 @@ def save_checkpoint(self, checkpoint: dict[str, Any], path: _PATH, storage_optio
If ``storage_options`` arg is passed in
"""
- if storage_options is not None:
- raise TypeError(
- "`Trainer.save_checkpoint(..., storage_options=...)` with `storage_options` arg"
- f" is not supported for `{self.__class__.__name__}`. Please implement your custom `CheckpointIO`"
- " to define how you'd like to use `storage_options`."
- )
- fs = get_filesystem(path)
- fs.makedirs(os.path.dirname(path), exist_ok=True)
- if RequirementCache("omegaconf"):
- # workaround for https://github.com/pytorch/xla/issues/2773
- from omegaconf import DictConfig, ListConfig, OmegaConf
-
- checkpoint = apply_to_collection(checkpoint, (DictConfig, ListConfig), OmegaConf.to_container)
- import torch_xla.core.xla_model as xm
-
- cpu_data = xm._maybe_convert_to_cpu(checkpoint, convert=True)
- log.debug(f"Saving checkpoint: {path}")
- torch.save(cpu_data, path)
+ return self.xla_impl.save_checkpoint(checkpoint, path, storage_options)
diff --git a/src/lightning/fabric/plugins/precision/bitsandbytes.py b/src/lightning/fabric/plugins/precision/bitsandbytes.py
index 4c648f2b97181..4ff20ee9fab1e 100644
--- a/src/lightning/fabric/plugins/precision/bitsandbytes.py
+++ b/src/lightning/fabric/plugins/precision/bitsandbytes.py
@@ -11,34 +11,15 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-import functools
-import logging
-import math
-import os
-import warnings
-from collections import OrderedDict
-from contextlib import AbstractContextManager, ExitStack
-from functools import partial
-from types import ModuleType
-from typing import Any, Callable, Literal, Optional, cast
+from contextlib import AbstractContextManager
+from typing import Any, Literal, Optional
import torch
-from lightning_utilities import apply_to_collection
from lightning_utilities.core.imports import RequirementCache
-from torch import Tensor
-from torch.nn import init
-from torch.nn.modules.module import _IncompatibleKeys
-from typing_extensions import Self, override
+from typing_extensions import override
from lightning.fabric.plugins.precision.precision import Precision
-from lightning.fabric.plugins.precision.utils import (
- _ClassReplacementContextManager,
- _convert_fp_tensor,
- _DtypeContextManager,
-)
-from lightning.fabric.utilities.types import _DEVICE
-
-log = logging.getLogger(__name__)
+from lightning.fabric.utilities.imports import _raise_enterprise_not_available
_BITSANDBYTES_AVAILABLE = RequirementCache("bitsandbytes")
@@ -73,376 +54,34 @@ def __init__(
dtype: Optional[torch.dtype] = None,
ignore_modules: Optional[set[str]] = None,
) -> None:
- _import_bitsandbytes()
-
- if dtype is None:
- # try to be smart about the default selection
- if mode.startswith("int8"):
- dtype = torch.float16
- else:
- dtype = (
- torch.bfloat16 if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else torch.float16
- )
- if mode.startswith("int8") and dtype is not torch.float16:
- # this limitation is mentioned in https://huggingface.co/blog/hf-bitsandbytes-integration#usage
- raise ValueError(f"{mode!r} only works with `dtype=torch.float16`, but you chose `{dtype}`")
+ super().__init__()
+ _raise_enterprise_not_available()
+ from pytorch_lightning_enterprise.plugins.precision.bitsandbytes import (
+ BitsandbytesPrecision as EnterpriseBitsandbytesPrecision,
+ )
- globals_ = globals()
- mode_to_cls = {
- "nf4": globals_["_NF4Linear"],
- "nf4-dq": globals_["_NF4DQLinear"],
- "fp4": globals_["_FP4Linear"],
- "fp4-dq": globals_["_FP4DQLinear"],
- "int8-training": globals_["_Linear8bitLt"],
- "int8": globals_["_Int8LinearInference"],
- }
- self._linear_cls = mode_to_cls[mode]
- self.dtype = dtype
- self.ignore_modules = ignore_modules or set()
+ self.bitsandbytes_impl = EnterpriseBitsandbytesPrecision(mode=mode, dtype=dtype, ignore_modules=ignore_modules)
@override
def convert_module(self, module: torch.nn.Module) -> torch.nn.Module:
- # avoid naive users thinking they quantized their model
- if not any(isinstance(m, torch.nn.Linear) for m in module.modules()):
- raise TypeError(
- "You are using the bitsandbytes precision plugin, but your model has no Linear layers. This plugin"
- " won't work for your model."
- )
-
- # convert modules if they haven't been converted already
- bnb = _import_bitsandbytes()
- if not any(isinstance(m, (bnb.nn.Linear8bitLt, bnb.nn.Linear4bit)) for m in module.modules()):
- # this will not quantize the model but only replace the layer classes
- _convert_layers(module, self._linear_cls, self.ignore_modules)
-
- # set the compute dtype if necessary
- for m in module.modules():
- if isinstance(m, bnb.nn.Linear4bit):
- m.compute_dtype = self.dtype
- m.compute_type_is_set = False
- return module
+ return self.bitsandbytes_impl.convert_module(module)
@override
def tensor_init_context(self) -> AbstractContextManager:
- return _DtypeContextManager(self.dtype)
+ return self.bitsandbytes_impl.tensor_init_context()
@override
def module_init_context(self) -> AbstractContextManager:
- if self.ignore_modules:
- # cannot patch the Linear class if the user wants to skip some submodules
- raise RuntimeError(
- "Instantiating your model under the `init_module` context manager is not supported when used with"
- f" `BitsandbytesPrecision(..., ignore_modules={self.ignore_modules})` as this"
- " may initialize the layers on-device, defeating the purpose of quantization. You can remove"
- " `ignore_modules` or remove the `init_module` context manager."
- )
- dtype_ctx = self.tensor_init_context()
- # TODO: this could also support replacing `Embedding` and `Conv1D`
- context_manager = _ClassReplacementContextManager({"torch.nn.Linear": self._linear_cls})
- stack = ExitStack()
- stack.enter_context(dtype_ctx)
- stack.enter_context(context_manager)
- return stack
+ return self.bitsandbytes_impl.module_init_context()
@override
def forward_context(self) -> AbstractContextManager:
- return _DtypeContextManager(self.dtype)
+ return self.bitsandbytes_impl.forward_context()
@override
def convert_input(self, data: Any) -> Any:
- return apply_to_collection(data, function=_convert_fp_tensor, dtype=Tensor, dst_type=self.dtype)
+ return self.bitsandbytes_impl.convert_input(data)
@override
def convert_output(self, data: Any) -> Any:
- return apply_to_collection(data, function=_convert_fp_tensor, dtype=Tensor, dst_type=torch.get_default_dtype())
-
-
-def _quantize_on_load_hook(quantize_fn: Callable[[torch.Tensor], None], state_dict: OrderedDict, *_: Any) -> None:
- # There is only one key that ends with `*.weight`, the other one is the bias
- weight_key = next((name for name in state_dict if name.endswith("weight")), None)
- if weight_key is None:
- return
- # Load the weight from the state dict and re-quantize it
- weight = state_dict.pop(weight_key)
- quantize_fn(weight)
-
-
-def _ignore_missing_weights_hook(module: torch.nn.Module, incompatible_keys: _IncompatibleKeys) -> None:
- # since we manually loaded the weight in the `_quantize_on_load_hook` hook, we need to avoid this missing key false
- # positive
- for key in reversed(incompatible_keys.missing_keys):
- if key.endswith("weight"):
- incompatible_keys.missing_keys.remove(key)
-
-
-def _replace_param(
- param: torch.nn.Parameter, data: torch.Tensor, quant_state: Optional[tuple] = None
-) -> torch.nn.Parameter:
- bnb = _import_bitsandbytes()
-
- # doing `param.data = weight` raises a RuntimeError if param.data was on meta-device, so
- # we need to re-create the parameters instead of overwriting the data
- if param.device.type == "meta":
- if isinstance(param, bnb.nn.Params4bit):
- return bnb.nn.Params4bit(
- data=data,
- requires_grad=data.requires_grad,
- quant_state=quant_state,
- blocksize=param.blocksize,
- compress_statistics=param.compress_statistics,
- quant_type=param.quant_type,
- quant_storage=param.quant_storage,
- module=param.module,
- bnb_quantized=param.bnb_quantized,
- )
- return torch.nn.Parameter(data, requires_grad=data.requires_grad)
- param.data = data
- if isinstance(param, bnb.nn.Params4bit):
- param.quant_state = quant_state
- return param
-
-
-@functools.lru_cache(maxsize=1)
-def _import_bitsandbytes() -> ModuleType:
- if not _BITSANDBYTES_AVAILABLE:
- raise ModuleNotFoundError(str(_BITSANDBYTES_AVAILABLE))
- # configuration for bitsandbytes before import
- nowelcome_set = "BITSANDBYTES_NOWELCOME" in os.environ
- if not nowelcome_set:
- os.environ["BITSANDBYTES_NOWELCOME"] = "1"
- warnings.filterwarnings("ignore", message=r".*bitsandbytes was compiled without GPU support.*")
- warnings.filterwarnings(
- "ignore", message=r"MatMul8bitLt: inputs will be cast from .* to float16 during quantization"
- )
- import bitsandbytes as bnb
-
- if not nowelcome_set:
- del os.environ["BITSANDBYTES_NOWELCOME"]
-
- class _Linear8bitLt(bnb.nn.Linear8bitLt):
- """Wraps `bnb.nn.Linear8bitLt` and enables instantiation directly on the device and re-quantizaton when loading
- the state dict."""
-
- def __init__(self, *args: Any, device: Optional[_DEVICE] = None, threshold: float = 6.0, **kwargs: Any) -> None:
- super().__init__(*args, device=device, threshold=threshold, **kwargs)
- self.weight = cast(bnb.nn.Int8Params, self.weight) # type: ignore[has-type]
- self.bias: Optional[torch.nn.Parameter] = self.bias
- # if the device is CUDA or we are under a CUDA context manager, quantize the weight here, so we don't end up
- # filling the device memory with float32 weights which could lead to OOM
- if torch.tensor(0, device=device).device.type == "cuda":
- self.quantize_()
- self._register_load_state_dict_pre_hook(partial(_quantize_on_load_hook, self.quantize_))
- self.register_load_state_dict_post_hook(_ignore_missing_weights_hook)
-
- def quantize_(self, weight: Optional[torch.Tensor] = None, device: Optional[torch.device] = None) -> None:
- """Inplace quantize."""
- if weight is None:
- weight = self.weight.data
- if weight.data.dtype == torch.int8:
- # already quantized
- return
- assert isinstance(self.weight, bnb.nn.Int8Params)
- self.weight = self.quantize(self.weight, weight, device)
-
- @staticmethod
- def quantize(
- int8params: bnb.nn.Int8Params, weight: torch.Tensor, device: Optional[torch.device]
- ) -> bnb.nn.Int8Params:
- device = device or torch.device("cuda")
- if device.type != "cuda":
- raise RuntimeError(f"Unexpected device type: {device.type}")
- # https://github.com/TimDettmers/bitsandbytes/blob/0.41.0/bitsandbytes/nn/modules.py#L291-L302
- B = weight.contiguous().to(device=device, dtype=torch.float16)
- if int8params.has_fp16_weights:
- int8params.data = B
- else:
- # bitsandbytes >= 0.45 supports an improved API
- if hasattr(bnb.functional, "int8_vectorwise_quant"):
- CB, SCB, _ = bnb.functional.int8_vectorwise_quant(B)
- else: # old method is deprecated in 0.45, removed in 0.46+.
- CB, _, SCB, _, _ = bnb.functional.double_quant(B)
-
- int8params.data = CB
- setattr(int8params, "CB", CB)
- setattr(int8params, "SCB", SCB)
- return int8params
-
- def to_empty(self, *, device: _DEVICE, recurse: bool = True) -> Self:
- if self.weight.device.type == "meta":
- # need custom logic if int8params is on meta device
- raise NotImplementedError
- if self.weight.dtype == torch.uint8: # was quantized
- # need the original shape here
- raise NotImplementedError
- device = torch.device(device)
- weight = torch.empty_like(self.weight.data, device=device)
- if device.type == "cuda": # re-quantize
- self.quantize_(weight, device)
- else:
- self.weight = _replace_param(self.weight, weight)
- if self.bias is not None:
- self.bias = _replace_param(self.bias, torch.empty_like(self.bias, device=device))
- return self
-
- def reset_parameters(self) -> None:
- # from `torch.nn.Linear.reset_parameters`
- if self.bias is not None:
- fan_in, _ = torch.nn.init._calculate_fan_in_and_fan_out(self.weight)
- bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
- init.uniform_(self.bias, -bound, bound)
-
- linear_init_finished = isinstance(self.weight, bnb.nn.Params4bit)
- if linear_init_finished and self.weight.dtype == torch.uint8: # was quantized
- # need the original shape here
- raise NotImplementedError
- weight = self.weight.data
- torch.nn.init.kaiming_uniform_(weight, a=math.sqrt(5))
- if linear_init_finished:
- if self.weight.device.type == "meta":
- # need custom logic if int8params is on meta device
- raise NotImplementedError
- if self.weight.device.type == "cuda": # re-quantize
- self.quantize_(weight)
- else:
- self.weight = _replace_param(self.weight, weight)
-
- class _Linear4bit(bnb.nn.Linear4bit):
- """Wraps `bnb.nn.Linear4bit` to enable: instantiation directly on the device, re-quantizaton when loading the
- state dict, meta-device initialization, and materialization."""
-
- def __init__(self, *args: Any, device: Optional[_DEVICE] = None, **kwargs: Any) -> None:
- super().__init__(*args, device=device, **kwargs)
- self.weight = cast(bnb.nn.Params4bit, self.weight) # type: ignore[has-type]
- self.bias: Optional[torch.nn.Parameter] = self.bias
- # if the device is CUDA or we are under a CUDA context manager, quantize the weight here, so we don't end up
- # filling the device memory with float32 weights which could lead to OOM
- if torch.tensor(0, device=device).device.type == "cuda":
- self.quantize_()
- self._register_load_state_dict_pre_hook(partial(_quantize_on_load_hook, self.quantize_))
- self.register_load_state_dict_post_hook(_ignore_missing_weights_hook)
-
- def quantize_(self, weight: Optional[torch.Tensor] = None, device: Optional[torch.device] = None) -> None:
- """Inplace quantize."""
- if weight is None:
- weight = self.weight.data
- if weight.data.dtype == torch.uint8:
- # already quantized
- return
- assert isinstance(self.weight, bnb.nn.Params4bit)
- self.weight = self.quantize(self.weight, weight, device)
- self.weight.bnb_quantized = True
-
- @staticmethod
- def quantize(
- params4bit: bnb.nn.Params4bit, weight: torch.Tensor, device: Optional[torch.device]
- ) -> bnb.nn.Params4bit:
- device = device or torch.device("cuda")
- if device.type != "cuda":
- raise RuntimeError(f"Unexpected device type: {device.type}")
- # https://github.com/TimDettmers/bitsandbytes/blob/0.41.0/bitsandbytes/nn/modules.py#L156-L159
- w = weight.contiguous().to(device=device, dtype=torch.half)
- w_4bit, quant_state = bnb.functional.quantize_4bit(
- w,
- blocksize=params4bit.blocksize,
- compress_statistics=params4bit.compress_statistics,
- quant_type=params4bit.quant_type,
- quant_storage=params4bit.quant_storage,
- )
- return _replace_param(params4bit, w_4bit, quant_state)
-
- def to_empty(self, *, device: _DEVICE, recurse: bool = True) -> Self:
- if self.weight.dtype == torch.uint8: # was quantized
- # cannot init the quantized params directly
- weight = torch.empty(self.weight.quant_state.shape, device=device, dtype=torch.half)
- else:
- weight = torch.empty_like(self.weight.data, device=device)
- device = torch.device(device)
- if device.type == "cuda": # re-quantize
- self.quantize_(weight, device)
- else:
- self.weight = _replace_param(self.weight, weight)
- if self.bias is not None:
- self.bias = _replace_param(self.bias, torch.empty_like(self.bias, device=device))
- return self
-
- def reset_parameters(self) -> None:
- # from `torch.nn.Linear.reset_parameters`
- if self.bias is not None:
- fan_in, _ = torch.nn.init._calculate_fan_in_and_fan_out(self.weight)
- bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
- init.uniform_(self.bias, -bound, bound)
-
- linear_init_finished = isinstance(self.weight, bnb.nn.Params4bit)
- if linear_init_finished and self.weight.dtype == torch.uint8: # was quantized
- # cannot init the quantized params directly
- weight = torch.empty(self.weight.quant_state.shape, device=self.weight.device, dtype=torch.half)
- else:
- weight = self.weight.data
- torch.nn.init.kaiming_uniform_(weight, a=math.sqrt(5))
- if linear_init_finished:
- if self.weight.device.type == "cuda": # re-quantize
- self.quantize_(weight)
- else:
- self.weight = _replace_param(self.weight, weight)
-
- # Use a class instead `functools.partial` to respect `isinstance` checks and attribute accesses
- class _Int8LinearInference(_Linear8bitLt):
- def __init__(self, *args: Any, **kwargs: Any) -> None:
- super().__init__(*args, has_fp16_weights=False, **kwargs)
-
- class _FP4Linear(_Linear4bit):
- def __init__(self, *args: Any, **kwargs: Any) -> None:
- super().__init__(*args, quant_type="fp4", compress_statistics=False, **kwargs)
-
- class _FP4DQLinear(_Linear4bit):
- def __init__(self, *args: Any, **kwargs: Any) -> None:
- super().__init__(*args, quant_type="fp4", compress_statistics=True, **kwargs)
-
- class _NF4Linear(_Linear4bit):
- def __init__(self, *args: Any, **kwargs: Any) -> None:
- super().__init__(*args, quant_type="nf4", compress_statistics=False, **kwargs)
-
- class _NF4DQLinear(_Linear4bit):
- def __init__(self, *args: Any, **kwargs: Any) -> None:
- super().__init__(*args, quant_type="nf4", compress_statistics=True, **kwargs)
-
- # these classes are defined programmatically like this to avoid importing bitsandbytes in environments that have
- # it available but will not use it
- classes = {
- "_Linear8bitLt": _Linear8bitLt,
- "_Linear4bit": _Linear4bit,
- "_Int8LinearInference": _Int8LinearInference,
- "_FP4Linear": _FP4Linear,
- "_FP4DQLinear": _FP4DQLinear,
- "_NF4Linear": _NF4Linear,
- "_NF4DQLinear": _NF4DQLinear,
- }
- globals().update(classes)
-
- return bnb
-
-
-def _convert_layers(module: torch.nn.Module, linear_cls: type, ignore_modules: set[str], prefix: str = "") -> None:
- for name, child in module.named_children():
- fullname = f"{prefix}.{name}" if prefix else name
- if isinstance(child, torch.nn.Linear) and not any(fullname.startswith(s) for s in ignore_modules):
- log.debug(f"Replacing layer {fullname!r} with bitsandbytes equivalent")
- has_bias = child.bias is not None
- # since we are going to copy over the child's data, the device doesn't matter. I chose CPU
- # to avoid spiking CUDA memory even though initialization is slower
- # 4bit layers support quantizing from meta-device params so this is only relevant for 8-bit
- _Linear4bit = globals()["_Linear4bit"]
- device = torch.device("meta" if issubclass(linear_cls, _Linear4bit) else "cpu")
- replacement = linear_cls(
- child.in_features,
- child.out_features,
- bias=has_bias,
- device=device,
- )
- if has_bias:
- replacement.bias = _replace_param(replacement.bias, child.bias.data.clone())
- state = {"quant_state": replacement.weight.quant_state if issubclass(linear_cls, _Linear4bit) else None}
- replacement.weight = _replace_param(replacement.weight, child.weight.data.clone(), **state)
- module.__setattr__(name, replacement)
- else:
- _convert_layers(child, linear_cls, ignore_modules, prefix=fullname)
+ return self.bitsandbytes_impl.convert_output(data)
diff --git a/src/lightning/fabric/plugins/precision/deepspeed.py b/src/lightning/fabric/plugins/precision/deepspeed.py
index 526095008f376..837fb2dad3aac 100644
--- a/src/lightning/fabric/plugins/precision/deepspeed.py
+++ b/src/lightning/fabric/plugins/precision/deepspeed.py
@@ -11,17 +11,16 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from contextlib import AbstractContextManager, nullcontext
+from contextlib import AbstractContextManager
from typing import TYPE_CHECKING, Any, Literal
import torch
-from lightning_utilities.core.apply_func import apply_to_collection
from torch import Tensor
from torch.nn import Module
-from typing_extensions import get_args, override
+from typing_extensions import override
from lightning.fabric.plugins.precision.precision import Precision
-from lightning.fabric.plugins.precision.utils import _convert_fp_tensor, _DtypeContextManager
+from lightning.fabric.utilities.imports import _raise_enterprise_not_available
from lightning.fabric.utilities.types import Steppable
if TYPE_CHECKING:
@@ -44,51 +43,37 @@ class DeepSpeedPrecision(Precision):
"""
def __init__(self, precision: _PRECISION_INPUT) -> None:
- supported_precision = get_args(_PRECISION_INPUT)
- if precision not in supported_precision:
- raise ValueError(
- f"`precision={precision!r})` is not supported in DeepSpeed."
- f" `precision` must be one of: {supported_precision}."
- )
- self.precision = precision
-
- precision_to_type = {
- "bf16-mixed": torch.bfloat16,
- "16-mixed": torch.float16,
- "bf16-true": torch.bfloat16,
- "16-true": torch.float16,
- "32-true": torch.float32,
- }
- self._desired_dtype = precision_to_type[self.precision]
+ super().__init__()
+ _raise_enterprise_not_available()
+ from pytorch_lightning_enterprise.plugins.precision.deepspeed import (
+ DeepSpeedPrecisionFabric as EnterpriseDeepSpeedPrecision,
+ )
+
+ self.deepspeed_impl = EnterpriseDeepSpeedPrecision(precision=precision)
@override
def convert_module(self, module: Module) -> Module:
- if "true" in self.precision:
- return module.to(dtype=self._desired_dtype)
- return module
+ return self.deepspeed_impl.convert_module(module)
@override
def tensor_init_context(self) -> AbstractContextManager:
- if "true" not in self.precision:
- return nullcontext()
- return _DtypeContextManager(self._desired_dtype)
+ return self.deepspeed_impl.tensor_init_context()
@override
def module_init_context(self) -> AbstractContextManager:
- return self.tensor_init_context()
+ return self.deepspeed_impl.module_init_context()
@override
def convert_input(self, data: Any) -> Any:
- return apply_to_collection(data, function=_convert_fp_tensor, dtype=Tensor, dst_type=self._desired_dtype)
+ return self.deepspeed_impl.convert_input(data)
@override
def convert_output(self, data: Any) -> Any:
- return apply_to_collection(data, function=_convert_fp_tensor, dtype=Tensor, dst_type=torch.get_default_dtype())
+ return self.deepspeed_impl.convert_output(data)
@override
def backward(self, tensor: Tensor, model: "DeepSpeedEngine", *args: Any, **kwargs: Any) -> None:
- """Performs back-propagation using DeepSpeed's engine."""
- model.backward(tensor, *args, **kwargs)
+ return self.deepspeed_impl.backward(tensor, model, *args, **kwargs)
@override
def optimizer_step(
@@ -96,5 +81,20 @@ def optimizer_step(
optimizer: Steppable,
**kwargs: Any,
) -> Any:
- # DeepSpeed handles the optimizer step internally
- return optimizer.step(**kwargs)
+ return self.deepspeed_impl.optimizer_step(optimizer, **kwargs)
+
+ @property
+ def precision(self) -> _PRECISION_INPUT:
+ return self.deepspeed_impl.precision
+
+ @precision.setter
+ def precision(self, precision: _PRECISION_INPUT) -> None:
+ self.deepspeed_impl.precision = precision
+
+ @property
+ def _desired_dtype(self) -> torch.dtype:
+ return self.deepspeed_impl._desired_dtype
+
+ @_desired_dtype.setter
+ def _desired_dtype(self, dtype: torch.dtype) -> None:
+ self.deepspeed_impl._desired_dtype = dtype
diff --git a/src/lightning/fabric/plugins/precision/transformer_engine.py b/src/lightning/fabric/plugins/precision/transformer_engine.py
index bf1e51ea6b2b0..3d3d2491763f4 100644
--- a/src/lightning/fabric/plugins/precision/transformer_engine.py
+++ b/src/lightning/fabric/plugins/precision/transformer_engine.py
@@ -11,31 +11,19 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-import logging
from collections.abc import Mapping
-from contextlib import AbstractContextManager, ExitStack
+from contextlib import AbstractContextManager
from typing import TYPE_CHECKING, Any, Literal, Optional, Union
import torch
-from lightning_utilities import apply_to_collection
-from lightning_utilities.core.imports import RequirementCache
-from torch import Tensor
from typing_extensions import override
from lightning.fabric.plugins.precision.precision import Precision
-from lightning.fabric.plugins.precision.utils import (
- _ClassReplacementContextManager,
- _convert_fp_tensor,
- _DtypeContextManager,
-)
-from lightning.fabric.utilities.rank_zero import rank_zero_info, rank_zero_warn
+from lightning.fabric.utilities.imports import _raise_enterprise_not_available
if TYPE_CHECKING:
from transformer_engine.common.recipe import DelayedScaling
-_TRANSFORMER_ENGINE_AVAILABLE = RequirementCache("transformer_engine>=0.11.0")
-log = logging.getLogger(__name__)
-
class TransformerEnginePrecision(Precision):
"""Plugin for training with fp8 precision via nvidia's
@@ -72,111 +60,79 @@ def __init__(
replace_layers: Optional[bool] = None,
fallback_compute_dtype: Optional[torch.dtype] = None,
) -> None:
- if not _TRANSFORMER_ENGINE_AVAILABLE:
- raise ModuleNotFoundError(str(_TRANSFORMER_ENGINE_AVAILABLE))
- from transformer_engine.common.recipe import DelayedScaling
-
- if recipe is None:
- recipe = DelayedScaling()
- elif isinstance(recipe, Mapping):
- recipe = dict(recipe) # copy
- if "fp8_format" in recipe:
- from transformer_engine.common.recipe import Format
-
- recipe["fp8_format"] = getattr(Format, recipe["fp8_format"])
- recipe = DelayedScaling(**recipe)
-
- self.weights_dtype = weights_dtype
- self.recipe = recipe
- self.replace_layers = replace_layers
- self.fallback_compute_dtype = fallback_compute_dtype or weights_dtype
+ super().__init__()
+ _raise_enterprise_not_available()
+ from pytorch_lightning_enterprise.plugins.precision.transformer_engine import (
+ TransformerEnginePrecision as EnterpriseTransformerEnginePrecision,
+ )
+
+ self.transformer_engine_impl = EnterpriseTransformerEnginePrecision(
+ weights_dtype=weights_dtype,
+ recipe=recipe,
+ replace_layers=replace_layers,
+ fallback_compute_dtype=fallback_compute_dtype,
+ )
+
+ @property
+ def weights_dtype(self) -> torch.dtype:
+ return self.transformer_engine_impl.weights_dtype
+
+ @weights_dtype.setter
+ def weights_dtype(self, value: torch.dtype) -> None:
+ self.transformer_engine_impl.weights_dtype = value
+
+ @property
+ def recipe(self) -> Union[Mapping[str, Any], "DelayedScaling"]:
+ return self.transformer_engine_impl.recipe
+
+ @recipe.setter
+ def recipe(self, value: Union[Mapping[str, Any], "DelayedScaling"]) -> None:
+ self.transformer_engine_impl.recipe = value
+
+ @property
+ def replace_layers(self) -> bool:
+ return self.transformer_engine_impl.replace_layers
+
+ @replace_layers.setter
+ def replace_layers(self, value: bool) -> None:
+ self.transformer_engine_impl.replace_layers = value
+
+ @property
+ def fallback_compute_dtype(self) -> torch.dtype:
+ return self.transformer_engine_impl.fallback_compute_dtype
+
+ @fallback_compute_dtype.setter
+ def fallback_compute_dtype(self, value: torch.dtype) -> None:
+ self.transformer_engine_impl.fallback_compute_dtype = value
@override
def convert_module(self, module: torch.nn.Module) -> torch.nn.Module:
- # avoid converting if any is found. assume the user took care of it
- if any("transformer_engine.pytorch" in m.__module__ for m in module.modules()):
- if self.replace_layers is True:
- # info level because this is expected with `init_module`
- rank_zero_info(
- "`TransformerEnginePrecision(replace_layers=True)` is set but the model already contains"
- " TransformerEngine layers. Skipping"
- )
- elif self.replace_layers in (None, True):
- _convert_layers(module)
- module = module.to(dtype=self.weights_dtype)
- return module
+ return self.transformer_engine_impl.convert_module(module)
@override
def tensor_init_context(self) -> AbstractContextManager:
- return _DtypeContextManager(self.weights_dtype)
+ return self.transformer_engine_impl.tensor_init_context()
@override
def module_init_context(self) -> AbstractContextManager:
- dtype_ctx = self.tensor_init_context()
- stack = ExitStack()
- if self.replace_layers:
- import transformer_engine.pytorch as te
-
- context_manager = _ClassReplacementContextManager({
- "torch.nn.Linear": te.Linear,
- "torch.nn.LayerNorm": te.LayerNorm,
- })
- stack.enter_context(context_manager)
- stack.enter_context(dtype_ctx)
- return stack
+ return self.transformer_engine_impl.module_init_context()
@override
def forward_context(self) -> AbstractContextManager:
- dtype_ctx = _DtypeContextManager(self.weights_dtype)
- fallback_autocast_ctx = torch.autocast(device_type="cuda", dtype=self.fallback_compute_dtype)
- import transformer_engine.pytorch as te
-
- autocast_ctx = te.fp8_autocast(enabled=True, fp8_recipe=self.recipe)
- stack = ExitStack()
- stack.enter_context(dtype_ctx)
- # enable an outer fallback autocast for operations that do not support fp8
- stack.enter_context(fallback_autocast_ctx)
- stack.enter_context(autocast_ctx)
- return stack
+ return self.transformer_engine_impl.forward_context()
@override
def convert_input(self, data: Any) -> Any:
- return apply_to_collection(data, function=_convert_fp_tensor, dtype=Tensor, dst_type=self.weights_dtype)
+ return self.transformer_engine_impl.convert_input(data)
@override
def convert_output(self, data: Any) -> Any:
- return apply_to_collection(data, function=_convert_fp_tensor, dtype=Tensor, dst_type=torch.get_default_dtype())
-
-
-def _convert_layers(module: torch.nn.Module) -> None:
- import transformer_engine.pytorch as te
-
- for name, child in module.named_children():
- if isinstance(child, torch.nn.Linear):
- if child.in_features % 8 != 0 or child.out_features % 16 != 0:
- # https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/examples/fp8_primer.html#FP8-autocasting
- rank_zero_warn(
- "Support for FP8 in the linear layers with this plugin is currently limited to"
- " tensors with shapes where the dimensions are divisible by 8 and 16 respectively."
- f" The layer {name!r} does not fit this criteria. You might want to add padding to your inputs."
- )
- continue
- has_bias = child.bias is not None
- replacement = te.Linear(child.in_features, child.out_features, bias=has_bias)
- replacement.weight.data = child.weight.data.clone()
- if has_bias:
- replacement.bias.data = child.bias.data.clone()
- log.debug(f"Replacing layer {name!r} with Transformer Engine equivalent")
- module.__setattr__(name, replacement)
- elif isinstance(child, torch.nn.LayerNorm):
- replacement = te.LayerNorm(child.normalized_shape[0], eps=child.eps)
- replacement.weight.data = child.weight.data.clone()
- # Check if bias exists before attempting to clone its data
- if child.bias is not None and replacement.bias is not None:
- replacement.bias.data = child.bias.data.clone()
- log.debug(f"Replacing layer {name!r} with Transformer Engine equivalent")
- module.__setattr__(name, replacement)
- else:
- # there are other transformer engine layers that we could convert but require fusion. full list at:
- # https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/api/pytorch.html
- _convert_layers(child)
+ return self.transformer_engine_impl.convert_output(data)
+
+ @property
+ def _desired_dtype(self) -> torch.dtype:
+ return self.transformer_engine_impl._desired_dtype
+
+ @_desired_dtype.setter
+ def _desired_dtype(self, dtype: torch.dtype) -> None:
+ self.transformer_engine_impl._desired_dtype = dtype
diff --git a/src/lightning/fabric/plugins/precision/xla.py b/src/lightning/fabric/plugins/precision/xla.py
index fdb30032b3cdd..4055d87e14af7 100644
--- a/src/lightning/fabric/plugins/precision/xla.py
+++ b/src/lightning/fabric/plugins/precision/xla.py
@@ -11,14 +11,13 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-import os
from typing import Any, Literal
import torch
-from typing_extensions import get_args, override
+from typing_extensions import override
-from lightning.fabric.accelerators.xla import _XLA_AVAILABLE
from lightning.fabric.plugins.precision.precision import Precision
+from lightning.fabric.utilities.imports import _raise_enterprise_not_available
from lightning.fabric.utilities.types import Optimizable
_PRECISION_INPUT = Literal["32-true", "16-true", "bf16-true"]
@@ -37,24 +36,11 @@ class XLAPrecision(Precision):
"""
def __init__(self, precision: _PRECISION_INPUT) -> None:
- if not _XLA_AVAILABLE:
- raise ModuleNotFoundError(str(_XLA_AVAILABLE))
- supported_precision = get_args(_PRECISION_INPUT)
- if precision not in supported_precision:
- raise ValueError(
- f"`precision={precision!r})` is not supported in XLA."
- f" `precision` must be one of: {supported_precision}."
- )
- self.precision = precision
-
- if precision == "16-true":
- os.environ["XLA_USE_F16"] = "1"
- self._desired_dtype = torch.float16
- elif precision == "bf16-true":
- os.environ["XLA_USE_BF16"] = "1"
- self._desired_dtype = torch.bfloat16
- else:
- self._desired_dtype = torch.float32
+ super().__init__()
+ _raise_enterprise_not_available()
+ from pytorch_lightning_enterprise.plugins.precision.xla import XLAPrecision as EnterpriseXLAPrecision
+
+ self.xla_impl = EnterpriseXLAPrecision(precision=precision)
@override
def optimizer_step(
@@ -62,12 +48,24 @@ def optimizer_step(
optimizer: Optimizable,
**kwargs: Any,
) -> Any:
- import torch_xla.core.xla_model as xm
-
- # you always want to `xm.mark_step()` after `optimizer.step` for better performance, so we set `barrier=True`
- return xm.optimizer_step(optimizer, optimizer_args=kwargs, barrier=True)
+ return self.xla_impl.optimizer_step(optimizer, **kwargs)
@override
def teardown(self) -> None:
- os.environ.pop("XLA_USE_BF16", None)
- os.environ.pop("XLA_USE_F16", None)
+ return self.xla_impl.teardown()
+
+ @property
+ def _desired_dtype(self) -> torch.dtype:
+ return self.xla_impl._desired_dtype
+
+ @_desired_dtype.setter
+ def _desired_dtype(self, dtype: torch.dtype) -> None:
+ self.xla_impl._desired_dtype = dtype
+
+ @property
+ def precision(self) -> _PRECISION_INPUT:
+ return self.xla_impl.precision
+
+ @precision.setter
+ def precision(self, precision: _PRECISION_INPUT) -> None:
+ self.xla_impl.precision = precision
diff --git a/src/lightning/fabric/strategies/deepspeed.py b/src/lightning/fabric/strategies/deepspeed.py
index 883546fea1f2d..a477a2b940afc 100644
--- a/src/lightning/fabric/strategies/deepspeed.py
+++ b/src/lightning/fabric/strategies/deepspeed.py
@@ -11,45 +11,30 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-import argparse
-import json
import logging
-import os
-import platform
-from collections.abc import Mapping
-from contextlib import AbstractContextManager, ExitStack
+from contextlib import AbstractContextManager
from datetime import timedelta
-from itertools import chain
-from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
import torch
-from lightning_utilities.core.imports import RequirementCache
from torch.nn import Module
from torch.optim import Optimizer
from typing_extensions import override
-from lightning.fabric.accelerators import Accelerator, CUDAAccelerator
+from lightning.fabric.accelerators import Accelerator
from lightning.fabric.plugins.collectives.torch_collective import default_pg_timeout
from lightning.fabric.plugins.environments.cluster_environment import ClusterEnvironment
from lightning.fabric.plugins.precision import Precision
from lightning.fabric.strategies.ddp import DDPStrategy
from lightning.fabric.strategies.registry import _StrategyRegistry
from lightning.fabric.strategies.strategy import _Sharded
-from lightning.fabric.utilities.distributed import log
-from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_6
-from lightning.fabric.utilities.load import _move_state_into
-from lightning.fabric.utilities.rank_zero import rank_zero_info, rank_zero_warn
-from lightning.fabric.utilities.seed import reset_seed
+from lightning.fabric.utilities.imports import _raise_enterprise_not_available
from lightning.fabric.utilities.types import _PATH
if TYPE_CHECKING:
from deepspeed import DeepSpeedEngine
from torch.optim.lr_scheduler import _LRScheduler
-_DEEPSPEED_AVAILABLE = RequirementCache("deepspeed")
-_DEEPSPEED_GREATER_EQUAL_0_16 = RequirementCache("deepspeed>=0.16.0")
-
# TODO(fabric): Links in the docstrings to PL-specific deepspeed user docs need to be replaced.
class DeepSpeedStrategy(DDPStrategy, _Sharded):
@@ -235,24 +220,11 @@ def __init__(
exclude_frozen_parameters: Exclude frozen parameters when saving checkpoints.
"""
- if not _DEEPSPEED_AVAILABLE:
- raise ImportError(
- "To use the `DeepSpeedStrategy`, you must have DeepSpeed installed."
- " Install it by running `pip install -U deepspeed`."
- )
-
- if _TORCH_GREATER_EQUAL_2_6 and not _DEEPSPEED_GREATER_EQUAL_0_16:
- # Starting with PyTorch 2.6, `torch.load` defaults to `weights_only=True` when loading full checkpoints.
- # DeepSpeed added support for this behavior in version 0.16.0.
- import deepspeed
-
- deepspeed_version = deepspeed.__version__
- raise ImportError(
- f"PyTorch >= 2.6 requires DeepSpeed >= 0.16.0. "
- f"Detected DeepSpeed version: {deepspeed_version}. "
- "Please upgrade by running `pip install -U 'deepspeed>=0.16.0'`."
- )
+ _raise_enterprise_not_available()
+ from pytorch_lightning_enterprise.strategies.deepspeed import (
+ DeepSpeedStrategyFabric as EnterpriseDeepSpeedStrategy,
+ )
super().__init__(
accelerator=accelerator,
@@ -261,77 +233,68 @@ def __init__(
precision=precision,
process_group_backend=process_group_backend,
)
- self._backward_sync_control = None # DeepSpeed handles gradient accumulation internally
- self._timeout: Optional[timedelta] = timeout
-
- self.config = self._load_config(config)
- if self.config is None:
- # User has not overridden config, set defaults
- self.config = self._create_default_config(
- zero_optimization,
- zero_allow_untested_optimizer,
- logging_batch_size_per_gpu,
- offload_optimizer=offload_optimizer,
- offload_parameters=offload_parameters,
- nvme_path=nvme_path,
- offload_params_device=offload_params_device,
- params_buffer_count=params_buffer_count,
- params_buffer_size=params_buffer_size,
- max_in_cpu=max_in_cpu,
- pin_memory=pin_memory,
- offload_optimizer_device=offload_optimizer_device,
- optimizer_buffer_count=optimizer_buffer_count,
- block_size=block_size,
- queue_depth=queue_depth,
- single_submit=single_submit,
- overlap_events=overlap_events,
- thread_count=thread_count,
- partition_activations=partition_activations,
- cpu_checkpointing=cpu_checkpointing,
- contiguous_memory_optimization=contiguous_memory_optimization,
- synchronize_checkpoint_boundary=synchronize_checkpoint_boundary,
- stage=stage,
- contiguous_gradients=contiguous_gradients,
- overlap_comm=overlap_comm,
- allgather_partitions=allgather_partitions,
- reduce_scatter=reduce_scatter,
- allgather_bucket_size=allgather_bucket_size,
- reduce_bucket_size=reduce_bucket_size,
- sub_group_size=sub_group_size,
- )
-
- import deepspeed
-
- self._config_initialized = False
- deepspeed.utils.logging.logger.setLevel(logging_level)
-
- self.remote_device = remote_device
- self.load_full_weights = load_full_weights
- self.exclude_frozen_parameters = exclude_frozen_parameters
-
- # default FP16 parameters.
- self.loss_scale = loss_scale
- self.initial_scale_power = initial_scale_power
- self.loss_scale_window = loss_scale_window
- self.hysteresis = hysteresis
- self.min_loss_scale = min_loss_scale
-
- self._deepspeed_engine: Optional[DeepSpeedEngine] = None
+ self.deepspeed_impl = EnterpriseDeepSpeedStrategy(
+ outer_object=self,
+ accelerator=accelerator,
+ zero_optimization=zero_optimization,
+ stage=stage,
+ remote_device=remote_device,
+ offload_optimizer=offload_optimizer,
+ offload_parameters=offload_parameters,
+ offload_params_device=offload_params_device,
+ nvme_path=nvme_path,
+ params_buffer_count=params_buffer_count,
+ params_buffer_size=params_buffer_size,
+ max_in_cpu=max_in_cpu,
+ offload_optimizer_device=offload_optimizer_device,
+ optimizer_buffer_count=optimizer_buffer_count,
+ block_size=block_size,
+ queue_depth=queue_depth,
+ single_submit=single_submit,
+ overlap_events=overlap_events,
+ thread_count=thread_count,
+ pin_memory=pin_memory,
+ sub_group_size=sub_group_size,
+ contiguous_gradients=contiguous_gradients,
+ overlap_comm=overlap_comm,
+ allgather_partitions=allgather_partitions,
+ reduce_scatter=reduce_scatter,
+ allgather_bucket_size=allgather_bucket_size,
+ reduce_bucket_size=reduce_bucket_size,
+ zero_allow_untested_optimizer=zero_allow_untested_optimizer,
+ logging_batch_size_per_gpu=logging_batch_size_per_gpu,
+ config=config,
+ logging_level=logging_level,
+ parallel_devices=parallel_devices,
+ cluster_environment=cluster_environment,
+ loss_scale=loss_scale,
+ initial_scale_power=initial_scale_power,
+ loss_scale_window=loss_scale_window,
+ hysteresis=hysteresis,
+ min_loss_scale=min_loss_scale,
+ partition_activations=partition_activations,
+ cpu_checkpointing=cpu_checkpointing,
+ contiguous_memory_optimization=contiguous_memory_optimization,
+ synchronize_checkpoint_boundary=synchronize_checkpoint_boundary,
+ load_full_weights=load_full_weights,
+ precision=precision,
+ process_group_backend=process_group_backend,
+ timeout=timeout,
+ exclude_frozen_parameters=exclude_frozen_parameters,
+ )
@property
def zero_stage_3(self) -> bool:
- assert isinstance(self.config, dict)
- zero_optimization = self.config.get("zero_optimization")
- return zero_optimization is not None and zero_optimization.get("stage") == 3
+ return self.deepspeed_impl.zero_stage_3
@property
@override
def distributed_sampler_kwargs(self) -> dict[str, int]:
- return {"num_replicas": self.world_size, "rank": self.global_rank}
+ return self.deepspeed_impl.distributed_sampler_kwargs
@property
def model(self) -> "DeepSpeedEngine":
- return self._deepspeed_engine
+ return self.deepspeed_impl.model
@override
def setup_module_and_optimizers(
@@ -345,14 +308,9 @@ def setup_module_and_optimizers(
deepspeed optimizer, and an optional learning rate scheduler.
"""
- if len(optimizers) != 1:
- raise ValueError(
- f"Currently only one optimizer is supported with DeepSpeed. Got {len(optimizers)} optimizers instead."
- )
-
- self._deepspeed_engine, optimizer, scheduler = self._initialize_engine(module, optimizers[0], scheduler)
- self._set_deepspeed_activation_checkpointing()
- return self._deepspeed_engine, [optimizer], scheduler
+ return self.deepspeed_impl.setup_module_and_optimizers(
+ module=module, optimizers=optimizers, scheduler=scheduler
+ )
@override
def setup_module(self, module: Module) -> "DeepSpeedEngine":
@@ -361,8 +319,7 @@ def setup_module(self, module: Module) -> "DeepSpeedEngine":
For training, see :meth:`setup_module_and_optimizers`.
"""
- self._deepspeed_engine, _, _ = self._initialize_engine(module)
- return self._deepspeed_engine
+ return self.deepspeed_impl.setup_module(module=module)
@override
def setup_optimizer(self, optimizer: Optimizer) -> Optimizer:
@@ -371,34 +328,15 @@ def setup_optimizer(self, optimizer: Optimizer) -> Optimizer:
Please use :meth:`setup_module_and_optimizers` to set up both module and optimizer together.
"""
- raise NotImplementedError(self._err_msg_joint_setup_required())
+ return self.deepspeed_impl.setup_optimizer(optimizer=optimizer)
@override
def module_init_context(self, empty_init: Optional[bool] = None) -> AbstractContextManager:
- if self.zero_stage_3 and empty_init is False:
- raise NotImplementedError(
- f"`{empty_init=}` is not a valid choice with `DeepSpeedStrategy` when ZeRO stage 3 is enabled."
- )
- module_sharded_ctx = self.module_sharded_context()
- stack = ExitStack()
- if not self.zero_stage_3:
- stack.enter_context(super().module_init_context(empty_init=empty_init))
- stack.enter_context(module_sharded_ctx)
- return stack
+ return self.deepspeed_impl.module_init_context(empty_init=empty_init)
@override
def module_sharded_context(self) -> AbstractContextManager:
- # Current limitation in Fabric: The config needs to be fully determined at the time of calling the context
- # manager. Later modifications through e.g. `Fabric.setup()` won't have an effect here.
-
- import deepspeed
-
- assert self._config_initialized
- return deepspeed.zero.Init(
- enabled=self.zero_stage_3,
- remote_device=self.remote_device,
- config_dict_or_path=self.config,
- )
+ return self.deepspeed_impl.module_sharded_context()
@override
def save_checkpoint(
@@ -425,46 +363,8 @@ def save_checkpoint(
:class:`deepspeed.DeepSpeedEngine` objects were found.
"""
- if storage_options is not None:
- raise TypeError(
- "`DeepSpeedStrategy.save_checkpoint(..., storage_options=...)` is not supported because"
- " `DeepSpeedStrategy` does not use the `CheckpointIO`."
- )
- if filter is not None:
- raise TypeError(
- "`DeepSpeedStrategy.save_checkpoint(..., filter=...)` is not supported because"
- " `DeepSpeedStrategy` manages the state serialization internally."
- )
-
- engines = _get_deepspeed_engines_from_state(state)
- if len(engines) == 0:
- raise ValueError(
- "Could not find a DeepSpeed model in the provided checkpoint state. Please provide the model as"
- " part of the state like so: `save_checkpoint(..., state={'model': model, ...})`. Make sure"
- " you set up the model (and optimizers if any) through the strategy before saving the checkpoint."
- )
- if len(engines) > 1:
- raise ValueError(
- "Found multiple DeepSpeed engine modules in the given state. Saving checkpoints with DeepSpeed is"
- " currently limited to a single model per checkpoint. To save multiple models, call the"
- " save method for each model separately with a different path."
- )
- engine = engines[0]
-
- # broadcast the path from rank 0 to ensure all the states are saved in a common path
- path = self.broadcast(path)
-
- # split the checkpoint into two parts:
- # 1) the deepspeed engine encapsulating both the model and optionally the optimizer(s)
- # 2) the rest of the user's state, which in deepspeed is called `client state`
- excluded_objects = (engine, engine.optimizer) if engine.optimizer is not None else (engine,)
- state = {k: v for k, v in state.items() if v not in excluded_objects}
- _validate_state_keys(state)
- # there might be other stateful objects unrelated to the deepspeed engine - convert them to a state_dict
- state = self._convert_stateful_objects_in_state(state, filter={})
- # use deepspeed's internal checkpointing function to handle partitioned weights across processes
- engine.save_checkpoint(
- path, client_state=state, tag="checkpoint", exclude_frozen_parameters=self.exclude_frozen_parameters
+ return self.deepspeed_impl.save_checkpoint(
+ path=path, state=state, storage_options=storage_options, filter=filter
)
@override
@@ -495,59 +395,7 @@ def load_checkpoint(
not in the expected DeepSpeed format.
"""
- if isinstance(state, (Module, Optimizer)) or self.load_full_weights and self.zero_stage_3:
- # This code path to enables loading a checkpoint from a non-deepspeed checkpoint or from
- # a consolidated checkpoint
- path = self.broadcast(path)
- return super().load_checkpoint(path=path, state=state, strict=strict, weights_only=weights_only)
-
- if not state:
- raise ValueError(
- f"Got DeepSpeedStrategy.load_checkpoint(..., state={state!r}) but a state with at least "
- f" a model instance to reload is required. Pass it in like so:"
- " DeepSpeedStrategy.load_checkpoint(..., state={'model': model, ...})"
- )
- _validate_checkpoint_directory(path)
-
- engines = _get_deepspeed_engines_from_state(state)
- if len(engines) == 0:
- raise ValueError(
- "Could not find a DeepSpeed model in the provided checkpoint state. Please provide the model as"
- " part of the state like so: `load_checkpoint(..., state={'model': model, ...})`. Make sure"
- " you set up the model (and optimizers if any) through the strategy before loading the checkpoint."
- )
- if len(engines) > 1:
- raise ValueError(
- "Found multiple DeepSpeed engine modules in the given state. Saving and loading checkpoints"
- " with DeepSpeed is currently limited to a single model per checkpoint. To load multiple model"
- " states, call the load method for each model checkpoint separately."
- )
- engine = engines[0]
-
- from deepspeed.runtime.base_optimizer import DeepSpeedOptimizer
-
- optimzer_state_requested = any(isinstance(item, (Optimizer, DeepSpeedOptimizer)) for item in state.values())
-
- torch.cuda.empty_cache()
- _, client_state = engine.load_checkpoint(
- path,
- tag="checkpoint",
- load_optimizer_states=optimzer_state_requested,
- load_lr_scheduler_states=False,
- load_module_strict=strict,
- )
-
- if client_state is None:
- raise RuntimeError(
- "DeepSpeed was unable to load the checkpoint. Ensure you passed in a DeepSpeed compatible checkpoint"
- " or a single checkpoint file by setting `DeepSpeedStrategy(..., load_full_weights=True)`."
- )
-
- # `Engine.load_checkpoint` adds useless keys 'optimizer' and 'lr_scheduler' to the client state; remove
- # them to avoid name collision with user state
- keys = set(client_state) & set(state) - {"optimizer", "lr_scheduler"}
- _move_state_into(source=client_state, destination=state, keys=keys)
- return client_state
+ return self.deepspeed_impl.load_checkpoint(path=path, state=state, strict=strict)
@override
def clip_gradients_norm(
@@ -558,19 +406,19 @@ def clip_gradients_norm(
norm_type: Union[float, int] = 2.0,
error_if_nonfinite: bool = True,
) -> torch.Tensor:
- raise NotImplementedError(
- "DeepSpeed handles gradient clipping automatically within the optimizer. "
- "Make sure to set the `gradient_clipping` value in your Config."
+ return self.deepspeed_impl.clip_gradients_norm(
+ module=module,
+ optimizer=optimizer,
+ max_norm=max_norm,
+ norm_type=norm_type,
+ error_if_nonfinite=error_if_nonfinite,
)
@override
def clip_gradients_value(
self, module: "DeepSpeedEngine", optimizer: Optimizer, clip_val: Union[float, int]
) -> None:
- raise NotImplementedError(
- "DeepSpeed handles gradient clipping automatically within the optimizer. "
- "Make sure to set the `gradient_clipping` value in your Config."
- )
+ return self.deepspeed_impl.clip_gradients_value(module=module, optimizer=optimizer, clip_val=clip_val)
@classmethod
@override
@@ -613,338 +461,18 @@ def register_strategies(cls, strategy_registry: _StrategyRegistry) -> None:
offload_optimizer_device="nvme",
)
- def _initialize_engine(
- self, model: Module, optimizer: Optional[Optimizer] = None, scheduler: Optional["_LRScheduler"] = None
- ) -> tuple["DeepSpeedEngine", Optimizer, Any]:
- """Initialize one model and one optimizer with an optional learning rate scheduler.
-
- This calls ``deepspeed.initialize`` internally.
-
- """
- import deepspeed
-
- model_parameters = filter(lambda p: p.requires_grad, model.parameters())
- deepspeed_engine, deepspeed_optimizer, _, deepspeed_scheduler = deepspeed.initialize(
- args=argparse.Namespace(device_rank=self.root_device.index),
- config=self.config,
- model=model,
- model_parameters=model_parameters,
- optimizer=optimizer,
- lr_scheduler=scheduler,
- dist_init_required=False,
- )
- return deepspeed_engine, deepspeed_optimizer, deepspeed_scheduler
-
@override
def setup_environment(self) -> None:
- if not isinstance(self.accelerator, CUDAAccelerator):
- raise RuntimeError(
- f"The DeepSpeed strategy is only supported on CUDA GPUs but `{self.accelerator.__class__.__name__}`"
- " is used."
- )
- super().setup_environment()
+ return self.deepspeed_impl.setup_environment()
@override
def _setup_distributed(self) -> None:
- assert self.parallel_devices is not None
- _validate_device_index_selection(self.parallel_devices)
- reset_seed()
- self._set_world_ranks()
- self._init_deepspeed_distributed()
- if not self._config_initialized:
- self._format_config()
- self._config_initialized = True
-
- def _init_deepspeed_distributed(self) -> None:
- import deepspeed
-
- assert self.cluster_environment is not None
- if platform.system() != "Windows":
- # do not set env variables on windows, allow deepspeed to control setup
- self._set_node_environment_variables()
- log.info(
- "initializing deepspeed distributed: "
- f"GLOBAL_RANK: {self.global_rank}, "
- f"MEMBER: {self.global_rank + 1}/{self.world_size}"
- )
- self._process_group_backend = self._get_process_group_backend()
- deepspeed.init_distributed(
- self._process_group_backend, distributed_port=self.cluster_environment.main_port, timeout=self._timeout
- )
-
- def _set_node_environment_variables(self) -> None:
- assert self.cluster_environment is not None
- os.environ["MASTER_ADDR"] = self.cluster_environment.main_address
- os.environ["MASTER_PORT"] = str(self.cluster_environment.main_port)
- os.environ["RANK"] = str(self.global_rank)
- os.environ["WORLD_SIZE"] = str(self.world_size)
- os.environ["LOCAL_RANK"] = str(self.local_rank)
-
- def _set_deepspeed_activation_checkpointing(self) -> None:
- import deepspeed
-
- assert isinstance(self.config, dict)
- if self.config.get("activation_checkpointing"):
- checkpoint_config = self.config["activation_checkpointing"]
- deepspeed.checkpointing.configure(
- mpu_=None,
- partition_activations=checkpoint_config.get("partition_activations"),
- contiguous_checkpointing=checkpoint_config.get("contiguous_memory_optimization"),
- checkpoint_in_cpu=checkpoint_config.get("cpu_checkpointing"),
- profile=checkpoint_config.get("profile"),
- )
-
- def _format_config(self) -> None:
- if self.config is None:
- raise ValueError(
- "To use DeepSpeed you must pass in a DeepSpeed config dict, or a path to a JSON config."
- " See: https://lightning.ai/docs/pytorch/stable/advanced/model_parallel.html#deepspeed"
- )
-
- self.config.setdefault("train_micro_batch_size_per_gpu", 1)
- _format_precision_config(
- config=self.config,
- precision=self.precision.precision,
- loss_scale=self.loss_scale,
- loss_scale_window=self.loss_scale_window,
- min_loss_scale=self.min_loss_scale,
- initial_scale_power=self.initial_scale_power,
- hysteresis=self.hysteresis,
- )
-
- def _create_default_config(
- self,
- zero_optimization: bool,
- zero_allow_untested_optimizer: bool,
- logging_batch_size_per_gpu: Optional[int],
- partition_activations: bool,
- cpu_checkpointing: bool,
- contiguous_memory_optimization: bool,
- synchronize_checkpoint_boundary: bool,
- offload_optimizer: bool,
- offload_parameters: bool,
- nvme_path: str,
- offload_params_device: str,
- params_buffer_count: int,
- params_buffer_size: int,
- max_in_cpu: int,
- offload_optimizer_device: str,
- optimizer_buffer_count: int,
- pin_memory: bool,
- block_size: int,
- queue_depth: int,
- single_submit: bool,
- overlap_events: bool,
- thread_count: int,
- **zero_kwargs: Any,
- ) -> dict:
- cfg = {
- "activation_checkpointing": {
- "partition_activations": partition_activations,
- "cpu_checkpointing": cpu_checkpointing,
- "contiguous_memory_optimization": contiguous_memory_optimization,
- "synchronize_checkpoint_boundary": synchronize_checkpoint_boundary,
- },
- "aio": {
- "block_size": block_size,
- "queue_depth": queue_depth,
- "single_submit": single_submit,
- "overlap_events": overlap_events,
- "thread_count": thread_count,
- },
- }
- if zero_optimization:
- zero_config = zero_kwargs
-
- if offload_optimizer:
- zero_config["offload_optimizer"] = {
- "device": offload_optimizer_device,
- "nvme_path": nvme_path,
- "buffer_count": optimizer_buffer_count,
- "pin_memory": pin_memory,
- }
- if offload_parameters:
- zero_config["offload_param"] = {
- "device": offload_params_device,
- "nvme_path": nvme_path,
- "buffer_count": params_buffer_count,
- "buffer_size": params_buffer_size,
- "max_in_cpu": max_in_cpu,
- "pin_memory": pin_memory,
- }
- cfg.update({
- "zero_allow_untested_optimizer": zero_allow_untested_optimizer,
- "zero_optimization": zero_config,
- })
- if logging_batch_size_per_gpu:
- cfg["train_micro_batch_size_per_gpu"] = logging_batch_size_per_gpu
- return cfg
-
- def _restore_zero_state(self, module: Module, ckpt: Mapping[str, Any]) -> None:
- """Overrides the normal load_state_dict behaviour in PyTorch to ensure we gather parameters that may be sharded
- across processes before loading the state dictionary when using ZeRO stage 3. This is then automatically synced
- across processes.
-
- Args:
- ckpt: The ckpt file.
-
- """
- import deepspeed
-
- def load(module: torch.nn.Module, prefix: str = "") -> None:
- missing_keys: list[str] = []
- unexpected_keys: list[str] = []
- error_msgs: list[str] = []
- state_dict = ckpt["state_dict"]
-
- # copy state_dict so _load_from_state_dict can modify it
- metadata = getattr(state_dict, "_metadata", None)
- state_dict = state_dict.copy()
- if metadata is not None:
- state_dict._metadata = metadata
-
- local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
- # because zero3 puts placeholders in model params, this context
- # manager gathers (unpartitions) the params of the current layer, then loads from
- # the state dict and then re-partitions them again
- with deepspeed.zero.GatheredParameters(list(module.parameters(recurse=False)), modifier_rank=0):
- if self.is_global_zero:
- module._load_from_state_dict(
- state_dict=state_dict,
- prefix=prefix,
- local_metadata=local_metadata,
- strict=True,
- missing_keys=missing_keys,
- unexpected_keys=unexpected_keys,
- error_msgs=error_msgs,
- )
-
- for name, child in module._modules.items():
- if child is not None:
- load(child, prefix + name + ".")
-
- load(module, prefix="")
-
- def _load_config(self, config: Optional[Union[_PATH, dict[str, Any]]]) -> Optional[dict[str, Any]]:
- if config is None and self.DEEPSPEED_ENV_VAR in os.environ:
- rank_zero_info(f"Loading DeepSpeed config from set {self.DEEPSPEED_ENV_VAR} environment variable")
- config = os.environ[self.DEEPSPEED_ENV_VAR]
- if isinstance(config, (str, Path)):
- if not os.path.isfile(config):
- raise FileNotFoundError(
- f"You passed in a path to a DeepSpeed config but the path does not exist: {config}"
- )
- with open(config) as f:
- config = json.load(f)
- assert isinstance(config, dict) or config is None
- return config
-
-
-def _get_deepspeed_engines_from_state(state: dict[str, Any]) -> list["DeepSpeedEngine"]:
- from deepspeed import DeepSpeedEngine
-
- modules = chain(*(module.modules() for module in state.values() if isinstance(module, Module)))
- return [engine for engine in modules if isinstance(engine, DeepSpeedEngine)]
-
-
-def _validate_state_keys(state: dict[str, Any]) -> None:
- # DeepSpeed merges the client state into its internal engine state when saving, but it does not check for
- # colliding keys from the user. We explicitly check it here:
- deepspeed_internal_keys = {
- "module",
- "buffer_names",
- "optimizer",
- "param_shapes",
- "lr_scheduler",
- "sparse_tensor_module_names",
- "skipped_steps",
- "global_steps",
- "global_samples",
- "dp_world_size",
- "mp_world_size",
- "ds_config",
- "ds_version",
- }
- colliding_keys = deepspeed_internal_keys.intersection(state.keys())
- if colliding_keys:
- rank_zero_warn(
- "Your state has keys that collide with DeepSpeed's internal engine state. This could result in your"
- " values being overwritten by DeepSpeed. Consider changing the name of these keys to something else: "
- + ", ".join(colliding_keys)
- )
-
-
-def _validate_device_index_selection(parallel_devices: list[torch.device]) -> None:
- selected_device_indices = [device.index for device in parallel_devices]
- expected_device_indices = list(range(len(parallel_devices)))
- if selected_device_indices != expected_device_indices:
- raise RuntimeError(
- f"The selected device indices {selected_device_indices!r} don't match the local rank values of processes."
- " If you need to select GPUs at a specific index, set the `CUDA_VISIBLE_DEVICES` environment variable"
- f" instead. For example: `CUDA_VISIBLE_DEVICES={','.join(str(i) for i in selected_device_indices)}`."
- )
+ return self.deepspeed_impl._setup_distributed()
+ @property
+ def config(self) -> dict[str, Any]:
+ return self.deepspeed_impl.config
-def _is_deepspeed_checkpoint(path: Path) -> bool:
- """Heuristic check whether the path points to a top-level DeepSpeed checkpoint directory."""
- return path.is_dir() and (path / "checkpoint").is_dir()
-
-
-def _validate_checkpoint_directory(path: _PATH) -> None:
- """Validates that the path points to a DeepSpeed checkpoint directory and suggests fixes for user error."""
- # Example DeepSpeed checkpoint directory:
- #
- # epoch=5-step=10999.ckpt
- # ├── checkpoint
- # │ ├── zero_pp_rank_0_mp_rank_00_model_states.pt
- # │ ├── zero_pp_rank_0_mp_rank_00_optim_states.pt
- # │ ├── zero_pp_rank_1_mp_rank_00_model_states.pt
- # │ └── zero_pp_rank_1_mp_rank_00_optim_states.pt
- # ├── latest
- # └── zero_to_fp32.py
-
- path = Path(path)
- path_is_ds_checkpoint = _is_deepspeed_checkpoint(path)
- default_message = f"The provided path is not a valid DeepSpeed checkpoint: {path}"
-
- if not path_is_ds_checkpoint:
- # Case 1: User may have accidentally passed the subfolder "checkpoint"
- parent_is_ds_checkpoint = _is_deepspeed_checkpoint(path.parent)
- if parent_is_ds_checkpoint:
- raise FileNotFoundError(
- f"{default_message}. It looks like you passed the path to a subfolder."
- f" Try to load using this parent directory instead: {path.parent}"
- )
- # Case 2: User may have accidentally passed the path to a file inside the "checkpoint" subfolder
- parent_parent_is_ds_checkpoint = path.is_file() and _is_deepspeed_checkpoint(path.parent.parent)
- if parent_parent_is_ds_checkpoint:
- raise FileNotFoundError(
- f"{default_message}. It looks like you passed the path to a file inside a DeepSpeed checkpoint folder."
- f" Try to load using this parent directory instead: {path.parent.parent}"
- )
- raise FileNotFoundError(default_message)
-
-
-def _format_precision_config(
- config: dict[str, Any],
- precision: str,
- loss_scale: float,
- loss_scale_window: int,
- min_loss_scale: int,
- initial_scale_power: int,
- hysteresis: int,
-) -> None:
- if "fp16" not in config and precision in ("16-mixed", "16-true"):
- # FP16 is a DeepSpeed standalone AMP implementation
- rank_zero_info("Enabling DeepSpeed FP16. Model parameters and inputs will be cast to `float16`.")
- config["fp16"] = {
- "enabled": True,
- "loss_scale": loss_scale,
- "initial_scale_power": initial_scale_power,
- "loss_scale_window": loss_scale_window,
- "hysteresis": hysteresis,
- "min_loss_scale": min_loss_scale,
- }
- elif "bf16" not in config and precision in ("bf16-mixed", "bf16-true"):
- rank_zero_info("Enabling DeepSpeed BF16. Model parameters and inputs will be cast to `bfloat16`.")
- config["bf16"] = {"enabled": True}
+ @config.setter
+ def config(self, config: dict[str, Any]) -> None:
+ self.deepspeed_impl.config = config
diff --git a/src/lightning/fabric/strategies/launchers/xla.py b/src/lightning/fabric/strategies/launchers/xla.py
index 639de55805646..566312754339b 100644
--- a/src/lightning/fabric/strategies/launchers/xla.py
+++ b/src/lightning/fabric/strategies/launchers/xla.py
@@ -11,17 +11,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-import queue
-import time
-from typing import TYPE_CHECKING, Any, Callable, Optional, Union
+from typing import TYPE_CHECKING, Any, Callable, Union
-import torch.multiprocessing as mp
from typing_extensions import override
-from lightning.fabric.accelerators.xla import _XLA_AVAILABLE
from lightning.fabric.strategies.launchers.launcher import _Launcher
-from lightning.fabric.strategies.launchers.multiprocessing import _GlobalStateSnapshot
-from lightning.fabric.utilities.apply_func import move_data_to_device
+from lightning.fabric.utilities.imports import _raise_enterprise_not_available
if TYPE_CHECKING:
from lightning.fabric.strategies import XLAFSDPStrategy, XLAStrategy
@@ -45,15 +40,16 @@ class _XLALauncher(_Launcher):
"""
def __init__(self, strategy: Union["XLAStrategy", "XLAFSDPStrategy"]) -> None:
- if not _XLA_AVAILABLE:
- raise ModuleNotFoundError(str(_XLA_AVAILABLE))
- self._strategy = strategy
- self._start_method = "fork"
+ super().__init__()
+ _raise_enterprise_not_available()
+ from pytorch_lightning_enterprise.strategies.xla.launcher import XLALauncherFabric as EnterpriseXLALauncher
+
+ self.xla_impl = EnterpriseXLALauncher(strategy=strategy)
@property
@override
def is_interactive_compatible(self) -> bool:
- return True
+ return self.xla_impl.is_interactive_compatible
@override
def launch(self, function: Callable, *args: Any, **kwargs: Any) -> Any:
@@ -68,61 +64,12 @@ def launch(self, function: Callable, *args: Any, **kwargs: Any) -> Any:
**kwargs: Optional keyword arguments to be passed to the given function.
"""
- return_queue: Union[queue.Queue, mp.SimpleQueue]
- return_queue = mp.Manager().Queue()
-
- import torch_xla.distributed.xla_multiprocessing as xmp
-
- spawn_kwargs = {}
- nprocs = self._strategy.num_processes
- if nprocs == 1:
- # avoid warning: "Unsupported nprocs". If it's 1, it will call the launched function directly.
- # otherwise it will use all devices
- spawn_kwargs["nprocs"] = nprocs
-
- xmp.spawn(
- self._wrapping_function,
- args=(function, args, kwargs, return_queue),
- start_method=self._start_method,
- **spawn_kwargs,
- )
- return return_queue.get()
-
- def _wrapping_function(
- self,
- # XLA's multiprocessing returns the global index, not the local index as torch's multiprocessing
- # https://github.com/pytorch/xla/blob/v1.13.0/torch_xla/distributed/xla_multiprocessing.py#L321
- process_idx: int,
- function: Callable,
- args: Any,
- kwargs: Any,
- return_queue: Union[mp.SimpleQueue, queue.Queue],
- global_states: Optional[_GlobalStateSnapshot] = None,
- ) -> None:
- import torch_xla.core.xla_model as xm
-
- if len(xm.get_xla_supported_devices()) > 1:
- # `get_xla_supported_devices` in the spawned process returns the logical devices (2 for v2/v3 and 1 for v4)
- # so when there's more than one (multithreading), objects need to be deep-copied
- import copy
-
- function, args, kwargs = copy.deepcopy((function, args, kwargs))
-
- results = function(*args, **kwargs)
-
- if self._strategy.local_rank == 0:
- return_queue.put(move_data_to_device(results, "cpu"))
-
- _rank_teardown(self._strategy.local_rank)
-
-
-def _rank_teardown(rank: int) -> None:
- import torch_xla.core.xla_model as xm
-
- # Make all processes wait for each other before joining
- # https://github.com/pytorch/xla/issues/1801#issuecomment-602799542
- xm.rendezvous("end-process")
- # Ensure that the rank 0 process is the one exiting last
- # https://github.com/pytorch/xla/issues/2190#issuecomment-641665358
- if rank == 0:
- time.sleep(1)
+ return self.xla_impl.launch(function=function, *args, **kwargs)
+
+ @property
+ def _start_method(self) -> str:
+ return self.xla_impl._start_method
+
+ @_start_method.setter
+ def _start_method(self, start_method: str) -> None:
+ self.xla_impl._start_method = start_method
diff --git a/src/lightning/fabric/strategies/single_xla.py b/src/lightning/fabric/strategies/single_xla.py
index ba2fce91f1146..59b38ff6157f6 100644
--- a/src/lightning/fabric/strategies/single_xla.py
+++ b/src/lightning/fabric/strategies/single_xla.py
@@ -13,15 +13,14 @@
# limitations under the License.
from typing import Optional
-import torch
from typing_extensions import override
from lightning.fabric.accelerators import Accelerator
-from lightning.fabric.accelerators.xla import _XLA_AVAILABLE
from lightning.fabric.plugins import CheckpointIO, Precision, XLAPrecision
from lightning.fabric.plugins.io.xla import XLACheckpointIO
from lightning.fabric.strategies import _StrategyRegistry
from lightning.fabric.strategies.single_device import SingleDeviceStrategy
+from lightning.fabric.utilities.imports import _raise_enterprise_not_available
from lightning.fabric.utilities.types import _DEVICE
@@ -35,20 +34,16 @@ def __init__(
checkpoint_io: Optional[XLACheckpointIO] = None,
precision: Optional[XLAPrecision] = None,
):
- if not _XLA_AVAILABLE:
- raise ModuleNotFoundError(str(_XLA_AVAILABLE))
- if isinstance(device, torch.device):
- # unwrap the `torch.device` in favor of `xla_device`
- device = device.index
-
- import torch_xla.core.xla_model as xm
+ _raise_enterprise_not_available()
+ from pytorch_lightning_enterprise.strategies.xla.single import validate_xla_strategy
super().__init__(
accelerator=accelerator,
- device=xm.xla_device(device),
+ device=device,
checkpoint_io=checkpoint_io,
precision=precision,
)
+ validate_xla_strategy(strategy=self, device=device)
@property
@override
diff --git a/src/lightning/fabric/strategies/xla.py b/src/lightning/fabric/strategies/xla.py
index 3a571fef37f00..5a877a5c0dcc8 100644
--- a/src/lightning/fabric/strategies/xla.py
+++ b/src/lightning/fabric/strategies/xla.py
@@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-import io
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
import torch
@@ -22,14 +21,13 @@
from typing_extensions import override
from lightning.fabric.accelerators import Accelerator
-from lightning.fabric.accelerators.xla import _XLA_GREATER_EQUAL_2_1
from lightning.fabric.plugins import CheckpointIO, Precision, XLAPrecision
from lightning.fabric.plugins.environments import XLAEnvironment
from lightning.fabric.plugins.io.xla import XLACheckpointIO
from lightning.fabric.strategies import ParallelStrategy, _StrategyRegistry
from lightning.fabric.strategies.launchers.xla import _XLALauncher
from lightning.fabric.strategies.strategy import TBroadcast
-from lightning.fabric.utilities.rank_zero import rank_zero_only
+from lightning.fabric.utilities.imports import _raise_enterprise_not_available
from lightning.fabric.utilities.types import _PATH, ReduceOp
if TYPE_CHECKING:
@@ -55,22 +53,19 @@ def __init__(
checkpoint_io=checkpoint_io,
precision=precision,
)
- self._backward_sync_control = None # XLA synchronizes gradients in the optimizer.step() call
- self._launched = False
- self._sync_module_states = sync_module_states
+ _raise_enterprise_not_available()
+ from pytorch_lightning_enterprise.strategies.xla.ddp import XLAStrategyFabric as EnterpriseXLAStrategy
+
+ self.xla_strategy_impl = EnterpriseXLAStrategy(outer_object=self, sync_module_states=sync_module_states)
@property
@override
def root_device(self) -> torch.device:
- if not self._launched:
- raise RuntimeError("Accessing the XLA device before processes have spawned is not allowed.")
- import torch_xla.core.xla_model as xm
-
- return xm.xla_device()
+ return self.xla_strategy_impl.root_device
@property
def num_processes(self) -> int:
- return len(self.parallel_devices) if self.parallel_devices is not None else 0
+ return self.xla_strategy_impl.num_processes
@property
@override
@@ -107,22 +102,22 @@ def precision(self, precision: Optional[Precision]) -> None:
@property
@override
def global_rank(self) -> int:
- return super().global_rank if self._launched else 0
+ return self.xla_strategy_impl.global_rank
@property
@override
def local_rank(self) -> int:
- return super().local_rank if self._launched else 0
+ return self.xla_strategy_impl.local_rank
@property
@override
def node_rank(self) -> int:
- return super().node_rank if self._launched else 0
+ return self.xla_strategy_impl.node_rank
@property
@override
def world_size(self) -> int:
- return super().world_size if self._launched else 1
+ return self.xla_strategy_impl.world_size
@override
def _configure_launcher(self) -> None:
@@ -130,48 +125,19 @@ def _configure_launcher(self) -> None:
@override
def setup_environment(self) -> None:
- assert self.parallel_devices is not None
- if len(self.parallel_devices) == 1:
- # spawning only 1 device with PjRT is not supported:
- # https://github.com/Lightning-AI/pytorch-lightning/pull/17408#discussion_r1170671732
- raise NotImplementedError(
- f"The {type(self).__name__} does not support running on a single device with the PjRT runtime."
- " Try using all devices or the `SingleDeviceXLAStrategy` strategy"
- )
-
- self._launched = True
- rank_zero_only.rank = self.global_rank
- super().setup_environment()
+ return self.xla_strategy_impl.setup_environment()
@override
def setup_module(self, module: Module) -> Module:
- if self._sync_module_states:
- if _XLA_GREATER_EQUAL_2_1:
- from torch_xla.core.xla_model import broadcast_master_param
- else:
- from torch_xla.experimental.pjrt import broadcast_master_param
-
- broadcast_master_param(module)
-
- return module
+ return self.xla_strategy_impl.setup_module(module=module)
@override
def module_to_device(self, module: Module) -> None:
- module.to(self.root_device)
+ return self.xla_strategy_impl.module_to_device(module=module)
@override
def process_dataloader(self, dataloader: DataLoader) -> "MpDeviceLoader":
- from torch_xla.distributed.parallel_loader import MpDeviceLoader
-
- if isinstance(dataloader, MpDeviceLoader):
- # dataloader is already wrapped by MpDeviceLoader
- return dataloader
-
- dataloader = MpDeviceLoader(dataloader, self.root_device)
- # Mimic interface to torch.utils.data.DataLoader
- dataloader.dataset = dataloader._loader.dataset
- dataloader.batch_sampler = getattr(dataloader._loader, "batch_sampler", None)
- return dataloader
+ return self.xla_strategy_impl.process_dataloader(dataloader=dataloader)
@override
def all_gather(self, tensor: Tensor, group: Optional[Any] = None, sync_grads: bool = False) -> Tensor:
@@ -185,92 +151,21 @@ def all_gather(self, tensor: Tensor, group: Optional[Any] = None, sync_grads: bo
A tensor of shape (world_size, ...)
"""
- if not self._launched:
- return tensor
- if not isinstance(tensor, Tensor):
- raise NotImplementedError(
- f"`{type(self).__name__}.all_gather` is only implemented for tensors. Given {tensor}"
- )
- if tensor.dim() == 0:
- tensor = tensor.unsqueeze(0)
- original_device = tensor.device
- tensor = tensor.to(self.root_device)
-
- import torch_xla.core.functions as xf
- import torch_xla.core.xla_model as xm
-
- tensor = xf.all_gather(tensor) if sync_grads else xm.all_gather(tensor)
- tensor = tensor.to(original_device)
- return tensor
+ return self.xla_strategy_impl.all_gather(tensor=tensor, group=group, sync_grads=sync_grads)
@override
def all_reduce(
self, output: Union[Tensor, Any], group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = None
) -> Tensor:
- if not isinstance(output, Tensor):
- output = torch.tensor(output, device=self.root_device)
-
- invalid_reduce_op = isinstance(reduce_op, ReduceOp) and reduce_op != ReduceOp.SUM
- invalid_reduce_op_str = isinstance(reduce_op, str) and reduce_op.lower() not in ("sum", "mean", "avg")
- if invalid_reduce_op or invalid_reduce_op_str:
- raise ValueError(
- "Currently, the XLAStrategy only supports `sum`, `mean`, `avg` for the reduce operation, got:"
- f" {reduce_op}"
- )
- import torch_xla.core.xla_model as xm
-
- output = xm.mesh_reduce("reduce", output, sum)
-
- if isinstance(reduce_op, str) and reduce_op.lower() in ("avg", "mean"):
- output = output / self.world_size
-
- return output
+ return self.xla_strategy_impl.all_reduce(output=output, group=group, reduce_op=reduce_op)
@override
def barrier(self, name: Optional[str] = None, *args: Any, **kwargs: Any) -> None:
- if not self._launched:
- return
- import torch_xla.core.xla_model as xm
-
- if name is None:
- # `None` is not supported: "TypeError: _xla_rendezvous(): incompatible function arguments"
- name = ""
- xm.rendezvous(name)
+ return self.xla_strategy_impl.barrier(name=name, *args, **kwargs)
@override
def broadcast(self, obj: TBroadcast, src: int = 0) -> TBroadcast:
- if not self._launched:
- return obj
-
- import torch_xla.core.xla_model as xm
-
- is_tensor = isinstance(obj, Tensor)
- if is_tensor:
- if obj.dim() == 0:
- obj = obj.unsqueeze(0)
- original_device = obj.device
- # XLA distributed requires that the data is on the XLA device
- obj = obj.to(self.root_device)
- else:
- # support for arbitrary pickle-ables
- buffer = io.BytesIO()
- torch.save(obj, buffer)
- obj = torch.tensor( # type: ignore[assignment]
- bytearray(buffer.getbuffer()), device=self.root_device, dtype=torch.float
- )
-
- obj = [obj]
- xm.collective_broadcast(obj, root_ordinal=src)
- obj = obj[0]
-
- if not is_tensor:
- # this will preserve the dtype and device of any tensors
- buffer = io.BytesIO(obj.cpu().byte().numpy())
- obj = torch.load(buffer)
- else:
- obj = obj.to(original_device)
-
- return obj
+ return self.xla_strategy_impl.broadcast(obj=obj, src=src)
@override
def save_checkpoint(
@@ -291,12 +186,9 @@ def save_checkpoint(
boolean indicating whether the given parameter should be saved (``True``) or filtered out (``False``).
"""
- import torch_xla.core.xla_model as xm
-
- # sync any pending lazy tensors on all ranks before saving to prevent potential collective hangs
- xm.mark_step()
- # save on global rank zero only
- super().save_checkpoint(path, state, storage_options=storage_options, filter=filter)
+ return self.xla_strategy_impl.save_checkpoint(
+ path=path, state=state, storage_options=storage_options, filter=filter
+ )
@classmethod
@override
diff --git a/src/lightning/fabric/strategies/xla_fsdp.py b/src/lightning/fabric/strategies/xla_fsdp.py
index 51b528eff26ff..f54267ca59a72 100644
--- a/src/lightning/fabric/strategies/xla_fsdp.py
+++ b/src/lightning/fabric/strategies/xla_fsdp.py
@@ -11,10 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-import io
-from contextlib import AbstractContextManager, ExitStack, nullcontext
-from functools import partial
-from pathlib import Path
+from contextlib import AbstractContextManager
from typing import TYPE_CHECKING, Any, Callable, Literal, Optional, Union
import torch
@@ -25,22 +22,16 @@
from typing_extensions import override
from lightning.fabric.accelerators import Accelerator
-from lightning.fabric.accelerators.xla import _XLA_AVAILABLE
from lightning.fabric.plugins import CheckpointIO, Precision, XLAPrecision
from lightning.fabric.plugins.environments import XLAEnvironment
from lightning.fabric.plugins.io.xla import XLACheckpointIO
from lightning.fabric.strategies import ParallelStrategy, _StrategyRegistry
-from lightning.fabric.strategies.fsdp import _apply_filter
from lightning.fabric.strategies.launchers.xla import _XLALauncher
from lightning.fabric.strategies.strategy import (
TBroadcast,
- _BackwardSyncControl,
_Sharded,
- _validate_keys_for_strict_loading,
)
-from lightning.fabric.utilities.cloud_io import get_filesystem
-from lightning.fabric.utilities.init import _EmptyInit
-from lightning.fabric.utilities.rank_zero import rank_zero_only, rank_zero_warn
+from lightning.fabric.utilities.imports import _raise_enterprise_not_available
from lightning.fabric.utilities.types import _PATH, Optimizable, ReduceOp
if TYPE_CHECKING:
@@ -93,8 +84,6 @@ def __init__(
sequential_save: bool = False,
**kwargs: Any,
) -> None:
- if not _XLA_AVAILABLE:
- raise ModuleNotFoundError(str(_XLA_AVAILABLE))
super().__init__(
accelerator=accelerator,
parallel_devices=parallel_devices,
@@ -102,27 +91,28 @@ def __init__(
checkpoint_io=checkpoint_io,
precision=precision,
)
- self._backward_sync_control = _XLAFSDPBackwardSyncControl()
+ _raise_enterprise_not_available()
+ from pytorch_lightning_enterprise.strategies.xla.fsdp import (
+ XLAFSDPStrategyFabric as EnterpriseXLAFSDPStrategy,
+ )
- self._auto_wrap_policy = auto_wrap_policy
- self._activation_checkpointing_policy = activation_checkpointing_policy
- self._fsdp_kwargs = kwargs
- self._state_dict_type = state_dict_type
- self._sequential_save = sequential_save
- self._launched = False
+ self.xla_fsdp_impl = EnterpriseXLAFSDPStrategy(
+ outer_object=self,
+ auto_wrap_policy=auto_wrap_policy,
+ activation_checkpointing_policy=activation_checkpointing_policy,
+ state_dict_type=state_dict_type,
+ sequential_save=sequential_save,
+ **kwargs,
+ )
@property
@override
def root_device(self) -> torch.device:
- if not self._launched:
- raise RuntimeError("Accessing the XLA device before processes have spawned is not allowed.")
- import torch_xla.core.xla_model as xm
-
- return xm.xla_device()
+ return self.xla_fsdp_impl.root_device
@property
def num_processes(self) -> int:
- return len(self.parallel_devices) if self.parallel_devices is not None else 0
+ return self.xla_fsdp_impl.num_processes
@property
@override
@@ -159,22 +149,22 @@ def precision(self, precision: Optional[Precision]) -> None:
@property
@override
def global_rank(self) -> int:
- return super().global_rank if self._launched else 0
+ return self.xla_fsdp_impl.global_rank
@property
@override
def local_rank(self) -> int:
- return super().local_rank if self._launched else 0
+ return self.xla_fsdp_impl.local_rank
@property
@override
def node_rank(self) -> int:
- return super().node_rank if self._launched else 0
+ return self.xla_fsdp_impl.node_rank
@property
@override
def world_size(self) -> int:
- return super().world_size if self._launched else 1
+ return self.xla_fsdp_impl.world_size
@override
def _configure_launcher(self) -> None:
@@ -182,108 +172,40 @@ def _configure_launcher(self) -> None:
@override
def setup_environment(self) -> None:
- assert self.parallel_devices is not None
- if len(self.parallel_devices) == 1:
- # spawning only 1 device with PjRT is not supported:
- # https://github.com/Lightning-AI/pytorch-lightning/pull/17408#discussion_r1170671732
- raise NotImplementedError(
- f"The {type(self).__name__} does not support running on a single device with the PjRT runtime."
- " Try using all devices or the `SingleDeviceXLAStrategy` strategy"
- )
-
- self._launched = True
- rank_zero_only.rank = self.global_rank
- super().setup_environment()
+ return self.xla_fsdp_impl.setup_environment()
@override
def setup_module_and_optimizers(
self, module: Module, optimizers: list[Optimizer], scheduler: Optional["_LRScheduler"] = None
) -> tuple[Module, list[Optimizer], Optional["_LRScheduler"]]:
- """Returns NotImplementedError since for XLAFSDP optimizer setup must happen after module setup."""
- raise NotImplementedError(
- f"The `{type(self).__name__}` does not support the joint setup of module and optimizer(s)."
- " Please do it in this order: Create the model, call `setup_module`, create the optimizer,"
- " call `setup_optimizer`."
- )
+ return self.xla_fsdp_impl.setup_module_and_optimizers(module=module, optimizers=optimizers, scheduler=scheduler)
@override
def setup_module(self, module: Module) -> Module:
- from torch_xla.distributed.fsdp import XlaFullyShardedDataParallel as XLAFSDP
-
- kwargs = self._parse_fsdp_kwargs()
- if any(isinstance(mod, XLAFSDP) for mod in module.modules()) and "auto_wrap_policy" in kwargs:
- rank_zero_warn(
- "A XLAFSDP `auto_wrap_policy` is set, but at least one submodule is already wrapped."
- " The policy will be ignored."
- )
- del kwargs["auto_wrap_policy"]
- # XLA FSDP requires that the root is wrapped, even if submodules are already wrapped
- if not isinstance(module, XLAFSDP):
- module = XLAFSDP(module=module, **kwargs)
- return module
+ return self.xla_fsdp_impl.setup_module(module=module)
@override
def module_to_device(self, module: Module) -> None:
- pass
+ return self.xla_fsdp_impl.module_to_device(module=module)
def module_init_context(self, empty_init: Optional[bool] = None) -> AbstractContextManager:
- precision_init_ctx = self.precision.module_init_context()
- module_sharded_ctx = self.module_sharded_context()
- stack = ExitStack()
- stack.enter_context(_EmptyInit(enabled=bool(empty_init)))
- stack.enter_context(precision_init_ctx)
- stack.enter_context(module_sharded_ctx)
- return stack
+ return self.xla_fsdp_impl.module_init_context(empty_init=empty_init)
@override
def module_sharded_context(self) -> AbstractContextManager:
- return nullcontext()
+ return self.xla_fsdp_impl.module_sharded_context()
@override
def process_dataloader(self, dataloader: DataLoader) -> "MpDeviceLoader":
- from torch_xla.distributed.parallel_loader import MpDeviceLoader
-
- if isinstance(dataloader, MpDeviceLoader):
- # dataloader is already wrapped by MpDeviceLoader
- return dataloader
-
- dataloader = MpDeviceLoader(dataloader, self.root_device)
- # Mimic interface to torch.utils.data.DataLoader
- dataloader.dataset = dataloader._loader.dataset
- dataloader.batch_sampler = getattr(dataloader._loader, "batch_sampler", None)
- return dataloader
+ return self.xla_fsdp_impl.process_dataloader(dataloader=dataloader)
@override
def setup_optimizer(self, optimizer: Optimizer) -> Optimizer:
- """Set up an optimizer for a model wrapped with XLAFSDP.
-
- This setup method doesn't modify the optimizer or wrap the optimizer. The only thing it currently does is verify
- that the optimizer was created after the model was wrapped with :meth:`setup_module` with a reference to the
- flattened parameters.
-
- """
- if any(getattr(p, "_is_sharded", False) for group in optimizer.param_groups for p in group["params"]):
- return optimizer
- raise ValueError(
- "The optimizer does not seem to reference any XLAFSDP parameters. HINT: Make sure to create the optimizer"
- " after setting up the model."
- )
+ return self.xla_fsdp_impl.setup_optimizer(optimizer=optimizer)
@override
def optimizer_step(self, optimizer: Optimizable, **kwargs: Any) -> Any:
- """Overrides default tpu optimizer_step since FSDP should not call `torch_xla.core.xla_model.optimizer_step`.
- Performs the actual optimizer step.
-
- Args:
- optimizer: the optimizer performing the step
- **kwargs: Any extra arguments to ``optimizer.step``
-
- """
- loss = optimizer.step(**kwargs)
- import torch_xla.core.xla_model as xm
-
- xm.mark_step()
- return loss
+ return self.xla_fsdp_impl.optimizer_step(optimizer=optimizer, **kwargs)
@override
def clip_gradients_norm(
@@ -295,17 +217,18 @@ def clip_gradients_norm(
error_if_nonfinite: bool = True,
) -> Tensor:
"""Clip gradients by norm."""
- self.precision.unscale_gradients(optimizer)
- assert callable(module.clip_grad_norm_)
- return module.clip_grad_norm_(max_norm=max_norm, norm_type=norm_type)
+ return self.xla_fsdp_impl.clip_gradients_norm(
+ module=module,
+ optimizer=optimizer,
+ max_norm=max_norm,
+ norm_type=norm_type,
+ error_if_nonfinite=error_if_nonfinite,
+ )
@override
def clip_gradients_value(self, module: Module, optimizer: Optimizer, clip_val: Union[float, int]) -> None:
"""Clip gradients by value."""
- raise NotImplementedError(
- "XLA's FSDP strategy does not support to clip gradients by value."
- " Consider clipping by norm instead or choose another strategy!"
- )
+ return self.xla_fsdp_impl.clip_gradients_value(module=module, optimizer=optimizer, clip_val=clip_val)
@override
def all_gather(self, tensor: Tensor, group: Optional[Any] = None, sync_grads: bool = False) -> Tensor:
@@ -319,92 +242,21 @@ def all_gather(self, tensor: Tensor, group: Optional[Any] = None, sync_grads: bo
A tensor of shape (world_size, ...)
"""
- if not self._launched:
- return tensor
- if not isinstance(tensor, Tensor):
- raise NotImplementedError(
- f"`{type(self).__name__}.all_gather` is only implemented for tensors. Given {tensor}"
- )
- if tensor.dim() == 0:
- tensor = tensor.unsqueeze(0)
- original_device = tensor.device
- tensor = tensor.to(self.root_device)
-
- import torch_xla.core.functions as xf
- import torch_xla.core.xla_model as xm
-
- tensor = xf.all_gather(tensor) if sync_grads else xm.all_gather(tensor)
- tensor = tensor.to(original_device)
- return tensor
+ return self.xla_fsdp_impl.all_gather(tensor=tensor, group=group, sync_grads=sync_grads)
@override
def all_reduce(
self, output: Union[Tensor, Any], group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = None
) -> Tensor:
- if not isinstance(output, Tensor):
- output = torch.tensor(output, device=self.root_device)
-
- invalid_reduce_op = isinstance(reduce_op, ReduceOp) and reduce_op != ReduceOp.SUM
- invalid_reduce_op_str = isinstance(reduce_op, str) and reduce_op.lower() not in ("sum", "mean", "avg")
- if invalid_reduce_op or invalid_reduce_op_str:
- raise ValueError(
- "Currently, the XLAFSDPStrategy only supports `sum`, `mean`, `avg` for the reduce operation, got:"
- f" {reduce_op}"
- )
- import torch_xla.core.xla_model as xm
-
- output = xm.mesh_reduce("reduce", output, sum)
-
- if isinstance(reduce_op, str) and reduce_op.lower() in ("avg", "mean"):
- output = output / self.world_size
-
- return output
+ return self.xla_fsdp_impl.all_reduce(output=output, group=group, reduce_op=reduce_op)
@override
def barrier(self, name: Optional[str] = None, *args: Any, **kwargs: Any) -> None:
- if not self._launched:
- return
- import torch_xla.core.xla_model as xm
-
- if name is None:
- # `None` is not supported: "TypeError: _xla_rendezvous(): incompatible function arguments"
- name = ""
- xm.rendezvous(name)
+ return self.xla_fsdp_impl.barrier(name=name, *args, **kwargs)
@override
def broadcast(self, obj: TBroadcast, src: int = 0) -> TBroadcast:
- if not self._launched:
- return obj
-
- import torch_xla.core.xla_model as xm
-
- is_tensor = isinstance(obj, Tensor)
- if is_tensor:
- if obj.dim() == 0:
- obj = obj.unsqueeze(0)
- original_device = obj.device
- # XLA distributed requires that the data is on the XLA device
- obj = obj.to(self.root_device)
- else:
- # support for arbitrary pickle-ables
- buffer = io.BytesIO()
- torch.save(obj, buffer)
- obj = torch.tensor( # type: ignore[assignment]
- bytearray(buffer.getbuffer()), device=self.root_device, dtype=torch.float
- )
-
- obj = [obj]
- xm.collective_broadcast(obj, root_ordinal=src)
- obj = obj[0]
-
- if not is_tensor:
- # this will preserve the dtype and device of any tensors
- buffer = io.BytesIO(obj.cpu().byte().numpy())
- obj = torch.load(buffer)
- else:
- obj = obj.to(original_device)
-
- return obj
+ return self.xla_fsdp_impl.broadcast(obj=obj, src=src)
@override
def save_checkpoint(
@@ -421,93 +273,8 @@ def save_checkpoint(
consolidated checkpoint combining all of the sharded checkpoints.
"""
- # broadcast the path from rank 0 to ensure all the states are saved in a common path
- path = Path(self.broadcast(path))
- if path.is_dir() and any(path.iterdir()):
- raise FileExistsError(f"The checkpoint directory already exists and is not empty: {path}")
- from torch_xla.distributed.fsdp import XlaFullyShardedDataParallel as XLAFSDP
-
- modules = [module for module in state.values() if isinstance(module, XLAFSDP)]
- if len(modules) == 0:
- raise ValueError(
- "Could not find a XLAFSDP model in the provided checkpoint state. Please provide the model as"
- " part of the state like so: `save_checkpoint(..., state={'model': model, ...})`. Make sure"
- " you set up the model (and optimizers if any) through the strategy before saving the checkpoint."
- )
- if len(modules) > 1:
- raise ValueError(
- "Found multiple XLAFSDP modules in the given state. Saving checkpoints with FSDP is"
- " currently limited to a single model per checkpoint. To save multiple models, call the"
- " save method for each model separately with a different path."
- )
- import torch_xla.core.xla_model as xm
-
- # ensure model parameters are updated
- xm.mark_step()
-
- parallel_devices = self.parallel_devices
- assert parallel_devices is not None
- if self._sequential_save:
- # each host runs this in parallel, but the ranks in the host run it sequentially
- for rank in range(len(parallel_devices)):
- if rank == self.local_rank:
- self._save_checkpoint_shard(path, state, storage_options, filter)
- self.barrier(f"wait-for-{rank}-save")
- else:
- self._save_checkpoint_shard(path, state, storage_options, filter)
-
- if self._state_dict_type == "full":
- ckpt_prefix = str(path / "checkpoint")
- ckpt_suffix = "_rank-*-of-*.pth"
- if len(parallel_devices) != self.world_size: # multihost
- raise OSError(
- "Multihost setups do not have a shared filesystem, so the checkpoint shards cannot be consolidated"
- " into a single checkpoint after saving them. Please switch to"
- " `XLAFSDPStrategy(state_dict_type='sharded')`. TIP: You can consolidate them manually by getting"
- " them together into a single directory and running `python -m"
- f" torch_xla.distributed.fsdp.consolidate_sharded_ckpts --ckpt_prefix {ckpt_prefix!r} --ckpt_suffix"
- f" {ckpt_suffix!r} --save_path 'path/to/consolidated.ckpt'`."
- )
-
- from torch_xla.distributed.fsdp import consolidate_sharded_model_checkpoints
-
- self.barrier("before_ckpt_consolidation")
- if self.is_global_zero:
- save_path = path.parent / "consolidated.ckpt"
- # save consolidated checkpoint separate to the shards
- consolidate_sharded_model_checkpoints(ckpt_prefix, ckpt_suffix, str(save_path))
- # remove the shards directory
- self.checkpoint_io.remove_checkpoint(path)
- # mv the consolidated checkpoint where the user would expect it
- get_filesystem(save_path).mv(str(save_path), str(path))
- self.barrier("after_ckpt_consolidation")
-
- def _save_checkpoint_shard(
- self,
- path: Path,
- state: dict[str, Union[Module, Optimizer, Any]],
- storage_options: Optional[Any],
- filter: Optional[dict[str, Callable[[str, Any], bool]]],
- ) -> None:
- from torch_xla.distributed.fsdp import XlaFullyShardedDataParallel as XLAFSDP
-
- converted_state: dict[str, Any] = {}
- for key, obj in state.items():
- # convert the state
- if isinstance(obj, Module) and isinstance(obj, XLAFSDP):
- converted = obj.state_dict()
- # add shard_metadata to state
- converted_state["shard_metadata"] = obj.get_shard_metadata()
- elif isinstance(obj, Optimizer):
- converted = obj.state_dict()
- else:
- converted = obj
- _apply_filter(key, filter or {}, converted, converted_state)
-
- self.checkpoint_io.save_checkpoint(
- converted_state,
- path / f"checkpoint_rank-{self.global_rank:08d}-of-{self.world_size:08d}.pth",
- storage_options=storage_options,
+ return self.xla_fsdp_impl.save_checkpoint(
+ path=path, state=state, storage_options=storage_options, filter=filter
)
@override
@@ -524,164 +291,9 @@ def load_checkpoint(
directory of multiple files rather than a single file.
"""
- if not state:
- raise ValueError(
- f"Got `XLAFSDPStrategy.load_checkpoint(..., state={state!r})` but a state with at least "
- " a model instance to reload is required. Pass it in like so:"
- " `FSDPStrategy.load_checkpoint(..., state={'model': model, ...})`"
- )
-
- # broadcast the path from rank 0 to ensure all the states are loaded from a common path
- path = Path(self.broadcast(path))
-
- if isinstance(state, (Module, Optimizer)):
- raise NotImplementedError(
- "Loading a single module or optimizer object from a checkpoint"
- " is not supported yet with the XLAFSDP strategy."
- )
-
- from torch_xla.distributed.fsdp import XlaFullyShardedDataParallel as XLAFSDP
-
- modules = {key: module for key, module in state.items() if isinstance(module, XLAFSDP)}
- optimizers = {key: optim for key, optim in state.items() if isinstance(optim, Optimizer)}
- if self._state_dict_type == "sharded":
- file = path / f"checkpoint_rank-{self.global_rank:08d}-of-{self.world_size:08d}.pth"
- if not file.is_file():
- raise ValueError(
- f"The path {str(file)!r} does not point to valid sharded checkpoints. Make sure the path points to"
- " a directory with XLAFSDP checkpoint shards."
- )
- if len(modules) == 0:
- raise ValueError(
- "Could not find a XLAFSDP model in the provided checkpoint state. Please provide the model as"
- " part of the state like so: `load_checkpoint(..., state={'model': model, ...})`. Make sure"
- " you set up the model (and optimizers if any) through the strategy before loading the checkpoint."
- )
- if len(modules) > 1:
- raise ValueError(
- "Found multiple XLAFSDP modules in the given state. Loading checkpoints with FSDP is"
- " currently limited to a single model per checkpoint. To load multiple models, call the"
- " load method for each model separately with a different path."
- )
-
- _, module = list(modules.items())[0]
- sharded_ckpt = torch.load(file)
-
- module.load_state_dict(sharded_ckpt["model"], strict=strict)
- for opt_key, opt in optimizers.items():
- opt.load_state_dict(sharded_ckpt[opt_key])
-
- # Load anything leftover from sharded_ckpt
- loaded_metadata_keys = sharded_ckpt.keys() - modules.keys() - optimizers.keys()
- requested_metadata_keys = state.keys() - modules.keys() - optimizers.keys()
- _validate_keys_for_strict_loading(requested_metadata_keys, loaded_metadata_keys, strict=strict)
- for key in requested_metadata_keys:
- if key in loaded_metadata_keys:
- state[key] = sharded_ckpt[key]
- loaded_metadata_keys.remove(key)
-
- metadata = {}
- if len(loaded_metadata_keys):
- for key in loaded_metadata_keys:
- metadata[key] = sharded_ckpt[key]
-
- # remove "shard_metadata" that is loaded in
- if "shard_metadata" in metadata:
- metadata.pop("shard_metadata")
-
- return metadata
-
- if self._state_dict_type == "full":
- if not path.is_file():
- raise ValueError(
- f"The path {str(path)!r} does not point to a valid full checkpoint. Make sure the path points to a"
- " directory with a full XLAFSDP checkpoint."
- )
- if len(optimizers) > 0 or len(state.keys() - modules.keys() - optimizers.keys()) > 0:
- rank_zero_warn(
- "Loading a full checkpoint will only load the full model."
- " The optimizer and any additional metadata are not included."
- )
- if len(modules) > 0:
- raise ValueError(
- "Found a XLAFSDP model in the provided checkpoint state."
- " Please provide the model without any XLAFSDP wrapper."
- )
- if "model" not in state or not isinstance(model := state["model"], torch.nn.Module):
- raise NotImplementedError("XLAFSDP only supports a single model instance with 'model' as the key.")
- full_ckpt = torch.load(path, weights_only=weights_only)
- model.load_state_dict(full_ckpt.pop("model"), strict=strict)
- return full_ckpt
-
- raise ValueError(f"Unknown state_dict_type: {self._state_dict_type}")
+ return self.xla_fsdp_impl.load_checkpoint(path=path, state=state, strict=strict, weights_only=weights_only)
@classmethod
@override
def register_strategies(cls, strategy_registry: _StrategyRegistry) -> None:
strategy_registry.register("xla_fsdp", cls, description=cls.__name__)
-
- def _parse_fsdp_kwargs(self) -> dict:
- # this needs to be delayed because `self.precision` isn't available at init
- kwargs = self._fsdp_kwargs.copy()
- precision = self.precision
- if isinstance(precision, XLAPrecision):
- # the `compute_dtype` will be passed to the `auto_wrapper_callable` automatically, so we don't need to pass
- # it when creating it
- kwargs.setdefault("compute_dtype", precision._desired_dtype)
- kwargs = _auto_wrap_policy_kwargs(self._auto_wrap_policy, kwargs)
- return _activation_checkpointing_kwargs(self._activation_checkpointing_policy, kwargs)
-
-
-def _auto_wrap_policy_kwargs(policy: Optional["_POLICY"], kwargs: dict) -> dict:
- if policy is None:
- return kwargs
- if isinstance(policy, set):
- from torch_xla.distributed.fsdp.wrap import transformer_auto_wrap_policy
-
- # this is not transformer specific despite the name
- policy = partial(transformer_auto_wrap_policy, transformer_layer_cls=policy)
- kwargs["auto_wrap_policy"] = policy
- return kwargs
-
-
-def _activation_checkpointing_auto_wrapper(policy: _POLICY_SET, module: Module, *args: Any, **kwargs: Any) -> Module:
- from torch_xla.distributed.fsdp import XlaFullyShardedDataParallel as XLAFSDP
- from torch_xla.distributed.fsdp import checkpoint_module
-
- module = checkpoint_module(module) if isinstance(module, tuple(policy)) else module
- return XLAFSDP(module, *args, **kwargs)
-
-
-def _activation_checkpointing_kwargs(policy: Optional[_POLICY_SET], kwargs: dict) -> dict:
- if not policy:
- return kwargs
- if "auto_wrapper_callable" in kwargs:
- raise ValueError(
- "You cannot set both `auto_wrapper_callable` and `activation_checkpointing_policy`. Choose one"
- )
- if not isinstance(policy, set):
- raise TypeError(
- f"`activation_checkpointing_policy` must be a set, found {policy}. You can try defining and"
- " passing `auto_wrapper_callable` instead."
- )
- auto_wrapper_callable = partial(_activation_checkpointing_auto_wrapper, policy)
- kwargs["auto_wrapper_callable"] = auto_wrapper_callable
- return kwargs
-
-
-class _XLAFSDPBackwardSyncControl(_BackwardSyncControl):
- @override
- def no_backward_sync(self, module: Module, enabled: bool) -> AbstractContextManager:
- """Blocks gradient synchronization inside the :class:`~torch_xla.distributed.fsdp.XlaFullyShardedDataParallel`
- wrapper."""
- if not enabled:
- return nullcontext()
- from torch_xla.distributed.fsdp import XlaFullyShardedDataParallel as XLAFSDP
-
- if not isinstance(module, XLAFSDP):
- raise TypeError(
- "Blocking backward sync is only possible if the module passed to"
- f" `{self.__class__.__name__}.no_backward_sync` is wrapped in `XlaFullyShardedDataParallel`."
- f" Got: {module.__class__.__name__}."
- )
- return module.no_sync()
diff --git a/src/lightning/fabric/utilities/imports.py b/src/lightning/fabric/utilities/imports.py
index ac706cc403a5d..dd815d9dc8f5f 100644
--- a/src/lightning/fabric/utilities/imports.py
+++ b/src/lightning/fabric/utilities/imports.py
@@ -40,3 +40,22 @@
_TORCH_GREATER_EQUAL_2_6 = compare_version("torch", operator.ge, "2.6.0")
_TORCHMETRICS_GREATER_EQUAL_1_0_0 = compare_version("torchmetrics", operator.ge, "1.0.0")
_PYTHON_GREATER_EQUAL_3_10_0 = (sys.version_info.major, sys.version_info.minor) >= (3, 10)
+
+_WANDB_AVAILABLE = RequirementCache("wandb>=0.12.10")
+_COMET_AVAILABLE = RequirementCache("comet-ml>=3.44.4")
+_MLFLOW_AVAILABLE = RequirementCache("mlflow>=1.0.0")
+_MLFLOW_SYNCHRONOUS_AVAILABLE = RequirementCache("mlflow>=2.8.0")
+_NEPTUNE_AVAILABLE = RequirementCache("neptune>=1.0")
+
+_DEEPSPEED_AVAILABLE = RequirementCache("deepspeed")
+_DEEPSPEED_GREATER_EQUAL_0_16 = RequirementCache("deepspeed>=0.16.0")
+_ENTERPRISE_AVAILABLE = RequirementCache("pytorch_lightning_enterprise")
+_TRANSFORMER_ENGINE_AVAILABLE = RequirementCache("transformer_engine>=0.11.0")
+
+
+def _raise_enterprise_not_available() -> None:
+ if not _ENTERPRISE_AVAILABLE:
+ raise ModuleNotFoundError(
+ "pytorch_lightning_enterprise is required to use the XLA accelerator. "
+ "Install it with `pip install pytorch-lightning-enterprise`"
+ )
diff --git a/src/lightning/fabric/utilities/testing/_runif.py b/src/lightning/fabric/utilities/testing/_runif.py
index d085e4138d742..92521362a2885 100644
--- a/src/lightning/fabric/utilities/testing/_runif.py
+++ b/src/lightning/fabric/utilities/testing/_runif.py
@@ -23,8 +23,7 @@
from lightning.fabric.accelerators import XLAAccelerator
from lightning.fabric.accelerators.cuda import num_cuda_devices
from lightning.fabric.accelerators.mps import MPSAccelerator
-from lightning.fabric.strategies.deepspeed import _DEEPSPEED_AVAILABLE
-from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_4
+from lightning.fabric.utilities.imports import _DEEPSPEED_AVAILABLE, _TORCH_GREATER_EQUAL_2_4
def _runif_reasons(
diff --git a/src/lightning/pytorch/accelerators/xla.py b/src/lightning/pytorch/accelerators/xla.py
index 700a29c2e3e7b..c8b322f477d4d 100644
--- a/src/lightning/pytorch/accelerators/xla.py
+++ b/src/lightning/pytorch/accelerators/xla.py
@@ -39,15 +39,7 @@ def get_device_stats(self, device: _DEVICE) -> dict[str, Any]:
A dictionary mapping the metrics (free memory and peak memory) to their values.
"""
- import torch_xla.core.xla_model as xm
-
- memory_info = xm.get_memory_info(device)
- free_memory = memory_info["kb_free"]
- peak_memory = memory_info["kb_total"] - free_memory
- return {
- "avg. free memory (MB)": free_memory,
- "avg. peak memory (MB)": peak_memory,
- }
+ return self.accelerator_impl.get_device_stats(device)
@staticmethod
@override
diff --git a/src/lightning/pytorch/loggers/comet.py b/src/lightning/pytorch/loggers/comet.py
index b544212e755e2..5e57e38dc4072 100644
--- a/src/lightning/pytorch/loggers/comet.py
+++ b/src/lightning/pytorch/loggers/comet.py
@@ -17,18 +17,15 @@
"""
import logging
-import os
from argparse import Namespace
from collections.abc import Mapping
from typing import TYPE_CHECKING, Any, Literal, Optional, Union
-from lightning_utilities.core.imports import RequirementCache
from torch import Tensor
from torch.nn import Module
from typing_extensions import override
-from lightning.fabric.utilities.logger import _convert_params
-from lightning.fabric.utilities.rank_zero import _get_rank
+from lightning.fabric.utilities.imports import _raise_enterprise_not_available
from lightning.pytorch.loggers.logger import Logger, rank_zero_experiment
from lightning.pytorch.utilities.rank_zero import rank_zero_only
@@ -36,7 +33,6 @@
from comet_ml import ExistingExperiment, Experiment, OfflineExperiment
log = logging.getLogger(__name__)
-_COMET_AVAILABLE = RequirementCache("comet-ml>=3.44.4", module="comet_ml")
FRAMEWORK_NAME = "pytorch-lightning"
comet_experiment = Union["Experiment", "ExistingExperiment", "OfflineExperiment"]
@@ -208,99 +204,23 @@ def __init__(
prefix: Optional[str] = None,
**kwargs: Any,
):
- if not _COMET_AVAILABLE:
- raise ModuleNotFoundError(str(_COMET_AVAILABLE))
+ _raise_enterprise_not_available()
super().__init__()
- ##################################################
- # HANDLE PASSED OLD TYPE PARAMS
-
- # handle old "experiment_name" param
- if "experiment_name" in kwargs:
- log.warning("The parameter `experiment_name` is deprecated, please use `name` instead.")
- experiment_name = kwargs.pop("experiment_name")
-
- if "name" not in kwargs:
- kwargs["name"] = experiment_name
- else:
- log.warning("You specified both `experiment_name` and `name` parameters, please use `name` only")
-
- # handle old "project_name" param
- if "project_name" in kwargs:
- log.warning("The parameter `project_name` is deprecated, please use `project` instead.")
- if project is None:
- project = kwargs.pop("project_name")
- else:
- log.warning("You specified both `project_name` and `project` parameters, please use `project` only")
-
- # handle old "offline" experiment flag
- if "offline" in kwargs:
- log.warning("The parameter `offline is deprecated, please use `online` instead.")
- if online is None:
- online = kwargs.pop("offline")
- else:
- log.warning("You specified both `offline` and `online` parameters, please use `online` only")
-
- # handle old "save_dir" param
- if "save_dir" in kwargs:
- log.warning("The parameter `save_dir` is deprecated, please use `offline_directory` instead.")
- if "offline_directory" not in kwargs:
- kwargs["offline_directory"] = kwargs.pop("save_dir")
- else:
- log.warning(
- "You specified both `save_dir` and `offline_directory` parameters, "
- "please use `offline_directory` only"
- )
- ##################################################
-
- self._api_key: Optional[str] = api_key
- self._experiment: Optional[comet_experiment] = None
- self._workspace: Optional[str] = workspace
- self._mode: Optional[Literal["get_or_create", "get", "create"]] = mode
- self._online: Optional[bool] = online
- self._project_name: Optional[str] = project
- self._experiment_key: Optional[str] = experiment_key
- self._prefix: Optional[str] = prefix
- self._kwargs: dict[str, Any] = kwargs
-
- # needs to be set before the first `comet_ml` import
- # because comet_ml imported after another machine learning libraries (Torch)
- os.environ["COMET_DISABLE_AUTO_LOGGING"] = "1"
-
- import comet_ml
-
- config_kwargs = self._kwargs.copy()
- if online is False:
- config_kwargs["disabled"] = True
- self._comet_config = comet_ml.ExperimentConfig(**config_kwargs)
-
- # create real experiment only on main node/process (when strategy=auto/ddp)
- if _get_rank() is not None and _get_rank() != 0:
- return
-
- self._create_experiment()
-
- def _create_experiment(self) -> None:
- import comet_ml
-
- self._experiment = comet_ml.start(
- api_key=self._api_key,
- workspace=self._workspace,
- project=self._project_name,
- experiment_key=self._experiment_key,
- mode=self._mode,
- online=self._online,
- experiment_config=self._comet_config,
+ from pytorch_lightning_enterprise.loggers.comet import CometLogger as EnterpriseCometLogger
+
+ self.logger_impl = EnterpriseCometLogger(
+ api_key=api_key,
+ workspace=workspace,
+ project=project,
+ experiment_key=experiment_key,
+ mode=mode,
+ online=online,
+ prefix=prefix,
+ **kwargs,
)
- if self._experiment is None:
- raise comet_ml.exceptions.ExperimentNotFound("Failed to create Comet experiment.")
-
- self._experiment_key = self._experiment.get_key()
- self._project_name = self._experiment.project_name
- self._experiment.log_other("Created from", FRAMEWORK_NAME)
-
@property
@rank_zero_experiment
def experiment(self) -> comet_experiment:
@@ -313,55 +233,24 @@ def experiment(self) -> comet_experiment:
"""
- # if by some chance there is no experiment created yet (for example, when strategy=ddp_spawn)
- # then we will create a new one
- if not self._experiment:
- self._create_experiment()
-
- return self._experiment
+ return self.logger_impl.experiment
@override
@rank_zero_only
def log_hyperparams(self, params: Union[dict[str, Any], Namespace]) -> None:
- params = _convert_params(params)
- self.experiment.__internal_api__log_parameters__(
- parameters=params,
- framework=FRAMEWORK_NAME,
- flatten_nested=True,
- source="manual",
- )
+ return self.logger_impl.log_hyperparams(params)
@override
@rank_zero_only
def log_metrics(self, metrics: Mapping[str, Union[Tensor, float]], step: Optional[int] = None) -> None:
- assert rank_zero_only.rank == 0, "experiment tried to log from global_rank != 0"
- # Comet.com expects metrics to be a dictionary of detached tensors on CPU
- metrics_without_epoch = metrics.copy()
- for key, val in metrics_without_epoch.items():
- if isinstance(val, Tensor):
- metrics_without_epoch[key] = val.cpu().detach()
-
- epoch = metrics_without_epoch.pop("epoch", None)
- self.experiment.__internal_api__log_metrics__(
- metrics_without_epoch,
- step=step,
- epoch=epoch,
- prefix=self._prefix,
- framework=FRAMEWORK_NAME,
- )
+ return self.logger_impl.log_metrics(metrics, step)
@override
@rank_zero_only
def finalize(self, status: str) -> None:
"""We will not end experiment (will not call self._experiment.end()) here to have an ability to continue using
it after training is complete but instead of ending we will upload/save all the data."""
- if self._experiment is None:
- # When using multiprocessing, finalize() should be a no-op on the main process, as no experiment has been
- # initialized there
- return
-
- # just save the data
- self.experiment.flush()
+ return self.logger_impl.finalize(status)
@property
@override
@@ -372,7 +261,7 @@ def save_dir(self) -> Optional[str]:
The path to the save directory.
"""
- return self._comet_config.offline_directory
+ return self.logger_impl.save_dir
@property
@override
@@ -383,7 +272,7 @@ def name(self) -> Optional[str]:
The project name if it is specified.
"""
- return self._project_name
+ return self.logger_impl.name
@property
@override
@@ -395,27 +284,8 @@ def version(self) -> Optional[str]:
"""
# Don't create an experiment if we don't have one
- if self._experiment is not None:
- return self._experiment.get_key()
-
- def __getstate__(self) -> dict[str, Any]:
- state = self.__dict__.copy()
-
- # Save the experiment id in case an experiment object already exists,
- # this way we could create an ExistingExperiment pointing to the same
- # experiment
- state["_experiment_key"] = self._experiment.get_key() if self._experiment is not None else None
-
- # Remove the experiment object as it contains hard to pickle objects
- # (like network connections), the experiment object will be recreated if
- # needed later
- state["_experiment"] = None
- return state
+ return self.logger_impl.version
@override
def log_graph(self, model: Module, input_array: Optional[Tensor] = None) -> None:
- if self._experiment is not None:
- self._experiment.__internal_api__set_model_graph__(
- graph=model,
- framework=FRAMEWORK_NAME,
- )
+ return self.logger_impl.log_graph(model, input_array)
diff --git a/src/lightning/pytorch/loggers/mlflow.py b/src/lightning/pytorch/loggers/mlflow.py
index ff9b2b0d7e542..725fd2cb8bfbb 100644
--- a/src/lightning/pytorch/loggers/mlflow.py
+++ b/src/lightning/pytorch/loggers/mlflow.py
@@ -16,35 +16,21 @@
-------------
"""
-import logging
import os
-import re
-import tempfile
from argparse import Namespace
from collections.abc import Mapping
-from pathlib import Path
-from time import time
-from typing import TYPE_CHECKING, Any, Callable, Literal, Optional, Union
+from typing import TYPE_CHECKING, Any, Literal, Optional, Union
-import yaml
-from lightning_utilities.core.imports import RequirementCache
-from torch import Tensor
from typing_extensions import override
-from lightning.fabric.utilities.logger import _add_prefix, _convert_params, _flatten_dict
+from lightning.fabric.utilities.imports import _raise_enterprise_not_available
from lightning.pytorch.callbacks.model_checkpoint import ModelCheckpoint
from lightning.pytorch.loggers.logger import Logger, rank_zero_experiment
-from lightning.pytorch.loggers.utilities import _scan_checkpoints
-from lightning.pytorch.utilities.rank_zero import rank_zero_only, rank_zero_warn
+from lightning.pytorch.utilities.rank_zero import rank_zero_only
if TYPE_CHECKING:
from mlflow.tracking import MlflowClient
-log = logging.getLogger(__name__)
-LOCAL_FILE_URI_PREFIX = "file:"
-_MLFLOW_AVAILABLE = RequirementCache("mlflow>=1.0.0", "mlflow")
-_MLFLOW_SYNCHRONOUS_AVAILABLE = RequirementCache("mlflow>=2.8.0", "mlflow")
-
class MLFlowLogger(Logger):
"""Log using `MLflow `_.
@@ -126,31 +112,22 @@ def __init__(
run_id: Optional[str] = None,
synchronous: Optional[bool] = None,
):
- if not _MLFLOW_AVAILABLE:
- raise ModuleNotFoundError(str(_MLFLOW_AVAILABLE))
- if synchronous is not None and not _MLFLOW_SYNCHRONOUS_AVAILABLE:
- raise ModuleNotFoundError("`synchronous` requires mlflow>=2.8.0")
+ _raise_enterprise_not_available()
+ from pytorch_lightning_enterprise.loggers.mlflow import MLFlowLogger as EnterpriseMLFlowLogger
+
super().__init__()
- if not tracking_uri:
- tracking_uri = f"{LOCAL_FILE_URI_PREFIX}{save_dir}"
-
- self._experiment_name = experiment_name
- self._experiment_id: Optional[str] = None
- self._tracking_uri = tracking_uri
- self._run_name = run_name
- self._run_id = run_id
- self.tags = tags
- self._log_model = log_model
- self._logged_model_time: dict[str, float] = {}
- self._checkpoint_callback: Optional[ModelCheckpoint] = None
- self._prefix = prefix
- self._artifact_location = artifact_location
- self._log_batch_kwargs = {} if synchronous is None else {"synchronous": synchronous}
- self._initialized = False
-
- from mlflow.tracking import MlflowClient
-
- self._mlflow_client = MlflowClient(tracking_uri)
+ self.logger_impl = EnterpriseMLFlowLogger(
+ experiment_name=experiment_name,
+ run_name=run_name,
+ tracking_uri=tracking_uri,
+ tags=tags,
+ save_dir=save_dir,
+ log_model=log_model,
+ prefix=prefix,
+ artifact_location=artifact_location,
+ run_id=run_id,
+ synchronous=synchronous,
+ )
@property
@rank_zero_experiment
@@ -163,46 +140,7 @@ def experiment(self) -> "MlflowClient":
self.logger.experiment.some_mlflow_function()
"""
- import mlflow
-
- if self._initialized:
- return self._mlflow_client
-
- mlflow.set_tracking_uri(self._tracking_uri)
-
- if self._run_id is not None:
- run = self._mlflow_client.get_run(self._run_id)
- self._experiment_id = run.info.experiment_id
- self._initialized = True
- return self._mlflow_client
-
- if self._experiment_id is None:
- expt = self._mlflow_client.get_experiment_by_name(self._experiment_name)
- if expt is not None and expt.lifecycle_stage != "deleted":
- self._experiment_id = expt.experiment_id
- else:
- log.warning(f"Experiment with name {self._experiment_name} not found. Creating it.")
- self._experiment_id = self._mlflow_client.create_experiment(
- name=self._experiment_name, artifact_location=self._artifact_location
- )
-
- if self._run_id is None:
- if self._run_name is not None:
- self.tags = self.tags or {}
-
- from mlflow.utils.mlflow_tags import MLFLOW_RUN_NAME
-
- if MLFLOW_RUN_NAME in self.tags:
- log.warning(
- f"The tag {MLFLOW_RUN_NAME} is found in tags. The value will be overridden by {self._run_name}."
- )
- self.tags[MLFLOW_RUN_NAME] = self._run_name
-
- resolve_tags = _get_resolve_tags()
- run = self._mlflow_client.create_run(experiment_id=self._experiment_id, tags=resolve_tags(self.tags))
- self._run_id = run.info.run_id
- self._initialized = True
- return self._mlflow_client
+ return self.logger_impl.experiment
@property
def run_id(self) -> Optional[str]:
@@ -212,8 +150,7 @@ def run_id(self) -> Optional[str]:
The run id.
"""
- _ = self.experiment
- return self._run_id
+ return self.logger_impl.run_id
@property
def experiment_id(self) -> Optional[str]:
@@ -223,71 +160,22 @@ def experiment_id(self) -> Optional[str]:
The experiment id.
"""
- _ = self.experiment
- return self._experiment_id
+ return self.logger_impl.experiment_id
@override
@rank_zero_only
def log_hyperparams(self, params: Union[dict[str, Any], Namespace]) -> None:
- params = _convert_params(params)
- params = _flatten_dict(params)
-
- from mlflow.entities import Param
-
- # Truncate parameter values to 250 characters.
- # TODO: MLflow 1.28 allows up to 500 characters: https://github.com/mlflow/mlflow/releases/tag/v1.28.0
- params_list = [Param(key=k, value=str(v)[:250]) for k, v in params.items()]
-
- # Log in chunks of 100 parameters (the maximum allowed by MLflow).
- for idx in range(0, len(params_list), 100):
- self.experiment.log_batch(run_id=self.run_id, params=params_list[idx : idx + 100], **self._log_batch_kwargs)
+ return self.logger_impl.log_hyperparams(params)
@override
@rank_zero_only
def log_metrics(self, metrics: Mapping[str, float], step: Optional[int] = None) -> None:
- assert rank_zero_only.rank == 0, "experiment tried to log from global_rank != 0"
-
- from mlflow.entities import Metric
-
- metrics = _add_prefix(metrics, self._prefix, self.LOGGER_JOIN_CHAR)
- metrics_list: list[Metric] = []
-
- timestamp_ms = int(time() * 1000)
- for k, v in metrics.items():
- if isinstance(v, str):
- log.warning(f"Discarding metric with string value {k}={v}.")
- continue
-
- new_k = re.sub("[^a-zA-Z0-9_/. -]+", "", k)
- if k != new_k:
- rank_zero_warn(
- "MLFlow only allows '_', '/', '.' and ' ' special characters in metric name."
- f" Replacing {k} with {new_k}.",
- category=RuntimeWarning,
- )
- k = new_k
- metrics_list.append(Metric(key=k, value=v, timestamp=timestamp_ms, step=step or 0))
-
- self.experiment.log_batch(run_id=self.run_id, metrics=metrics_list, **self._log_batch_kwargs)
+ return self.logger_impl.log_metrics(metrics, step)
@override
@rank_zero_only
def finalize(self, status: str = "success") -> None:
- if not self._initialized:
- return
- if status == "success":
- status = "FINISHED"
- elif status == "failed":
- status = "FAILED"
- elif status == "finished":
- status = "FINISHED"
-
- # log checkpoints as artifacts
- if self._checkpoint_callback:
- self._scan_and_log_checkpoints(self._checkpoint_callback)
-
- if self.experiment.get_run(self.run_id):
- self.experiment.set_terminated(self.run_id, status)
+ return self.logger_impl.finalize(status)
@property
@override
@@ -299,9 +187,7 @@ def save_dir(self) -> Optional[str]:
Otherwise returns `None`.
"""
- if self._tracking_uri.startswith(LOCAL_FILE_URI_PREFIX):
- return self._tracking_uri[len(LOCAL_FILE_URI_PREFIX) :]
- return None
+ return self.logger_impl.save_dir
@property
@override
@@ -312,7 +198,7 @@ def name(self) -> Optional[str]:
The experiment id.
"""
- return self.experiment_id
+ return self.logger_impl.name
@property
@override
@@ -323,76 +209,8 @@ def version(self) -> Optional[str]:
The run id.
"""
- return self.run_id
+ return self.logger_impl.version
@override
def after_save_checkpoint(self, checkpoint_callback: ModelCheckpoint) -> None:
- # log checkpoints as artifacts
- if self._log_model == "all" or self._log_model is True and checkpoint_callback.save_top_k == -1:
- self._scan_and_log_checkpoints(checkpoint_callback)
- elif self._log_model is True:
- self._checkpoint_callback = checkpoint_callback
-
- def _scan_and_log_checkpoints(self, checkpoint_callback: ModelCheckpoint) -> None:
- # get checkpoints to be saved with associated score
- checkpoints = _scan_checkpoints(checkpoint_callback, self._logged_model_time)
-
- # log iteratively all new checkpoints
- for t, p, s, tag in checkpoints:
- metadata = {
- # Ensure .item() is called to store Tensor contents
- "score": s.item() if isinstance(s, Tensor) else s,
- "original_filename": Path(p).name,
- "Checkpoint": {
- k: getattr(checkpoint_callback, k)
- for k in [
- "monitor",
- "mode",
- "save_last",
- "save_top_k",
- "save_weights_only",
- "_every_n_train_steps",
- "_every_n_val_epochs",
- ]
- # ensure it does not break if `Checkpoint` args change
- if hasattr(checkpoint_callback, k)
- },
- }
- aliases = ["latest", "best"] if p == checkpoint_callback.best_model_path else ["latest"]
-
- # Artifact path on mlflow
- artifact_path = Path(p).stem
-
- # Log the checkpoint
- self.experiment.log_artifact(self._run_id, p, artifact_path)
-
- # Create a temporary directory to log on mlflow
- with tempfile.TemporaryDirectory() as tmp_dir:
- # Log the metadata
- with open(f"{tmp_dir}/metadata.yaml", "w") as tmp_file_metadata:
- yaml.dump(metadata, tmp_file_metadata, default_flow_style=False)
-
- # Log the aliases
- with open(f"{tmp_dir}/aliases.txt", "w") as tmp_file_aliases:
- tmp_file_aliases.write(str(aliases))
-
- # Log the metadata and aliases
- self.experiment.log_artifacts(self._run_id, tmp_dir, artifact_path)
-
- # remember logged models - timestamp needed in case filename didn't change (lastkckpt or custom name)
- self._logged_model_time[p] = t
-
-
-def _get_resolve_tags() -> Callable:
- from mlflow.tracking import context
-
- # before v1.1.0
- if hasattr(context, "resolve_tags"):
- from mlflow.tracking.context import resolve_tags
- # since v1.1.0
- elif hasattr(context, "registry"):
- from mlflow.tracking.context.registry import resolve_tags
- else:
- resolve_tags = lambda tags: tags
-
- return resolve_tags
+ return self.logger_impl.after_save_checkpoint(checkpoint_callback)
diff --git a/src/lightning/pytorch/loggers/neptune.py b/src/lightning/pytorch/loggers/neptune.py
index bf9669c824784..c2dd2a9b5a142 100644
--- a/src/lightning/pytorch/loggers/neptune.py
+++ b/src/lightning/pytorch/loggers/neptune.py
@@ -16,23 +16,17 @@
--------------
"""
-import contextlib
import logging
-import os
from argparse import Namespace
-from collections.abc import Generator
-from functools import wraps
-from typing import TYPE_CHECKING, Any, Callable, Optional, Union
+from typing import TYPE_CHECKING, Any, Optional, Union
-from lightning_utilities.core.imports import RequirementCache
from torch import Tensor
from typing_extensions import override
import lightning.pytorch as pl
-from lightning.fabric.utilities.logger import _add_prefix, _convert_params, _sanitize_callable_params
+from lightning.fabric.utilities.imports import _raise_enterprise_not_available
from lightning.pytorch.callbacks import Checkpoint
from lightning.pytorch.loggers.logger import Logger, rank_zero_experiment
-from lightning.pytorch.utilities.model_summary import ModelSummary
from lightning.pytorch.utilities.rank_zero import rank_zero_only
if TYPE_CHECKING:
@@ -41,27 +35,6 @@
log = logging.getLogger(__name__)
-# Neptune is available with two names on PyPI : `neptune` and `neptune-client`
-# `neptune` was introduced as a name transition of neptune-client and the long-term target is to get
-# rid of Neptune-client package completely someday. It was introduced as a part of breaking-changes with a release
-# of neptune-client==1.0. neptune-client>=1.0 is just an alias of neptune package and have some breaking-changes
-# in compare to neptune-client<1.0.0.
-_NEPTUNE_AVAILABLE = RequirementCache("neptune>=1.0")
-_INTEGRATION_VERSION_KEY = "source_code/integrations/pytorch-lightning"
-
-
-# Neptune client throws `InactiveRunException` when trying to log to an inactive run.
-# This may happen when the run was stopped through the UI and the logger is still trying to log to it.
-def _catch_inactive(func: Callable) -> Callable:
- @wraps(func)
- def wrapper(*args: Any, **kwargs: Any) -> Any:
- from neptune.exceptions import InactiveRunException
-
- with contextlib.suppress(InactiveRunException):
- return func(*args, **kwargs)
-
- return wrapper
-
class NeptuneLogger(Logger):
r"""Log using `Neptune `_.
@@ -242,113 +215,19 @@ def __init__(
prefix: str = "training",
**neptune_run_kwargs: Any,
):
- if not _NEPTUNE_AVAILABLE:
- raise ModuleNotFoundError(str(_NEPTUNE_AVAILABLE))
-
- # verify if user passed proper init arguments
- self._verify_input_arguments(api_key, project, name, run, neptune_run_kwargs)
+ _raise_enterprise_not_available()
super().__init__()
- self._log_model_checkpoints = log_model_checkpoints
- self._prefix = prefix
- self._run_name = name
- self._project_name = project
- self._api_key = api_key
- self._run_instance = run
- self._neptune_run_kwargs = neptune_run_kwargs
- self._run_short_id: Optional[str] = None
-
- if self._run_instance is not None:
- self._retrieve_run_data()
-
- from neptune.handler import Handler
-
- # make sure that we've log integration version for outside `Run` instances
- root_obj = self._run_instance
- if isinstance(root_obj, Handler):
- root_obj = root_obj.get_root_object()
-
- root_obj[_INTEGRATION_VERSION_KEY] = pl.__version__
-
- def _retrieve_run_data(self) -> None:
- from neptune.handler import Handler
-
- assert self._run_instance is not None
- root_obj = self._run_instance
- if isinstance(root_obj, Handler):
- root_obj = root_obj.get_root_object()
-
- root_obj.wait()
-
- if root_obj.exists("sys/id"):
- self._run_short_id = root_obj["sys/id"].fetch()
- self._run_name = root_obj["sys/name"].fetch()
- else:
- self._run_short_id = "OFFLINE"
- self._run_name = "offline-name"
-
- @property
- def _neptune_init_args(self) -> dict:
- args: dict = {}
- # Backward compatibility in case of previous version retrieval
- with contextlib.suppress(AttributeError):
- args = self._neptune_run_kwargs
-
- if self._project_name is not None:
- args["project"] = self._project_name
-
- if self._api_key is not None:
- args["api_token"] = self._api_key
-
- if self._run_short_id is not None:
- args["run"] = self._run_short_id
-
- # Backward compatibility in case of previous version retrieval
- with contextlib.suppress(AttributeError):
- if self._run_name is not None:
- args["name"] = self._run_name
-
- return args
-
- def _construct_path_with_prefix(self, *keys: str) -> str:
- """Return sequence of keys joined by `LOGGER_JOIN_CHAR`, started with `_prefix` if defined."""
- if self._prefix:
- return self.LOGGER_JOIN_CHAR.join([self._prefix, *keys])
- return self.LOGGER_JOIN_CHAR.join(keys)
-
- @staticmethod
- def _verify_input_arguments(
- api_key: Optional[str],
- project: Optional[str],
- name: Optional[str],
- run: Optional[Union["Run", "Handler"]],
- neptune_run_kwargs: dict,
- ) -> None:
- from neptune import Run
- from neptune.handler import Handler
-
- # check if user passed the client `Run`/`Handler` object
- if run is not None and not isinstance(run, (Run, Handler)):
- raise ValueError("Run parameter expected to be of type `neptune.Run`, or `neptune.handler.Handler`.")
-
- # check if user passed redundant neptune.init_run arguments when passed run
- any_neptune_init_arg_passed = any(arg is not None for arg in [api_key, project, name]) or neptune_run_kwargs
- if run is not None and any_neptune_init_arg_passed:
- raise ValueError(
- "When an already initialized run object is provided, you can't provide other `neptune.init_run()`"
- " parameters."
- )
-
- def __getstate__(self) -> dict[str, Any]:
- state = self.__dict__.copy()
- # Run instance can't be pickled
- state["_run_instance"] = None
- return state
-
- def __setstate__(self, state: dict[str, Any]) -> None:
- import neptune
-
- self.__dict__ = state
- self._run_instance = neptune.init_run(**self._neptune_init_args)
+ from pytorch_lightning_enterprise.loggers.neptune import NeptuneLogger as EnterpriseNeptuneLogger
+
+ self.logger_impl = EnterpriseNeptuneLogger(
+ api_key=api_key,
+ project=project,
+ name=name,
+ run=run,
+ log_model_checkpoints=log_model_checkpoints,
+ prefix=prefix,
+ **neptune_run_kwargs,
+ )
@property
@rank_zero_experiment
@@ -378,72 +257,20 @@ def training_step(self, batch, batch_idx):
with NeptuneLogger.
"""
- return self.run
+ return self.logger_impl.experiment
@property
@rank_zero_experiment
def run(self) -> "Run":
- import neptune
-
- if not self._run_instance:
- self._run_instance = neptune.init_run(**self._neptune_init_args)
- self._retrieve_run_data()
- # make sure that we've log integration version for newly created
- self._run_instance[_INTEGRATION_VERSION_KEY] = pl.__version__
-
- return self._run_instance
+ return self.logger_impl.run
@override
@rank_zero_only
- @_catch_inactive
def log_hyperparams(self, params: Union[dict[str, Any], Namespace]) -> None:
- r"""Log hyperparameters to the run.
-
- Hyperparameters will be logged under the "/hyperparams" namespace.
-
- Note:
-
- You can also log parameters by directly using the logger instance:
- ``neptune_logger.experiment["model/hyper-parameters"] = params_dict``.
-
- In this way you can keep hierarchical structure of the parameters.
-
- Args:
- params: `dict`.
- Python dictionary structure with parameters.
-
- Example::
-
- from lightning.pytorch.loggers import NeptuneLogger
- import neptune
-
- PARAMS = {
- "batch_size": 64,
- "lr": 0.07,
- "decay_factor": 0.97,
- }
-
- neptune_logger = NeptuneLogger(
- api_key=neptune.ANONYMOUS_API_TOKEN,
- project="common/pytorch-lightning-integration"
- )
-
- neptune_logger.log_hyperparams(PARAMS)
-
- """
- from neptune.utils import stringify_unsupported
-
- params = _convert_params(params)
- params = _sanitize_callable_params(params)
-
- parameters_key = self.PARAMETERS_KEY
- parameters_key = self._construct_path_with_prefix(parameters_key)
-
- self.run[parameters_key] = stringify_unsupported(params)
+ return self.logger_impl.log_hyperparams(params)
@override
@rank_zero_only
- @_catch_inactive
def log_metrics(self, metrics: dict[str, Union[Tensor, float]], step: Optional[int] = None) -> None:
"""Log metrics (numeric values) in Neptune runs.
@@ -452,26 +279,12 @@ def log_metrics(self, metrics: dict[str, Union[Tensor, float]], step: Optional[i
step: Step number at which the metrics should be recorded
"""
- if rank_zero_only.rank != 0:
- raise ValueError("run tried to log from global_rank != 0")
-
- metrics = _add_prefix(metrics, self._prefix, self.LOGGER_JOIN_CHAR)
-
- for key, val in metrics.items():
- self.run[key].append(val, step=step)
+ return self.logger_impl.log_metrics(metrics, step)
@override
@rank_zero_only
- @_catch_inactive
def finalize(self, status: str) -> None:
- if not self._run_instance:
- # When using multiprocessing, finalize() should be a no-op on the main process, as no experiment has been
- # initialized there
- return
- if status:
- self.run[self._construct_path_with_prefix("status")] = status
-
- super().finalize(status)
+ return self.logger_impl.finalize(status)
@property
@override
@@ -483,21 +296,14 @@ def save_dir(self) -> Optional[str]:
the root directory where experiment logs get saved
"""
- return os.path.join(os.getcwd(), ".neptune")
+ return self.logger_impl.save_dir
@rank_zero_only
- @_catch_inactive
def log_model_summary(self, model: "pl.LightningModule", max_depth: int = -1) -> None:
- from neptune.types import File
-
- model_str = str(ModelSummary(model=model, max_depth=max_depth))
- self.run[self._construct_path_with_prefix("model/summary")] = File.from_content(
- content=model_str, extension="txt"
- )
+ return self.logger_impl.log_model_summary(model, max_depth)
@override
@rank_zero_only
- @_catch_inactive
def after_save_checkpoint(self, checkpoint_callback: Checkpoint) -> None:
"""Automatically log checkpointed model. Called after model checkpoint callback saves a new checkpoint.
@@ -505,83 +311,13 @@ def after_save_checkpoint(self, checkpoint_callback: Checkpoint) -> None:
checkpoint_callback: the model checkpoint callback instance
"""
- if not self._log_model_checkpoints:
- return
-
- file_names = set()
- checkpoints_namespace = self._construct_path_with_prefix("model/checkpoints")
-
- # save last model
- if hasattr(checkpoint_callback, "last_model_path") and checkpoint_callback.last_model_path:
- model_last_name = self._get_full_model_name(checkpoint_callback.last_model_path, checkpoint_callback)
- file_names.add(model_last_name)
- self.run[f"{checkpoints_namespace}/{model_last_name}"].upload(checkpoint_callback.last_model_path)
-
- # save best k models
- if hasattr(checkpoint_callback, "best_k_models"):
- for key in checkpoint_callback.best_k_models:
- model_name = self._get_full_model_name(key, checkpoint_callback)
- file_names.add(model_name)
- self.run[f"{checkpoints_namespace}/{model_name}"].upload(key)
-
- # log best model path and checkpoint
- if hasattr(checkpoint_callback, "best_model_path") and checkpoint_callback.best_model_path:
- self.run[self._construct_path_with_prefix("model/best_model_path")] = checkpoint_callback.best_model_path
-
- model_name = self._get_full_model_name(checkpoint_callback.best_model_path, checkpoint_callback)
- file_names.add(model_name)
- self.run[f"{checkpoints_namespace}/{model_name}"].upload(checkpoint_callback.best_model_path)
-
- # remove old models logged to experiment if they are not part of best k models at this point
- if self.run.exists(checkpoints_namespace):
- exp_structure = self.run.get_structure()
- uploaded_model_names = self._get_full_model_names_from_exp_structure(exp_structure, checkpoints_namespace)
-
- for file_to_drop in list(uploaded_model_names - file_names):
- del self.run[f"{checkpoints_namespace}/{file_to_drop}"]
-
- # log best model score
- if hasattr(checkpoint_callback, "best_model_score") and checkpoint_callback.best_model_score:
- self.run[self._construct_path_with_prefix("model/best_model_score")] = (
- checkpoint_callback.best_model_score.cpu().detach().numpy()
- )
-
- @staticmethod
- def _get_full_model_name(model_path: str, checkpoint_callback: Checkpoint) -> str:
- """Returns model name which is string `model_path` appended to `checkpoint_callback.dirpath`."""
- if hasattr(checkpoint_callback, "dirpath"):
- model_path = os.path.normpath(model_path)
- expected_model_path = os.path.normpath(checkpoint_callback.dirpath)
- if not model_path.startswith(expected_model_path):
- raise ValueError(f"{model_path} was expected to start with {expected_model_path}.")
- # Remove extension from filepath
- filepath, _ = os.path.splitext(model_path[len(expected_model_path) + 1 :])
- return filepath.replace(os.sep, "/")
- return model_path.replace(os.sep, "/")
-
- @classmethod
- def _get_full_model_names_from_exp_structure(cls, exp_structure: dict[str, Any], namespace: str) -> set[str]:
- """Returns all paths to properties which were already logged in `namespace`"""
- structure_keys: list[str] = namespace.split(cls.LOGGER_JOIN_CHAR)
- for key in structure_keys:
- exp_structure = exp_structure[key]
- uploaded_models_dict = exp_structure
- return set(cls._dict_paths(uploaded_models_dict))
-
- @classmethod
- def _dict_paths(cls, d: dict[str, Any], path_in_build: Optional[str] = None) -> Generator:
- for k, v in d.items():
- path = f"{path_in_build}/{k}" if path_in_build is not None else k
- if not isinstance(v, dict):
- yield path
- else:
- yield from cls._dict_paths(v, path)
+ return self.logger_impl.after_save_checkpoint(checkpoint_callback)
@property
@override
def name(self) -> Optional[str]:
"""Return the experiment name or 'offline-name' when exp is run in offline mode."""
- return self._run_name
+ return self.logger_impl.name
@property
@override
@@ -591,4 +327,4 @@ def version(self) -> Optional[str]:
It's Neptune Run's short_id
"""
- return self._run_short_id
+ return self.logger_impl.version
diff --git a/src/lightning/pytorch/loggers/wandb.py b/src/lightning/pytorch/loggers/wandb.py
index 37ca362fa40c1..41d38b460d0f0 100644
--- a/src/lightning/pytorch/loggers/wandb.py
+++ b/src/lightning/pytorch/loggers/wandb.py
@@ -16,37 +16,24 @@
-------------------------
"""
-import os
from argparse import Namespace
from collections.abc import Mapping
-from pathlib import Path
from typing import TYPE_CHECKING, Any, Literal, Optional, Union
import torch.nn as nn
-from lightning_utilities.core.imports import RequirementCache
-from torch import Tensor
from typing_extensions import override
-from lightning.fabric.utilities.logger import (
- _add_prefix,
- _convert_json_serializable,
- _convert_params,
- _sanitize_callable_params,
-)
+from lightning.fabric.utilities.imports import _raise_enterprise_not_available
from lightning.fabric.utilities.types import _PATH
from lightning.pytorch.callbacks.model_checkpoint import ModelCheckpoint
from lightning.pytorch.loggers.logger import Logger, rank_zero_experiment
-from lightning.pytorch.loggers.utilities import _scan_checkpoints
-from lightning.pytorch.utilities.exceptions import MisconfigurationException
-from lightning.pytorch.utilities.rank_zero import rank_zero_only, rank_zero_warn
+from lightning.pytorch.utilities.rank_zero import rank_zero_only
if TYPE_CHECKING:
from wandb import Artifact
from wandb.sdk.lib import RunDisabled
from wandb.wandb_run import Run
-_WANDB_AVAILABLE = RequirementCache("wandb>=0.12.10")
-
class WandbLogger(Logger):
r"""Log using `Weights and Biases `_.
@@ -308,70 +295,27 @@ def __init__(
add_file_policy: Literal["mutable", "immutable"] = "mutable",
**kwargs: Any,
) -> None:
- if not _WANDB_AVAILABLE:
- raise ModuleNotFoundError(str(_WANDB_AVAILABLE))
-
- if offline and log_model:
- raise MisconfigurationException(
- f"Providing log_model={log_model} and offline={offline} is an invalid configuration"
- " since model checkpoints cannot be uploaded in offline mode.\n"
- "Hint: Set `offline=False` to log your model."
- )
+ _raise_enterprise_not_available()
super().__init__()
- self._offline = offline
- self._log_model = log_model
- self._prefix = prefix
- self._experiment = experiment
- self._logged_model_time: dict[str, float] = {}
- self._checkpoint_callbacks: dict[int, ModelCheckpoint] = {}
- self.add_file_policy = add_file_policy
-
- # paths are processed as strings
- if save_dir is not None:
- save_dir = os.fspath(save_dir)
- elif dir is not None:
- dir = os.fspath(dir)
-
- project = project or os.environ.get("WANDB_PROJECT", "lightning_logs")
-
- # set wandb init arguments
- self._wandb_init: dict[str, Any] = {
- "name": name,
- "project": project,
- "dir": save_dir or dir,
- "id": version or id,
- "resume": "allow",
- "anonymous": ("allow" if anonymous else None),
- }
- self._wandb_init.update(**kwargs)
- # extract parameters
- self._project = self._wandb_init.get("project")
- self._save_dir = self._wandb_init.get("dir")
- self._name = self._wandb_init.get("name")
- self._id = self._wandb_init.get("id")
- self._checkpoint_name = checkpoint_name
-
- def __getstate__(self) -> dict[str, Any]:
- import wandb
-
- # Hack: If the 'spawn' launch method is used, the logger will get pickled and this `__getstate__` gets called.
- # We create an experiment here in the main process, and attach to it in the worker process.
- # Using wandb-service, we persist the same experiment even if multiple `Trainer.fit/test/validate` calls
- # are made.
- wandb.require("service")
- _ = self.experiment
-
- state = self.__dict__.copy()
- # args needed to reload correct experiment
- if self._experiment is not None:
- state["_id"] = getattr(self._experiment, "id", None)
- state["_attach_id"] = getattr(self._experiment, "_attach_id", None)
- state["_name"] = self._experiment.name
-
- # cannot be pickled
- state["_experiment"] = None
- return state
+ from pytorch_lightning_enterprise.loggers.wandb import WandbLogger as EnterpriseWandbLogger
+
+ self.logger_impl = EnterpriseWandbLogger(
+ name=name,
+ save_dir=save_dir,
+ version=version,
+ offline=offline,
+ dir=dir,
+ id=id,
+ anonymous=anonymous,
+ project=project,
+ log_model=log_model,
+ experiment=experiment,
+ prefix=prefix,
+ checkpoint_name=checkpoint_name,
+ add_file_policy=add_file_policy,
+ **kwargs,
+ )
@property
@rank_zero_experiment
@@ -386,40 +330,7 @@ def experiment(self) -> Union["Run", "RunDisabled"]:
self.logger.experiment.some_wandb_function()
"""
- import wandb
- from wandb.sdk.lib import RunDisabled
- from wandb.wandb_run import Run
-
- if self._experiment is None:
- if self._offline:
- os.environ["WANDB_MODE"] = "dryrun"
-
- attach_id = getattr(self, "_attach_id", None)
- if wandb.run is not None:
- # wandb process already created in this instance
- rank_zero_warn(
- "There is a wandb run already in progress and newly created instances of `WandbLogger` will reuse"
- " this run. If this is not desired, call `wandb.finish()` before instantiating `WandbLogger`."
- )
- self._experiment = wandb.run
- elif attach_id is not None and hasattr(wandb, "_attach"):
- # attach to wandb process referenced
- self._experiment = wandb._attach(attach_id)
- else:
- # create new wandb process
- self._experiment = wandb.init(**self._wandb_init)
-
- # define default x-axis
- if isinstance(self._experiment, (Run, RunDisabled)) and getattr(
- self._experiment, "define_metric", None
- ):
- if self._wandb_init.get("sync_tensorboard"):
- self._experiment.define_metric("*", step_metric="global_step")
- else:
- self._experiment.define_metric("trainer/global_step")
- self._experiment.define_metric("*", step_metric="trainer/global_step", step_sync=True)
-
- return self._experiment
+ return self.logger_impl.experiment
def watch(
self, model: nn.Module, log: Optional[str] = "gradients", log_freq: int = 100, log_graph: bool = True
@@ -429,21 +340,12 @@ def watch(
@override
@rank_zero_only
def log_hyperparams(self, params: Union[dict[str, Any], Namespace]) -> None:
- params = _convert_params(params)
- params = _sanitize_callable_params(params)
- params = _convert_json_serializable(params)
- self.experiment.config.update(params, allow_val_change=True)
+ return self.logger_impl.log_hyperparams(params)
@override
@rank_zero_only
def log_metrics(self, metrics: Mapping[str, float], step: Optional[int] = None) -> None:
- assert rank_zero_only.rank == 0, "experiment tried to log from global_rank != 0"
-
- metrics = _add_prefix(metrics, self._prefix, self.LOGGER_JOIN_CHAR)
- if step is not None and not self._wandb_init.get("sync_tensorboard"):
- self.experiment.log(dict(metrics, **{"trainer/global_step": step}))
- else:
- self.experiment.log(metrics)
+ return self.logger_impl.log_metrics(metrics, step)
@rank_zero_only
def log_table(
@@ -459,10 +361,7 @@ def log_table(
Can be defined either with `columns` and `data` or with `dataframe`.
"""
- import wandb
-
- metrics = {key: wandb.Table(columns=columns, data=data, dataframe=dataframe)}
- self.log_metrics(metrics, step)
+ return self.logger_impl.log_table(key, columns, data, dataframe, step)
@rank_zero_only
def log_text(
@@ -478,8 +377,7 @@ def log_text(
Can be defined either with `columns` and `data` or with `dataframe`.
"""
-
- self.log_table(key, columns, data, dataframe, step)
+ return self.logger_impl.log_text(key, columns, data, dataframe, step)
@rank_zero_only
def log_image(self, key: str, images: list[Any], step: Optional[int] = None, **kwargs: Any) -> None:
@@ -488,18 +386,7 @@ def log_image(self, key: str, images: list[Any], step: Optional[int] = None, **k
Optional kwargs are lists passed to each image (ex: caption, masks, boxes).
"""
- if not isinstance(images, list):
- raise TypeError(f'Expected a list as "images", found {type(images)}')
- n = len(images)
- for k, v in kwargs.items():
- if len(v) != n:
- raise ValueError(f"Expected {n} items but only found {len(v)} for {k}")
- kwarg_list = [{k: kwargs[k][i] for k in kwargs} for i in range(n)]
-
- import wandb
-
- metrics = {key: [wandb.Image(img, **kwarg) for img, kwarg in zip(images, kwarg_list)]}
- self.log_metrics(metrics, step) # type: ignore[arg-type]
+ return self.logger_impl.log_image(key, images, step, **kwargs)
@rank_zero_only
def log_audio(self, key: str, audios: list[Any], step: Optional[int] = None, **kwargs: Any) -> None:
@@ -514,18 +401,7 @@ def log_audio(self, key: str, audios: list[Any], step: Optional[int] = None, **k
Optional kwargs are lists passed to each audio (ex: caption, sample_rate).
"""
- if not isinstance(audios, list):
- raise TypeError(f'Expected a list as "audios", found {type(audios)}')
- n = len(audios)
- for k, v in kwargs.items():
- if len(v) != n:
- raise ValueError(f"Expected {n} items but only found {len(v)} for {k}")
- kwarg_list = [{k: kwargs[k][i] for k in kwargs} for i in range(n)]
-
- import wandb
-
- metrics = {key: [wandb.Audio(audio, **kwarg) for audio, kwarg in zip(audios, kwarg_list)]}
- self.log_metrics(metrics, step) # type: ignore[arg-type]
+ return self.logger_impl.log_audio(key, audios, step, **kwargs)
@rank_zero_only
def log_video(self, key: str, videos: list[Any], step: Optional[int] = None, **kwargs: Any) -> None:
@@ -540,18 +416,7 @@ def log_video(self, key: str, videos: list[Any], step: Optional[int] = None, **k
Optional kwargs are lists passed to each video (ex: caption, fps, format).
"""
- if not isinstance(videos, list):
- raise TypeError(f'Expected a list as "videos", found {type(videos)}')
- n = len(videos)
- for k, v in kwargs.items():
- if len(v) != n:
- raise ValueError(f"Expected {n} items but only found {len(v)} for {k}")
- kwarg_list = [{k: kwargs[k][i] for k in kwargs} for i in range(n)]
-
- import wandb
-
- metrics = {key: [wandb.Video(video, **kwarg) for video, kwarg in zip(videos, kwarg_list)]}
- self.log_metrics(metrics, step) # type: ignore[arg-type]
+ return self.logger_impl.log_video(key, videos, step, **kwargs)
@property
@override
@@ -562,7 +427,7 @@ def save_dir(self) -> Optional[str]:
The path to the save directory.
"""
- return self._save_dir
+ return self.logger_impl.save_dir
@property
@override
@@ -574,7 +439,7 @@ def name(self) -> Optional[str]:
name. To access wandb's internal experiment name, use ``logger.experiment.name`` instead.
"""
- return self._project
+ return self.logger_impl.name
@property
@override
@@ -586,15 +451,12 @@ def version(self) -> Optional[str]:
"""
# don't create an experiment if we don't have one
- return self._experiment.id if self._experiment else self._id
+ return self.logger_impl.version
@override
def after_save_checkpoint(self, checkpoint_callback: ModelCheckpoint) -> None:
# log checkpoints as artifacts
- if self._log_model == "all" or self._log_model is True and checkpoint_callback.save_top_k == -1:
- self._scan_and_log_checkpoints(checkpoint_callback)
- elif self._log_model is True:
- self._checkpoint_callbacks[id(checkpoint_callback)] = checkpoint_callback
+ return self.logger_impl.after_save_checkpoint(checkpoint_callback)
@staticmethod
@rank_zero_only
@@ -616,16 +478,10 @@ def download_artifact(
The path to the downloaded artifact.
"""
- import wandb
-
- if wandb.run is not None and use_artifact:
- artifact = wandb.run.use_artifact(artifact)
- else:
- api = wandb.Api()
- artifact = api.artifact(artifact, type=artifact_type)
+ _raise_enterprise_not_available()
+ from pytorch_lightning_enterprise.loggers.wandb import WandbLogger as EnterpriseWandbLogger
- save_dir = None if save_dir is None else os.fspath(save_dir)
- return artifact.download(root=save_dir)
+ return EnterpriseWandbLogger.download_artifact(artifact, save_dir, artifact_type, use_artifact)
def use_artifact(self, artifact: str, artifact_type: Optional[str] = None) -> "Artifact":
"""Logs to the wandb dashboard that the mentioned artifact is used by the run.
@@ -638,49 +494,9 @@ def use_artifact(self, artifact: str, artifact_type: Optional[str] = None) -> "A
wandb Artifact object for the artifact.
"""
- return self.experiment.use_artifact(artifact, type=artifact_type)
+ return self.logger_impl.use_artifact(artifact, artifact_type)
@override
@rank_zero_only
def finalize(self, status: str) -> None:
- if status != "success":
- # Currently, checkpoints only get logged on success
- return
- # log checkpoints as artifacts
- if self._experiment is not None:
- for checkpoint_callback in self._checkpoint_callbacks.values():
- self._scan_and_log_checkpoints(checkpoint_callback)
-
- def _scan_and_log_checkpoints(self, checkpoint_callback: ModelCheckpoint) -> None:
- import wandb
-
- # get checkpoints to be saved with associated score
- checkpoints = _scan_checkpoints(checkpoint_callback, self._logged_model_time)
-
- # log iteratively all new checkpoints
- for t, p, s, tag in checkpoints:
- metadata = {
- "score": s.item() if isinstance(s, Tensor) else s,
- "original_filename": Path(p).name,
- checkpoint_callback.__class__.__name__: {
- k: getattr(checkpoint_callback, k)
- for k in [
- "monitor",
- "mode",
- "save_last",
- "save_top_k",
- "save_weights_only",
- "_every_n_train_steps",
- ]
- # ensure it does not break if `ModelCheckpoint` args change
- if hasattr(checkpoint_callback, k)
- },
- }
- if not self._checkpoint_name:
- self._checkpoint_name = f"model-{self.experiment.id}"
- artifact = wandb.Artifact(name=self._checkpoint_name, type="model", metadata=metadata)
- artifact.add_file(p, name="model.ckpt", policy=self.add_file_policy)
- aliases = ["latest", "best"] if p == checkpoint_callback.best_model_path else ["latest"]
- self.experiment.log_artifact(artifact, aliases=aliases)
- # remember logged models - timestamp needed in case filename didn't change (lastkckpt or custom name)
- self._logged_model_time[p] = t
+ return self.logger_impl.finalize(status)
diff --git a/src/lightning/pytorch/plugins/precision/deepspeed.py b/src/lightning/pytorch/plugins/precision/deepspeed.py
index 9225e3bb9e7be..98580ee61e112 100644
--- a/src/lightning/pytorch/plugins/precision/deepspeed.py
+++ b/src/lightning/pytorch/plugins/precision/deepspeed.py
@@ -11,30 +11,21 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from contextlib import AbstractContextManager, nullcontext
-from typing import TYPE_CHECKING, Any, Callable, Optional, Union
+from contextlib import AbstractContextManager
+from typing import Any, Callable, Optional, Union
import torch
-from lightning_utilities import apply_to_collection
from torch import Tensor
from torch.nn import Module
-from torch.optim import LBFGS, Optimizer
-from typing_extensions import get_args, override
+from torch.optim import Optimizer
+from typing_extensions import override
import lightning.pytorch as pl
from lightning.fabric.plugins.precision.deepspeed import _PRECISION_INPUT
-from lightning.fabric.plugins.precision.utils import _convert_fp_tensor, _DtypeContextManager
+from lightning.fabric.utilities.imports import _raise_enterprise_not_available
from lightning.fabric.utilities.types import Steppable
from lightning.pytorch.plugins.precision.precision import Precision
from lightning.pytorch.utilities import GradClipAlgorithmType
-from lightning.pytorch.utilities.exceptions import MisconfigurationException
-from lightning.pytorch.utilities.model_helpers import is_overridden
-from lightning.pytorch.utilities.rank_zero import WarningCache
-
-if TYPE_CHECKING:
- import deepspeed
-
-warning_cache = WarningCache()
class DeepSpeedPrecision(Precision):
@@ -53,41 +44,29 @@ class DeepSpeedPrecision(Precision):
"""
def __init__(self, precision: _PRECISION_INPUT) -> None:
- supported_precision = get_args(_PRECISION_INPUT)
- if precision not in supported_precision:
- raise ValueError(
- f"`Trainer(strategy='deepspeed', precision={precision!r})` is not supported."
- f" `precision` must be one of: {supported_precision}."
- )
- self.precision = precision
- precision_to_type = {
- "bf16-mixed": torch.bfloat16,
- "16-mixed": torch.float16,
- "bf16-true": torch.bfloat16,
- "16-true": torch.float16,
- "32-true": torch.float32,
- }
- self._desired_dtype = precision_to_type[self.precision]
+ super().__init__()
+ _raise_enterprise_not_available()
+ from pytorch_lightning_enterprise.plugins.precision.deepspeed import (
+ DeepSpeedPrecisionTrainer as EnterpriseDeepSpeedPrecision,
+ )
+
+ self.deepspeed_precision_impl = EnterpriseDeepSpeedPrecision(outer_object=self, precision=precision)
@override
def convert_module(self, module: Module) -> Module:
- if "true" in self.precision:
- return module.to(dtype=self._desired_dtype)
- return module
+ return self.deepspeed_precision_impl.convert_module(module=module)
@override
def convert_input(self, data: Any) -> Any:
- return apply_to_collection(data, function=_convert_fp_tensor, dtype=Tensor, dst_type=self._desired_dtype)
+ return self.deepspeed_precision_impl.convert_input(data=data)
@override
def tensor_init_context(self) -> AbstractContextManager:
- if "true" not in self.precision:
- return nullcontext()
- return _DtypeContextManager(self._desired_dtype)
+ return self.deepspeed_precision_impl.tensor_init_context()
@override
def module_init_context(self) -> AbstractContextManager:
- return self.tensor_init_context()
+ return self.deepspeed_precision_impl.module_init_context()
@override
def backward( # type: ignore[override]
@@ -98,7 +77,7 @@ def backward( # type: ignore[override]
*args: Any,
**kwargs: Any,
) -> None:
- r"""Performs back-propagation using DeepSpeed's engine.
+ r"""Performs back-propagation.
Args:
tensor: the loss tensor
@@ -108,13 +87,7 @@ def backward( # type: ignore[override]
\**kwargs: additional keyword arguments for the :meth:`deepspeed.DeepSpeedEngine.backward` call
"""
- if is_overridden("backward", model):
- warning_cache.warn(
- "You have overridden the `LightningModule.backward` hook but it will be ignored since DeepSpeed handles"
- " the backward logic internally."
- )
- deepspeed_engine: deepspeed.DeepSpeedEngine = model.trainer.model
- deepspeed_engine.backward(tensor, *args, **kwargs)
+ return self.deepspeed_precision_impl.backward(tensor=tensor, model=model, optimizer=optimizer, *args, **kwargs)
@override
def optimizer_step( # type: ignore[override]
@@ -124,19 +97,7 @@ def optimizer_step( # type: ignore[override]
closure: Callable[[], Any],
**kwargs: Any,
) -> Any:
- if isinstance(optimizer, LBFGS):
- raise MisconfigurationException("DeepSpeed and the LBFGS optimizer are not compatible.")
- closure_result = closure()
- self._after_closure(model, optimizer)
- skipped_backward = closure_result is None
- # in manual optimization, the closure does not return a value
- if model.automatic_optimization and skipped_backward:
- raise MisconfigurationException(
- "Skipping backward by returning `None` from your `training_step` is not supported by `DeepSpeed`"
- )
- # DeepSpeed handles the optimizer step internally
- deepspeed_engine: deepspeed.DeepSpeedEngine = model.trainer.model
- return deepspeed_engine.step(**kwargs)
+ return self.deepspeed_precision_impl.optimizer_step(optimizer=optimizer, model=model, closure=closure, **kwargs)
@override
def clip_gradients(
@@ -145,4 +106,22 @@ def clip_gradients(
clip_val: Union[int, float] = 0.0,
gradient_clip_algorithm: GradClipAlgorithmType = GradClipAlgorithmType.NORM,
) -> None:
- """DeepSpeed handles gradient clipping internally."""
+ return self.deepspeed_precision_impl.clip_gradients(
+ optimizer=optimizer, clip_val=clip_val, gradient_clip_algorithm=gradient_clip_algorithm
+ )
+
+ @property
+ def precision(self) -> str:
+ return self.deepspeed_precision_impl.precision
+
+ @precision.setter
+ def precision(self, value: str) -> None:
+ self.deepspeed_precision_impl.precision = value
+
+ @property
+ def _desired_dtype(self) -> torch.dtype:
+ return self.deepspeed_precision_impl._desired_dtype
+
+ @_desired_dtype.setter
+ def _desired_dtype(self, value: torch.dtype) -> None:
+ self.deepspeed_precision_impl._desired_dtype = value
diff --git a/src/lightning/pytorch/plugins/precision/xla.py b/src/lightning/pytorch/plugins/precision/xla.py
index 6890cc4c1d825..5f84399401fe3 100644
--- a/src/lightning/pytorch/plugins/precision/xla.py
+++ b/src/lightning/pytorch/plugins/precision/xla.py
@@ -11,19 +11,16 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-import os
-from functools import partial
from typing import Any, Callable
import torch
-from typing_extensions import get_args, override
+from typing_extensions import override
import lightning.pytorch as pl
-from lightning.fabric.accelerators.xla import _XLA_AVAILABLE
from lightning.fabric.plugins.precision.xla import _PRECISION_INPUT
+from lightning.fabric.utilities.imports import _raise_enterprise_not_available
from lightning.fabric.utilities.types import Optimizable
from lightning.pytorch.plugins.precision.precision import Precision
-from lightning.pytorch.utilities.exceptions import MisconfigurationException
class XLAPrecision(Precision):
@@ -39,25 +36,12 @@ class XLAPrecision(Precision):
"""
def __init__(self, precision: _PRECISION_INPUT = "32-true") -> None:
- if not _XLA_AVAILABLE:
- raise ModuleNotFoundError(str(_XLA_AVAILABLE))
-
- supported_precision = get_args(_PRECISION_INPUT)
- if precision not in supported_precision:
- raise ValueError(
- f"`precision={precision!r})` is not supported in XLA."
- f" `precision` must be one of: {supported_precision}."
- )
- self.precision = precision
-
- if precision == "16-true":
- os.environ["XLA_USE_F16"] = "1"
- self._desired_dtype = torch.float16
- elif precision == "bf16-true":
- os.environ["XLA_USE_BF16"] = "1"
- self._desired_dtype = torch.bfloat16
- else:
- self._desired_dtype = torch.float32
+ super().__init__()
+
+ _raise_enterprise_not_available()
+ from pytorch_lightning_enterprise.plugins.precision.xla import XLAPrecision as EnterpriseXLAPrecision
+
+ self.xla_impl = EnterpriseXLAPrecision(precision)
@override
def optimizer_step( # type: ignore[override]
@@ -67,31 +51,24 @@ def optimizer_step( # type: ignore[override]
closure: Callable[[], Any],
**kwargs: Any,
) -> Any:
- import torch_xla.core.xla_model as xm
-
- closure = partial(self._xla_wrap_closure, optimizer, closure)
- closure = partial(self._wrap_closure, model, optimizer, closure)
- closure_result = optimizer.step(closure=closure, **kwargs)
- xm.mark_step()
- skipped_backward = closure_result is None
- # in manual optimization, the closure does not return a value
- if model.automatic_optimization and skipped_backward:
- # we lack coverage here so disable this - something to explore if there's demand
- raise MisconfigurationException(
- "Skipping backward by returning `None` from your `training_step` is not implemented with XLA."
- " Please, open an issue in `https://github.com/Lightning-AI/pytorch-lightning/issues`"
- " requesting this feature."
- )
- return closure_result
+ return self.xla_impl.optimizer_step(optimizer, model, closure, **kwargs)
- @override
- def teardown(self) -> None:
- os.environ.pop("XLA_USE_BF16", None)
- os.environ.pop("XLA_USE_F16", None)
+ @property
+ def precision(self) -> _PRECISION_INPUT:
+ return self.xla_impl.precision
- def _xla_wrap_closure(self, optimizer: Optimizable, closure: Callable[[], Any]) -> Any:
- import torch_xla.core.xla_model as xm
+ @precision.setter
+ def precision(self, precision: _PRECISION_INPUT) -> None:
+ self.xla_impl.precision = precision
- closure_result = closure()
- xm.reduce_gradients(optimizer)
- return closure_result
+ @property
+ def _desired_dtype(self) -> torch.dtype:
+ return self.xla_impl._desired_dtype
+
+ @_desired_dtype.setter
+ def _desired_dtype(self, dtype: torch.dtype) -> None:
+ self.xla_impl._desired_dtype = dtype
+
+ @override
+ def teardown(self) -> None:
+ return self.xla_impl.teardown()
diff --git a/src/lightning/pytorch/strategies/deepspeed.py b/src/lightning/pytorch/strategies/deepspeed.py
index 1fce7b06887cd..856760815b698 100644
--- a/src/lightning/pytorch/strategies/deepspeed.py
+++ b/src/lightning/pytorch/strategies/deepspeed.py
@@ -11,52 +11,27 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-import argparse
-import json
+
import logging
-import os
-import platform
from collections import OrderedDict
from collections.abc import Generator, Mapping
from contextlib import contextmanager
from datetime import timedelta
-from pathlib import Path
from typing import TYPE_CHECKING, Any, Optional, Union
import torch
from torch.nn import Module
from torch.optim import Optimizer
-from torch.optim.lr_scheduler import LRScheduler, ReduceLROnPlateau
from typing_extensions import override
import lightning.pytorch as pl
from lightning.fabric.plugins import ClusterEnvironment
from lightning.fabric.plugins.collectives.torch_collective import default_pg_timeout
from lightning.fabric.strategies import _StrategyRegistry
-from lightning.fabric.strategies.deepspeed import (
- _DEEPSPEED_AVAILABLE,
- _DEEPSPEED_GREATER_EQUAL_0_16,
- _format_precision_config,
- _validate_checkpoint_directory,
- _validate_device_index_selection,
-)
-from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_6
-from lightning.fabric.utilities.optimizer import _optimizers_to_device
-from lightning.fabric.utilities.seed import reset_seed
+from lightning.fabric.utilities.imports import _raise_enterprise_not_available
from lightning.fabric.utilities.types import _PATH
-from lightning.pytorch.accelerators.cuda import CUDAAccelerator
-from lightning.pytorch.core.optimizer import _init_optimizers_and_lr_schedulers
from lightning.pytorch.plugins.precision import Precision
from lightning.pytorch.strategies.ddp import DDPStrategy
-from lightning.pytorch.trainer.states import TrainerFn
-from lightning.pytorch.utilities import GradClipAlgorithmType
-from lightning.pytorch.utilities.exceptions import MisconfigurationException
-from lightning.pytorch.utilities.model_helpers import is_overridden
-from lightning.pytorch.utilities.rank_zero import WarningCache, rank_zero_info, rank_zero_warn
-from lightning.pytorch.utilities.types import LRSchedulerConfig
-
-log = logging.getLogger(__name__)
-warning_cache = WarningCache()
if TYPE_CHECKING:
import deepspeed
@@ -258,25 +233,6 @@ def __init__(
exclude_frozen_parameters: Exclude frozen parameters when saving checkpoints.
"""
- if not _DEEPSPEED_AVAILABLE:
- raise MisconfigurationException(
- "To use the `DeepSpeedStrategy`, you must have DeepSpeed installed."
- " Install it by running `pip install -U deepspeed`."
- )
-
- if _TORCH_GREATER_EQUAL_2_6 and not _DEEPSPEED_GREATER_EQUAL_0_16:
- # Starting with PyTorch 2.6, `torch.load` defaults to `weights_only=True` when loading full checkpoints.
- # DeepSpeed added support for this behavior in version 0.16.0.
- import deepspeed
-
- deepspeed_version = deepspeed.__version__
-
- raise ImportError(
- f"PyTorch >= 2.6 requires DeepSpeed >= 0.16.0. "
- f"Detected DeepSpeed version: {deepspeed_version}. "
- "Please upgrade by running `pip install -U 'deepspeed>=0.16.0'`."
- )
-
super().__init__(
accelerator=accelerator,
parallel_devices=parallel_devices,
@@ -284,124 +240,77 @@ def __init__(
precision_plugin=precision_plugin,
process_group_backend=process_group_backend,
)
- self._timeout: Optional[timedelta] = timeout
-
- self.config = self._load_config(config)
- if self.config is None:
- # User has not overridden config, set defaults
- self.config = self._create_default_config(
- zero_optimization,
- zero_allow_untested_optimizer,
- logging_batch_size_per_gpu,
- offload_optimizer=offload_optimizer,
- offload_parameters=offload_parameters,
- nvme_path=nvme_path,
- offload_params_device=offload_params_device,
- params_buffer_count=params_buffer_count,
- params_buffer_size=params_buffer_size,
- max_in_cpu=max_in_cpu,
- pin_memory=pin_memory,
- offload_optimizer_device=offload_optimizer_device,
- optimizer_buffer_count=optimizer_buffer_count,
- block_size=block_size,
- queue_depth=queue_depth,
- single_submit=single_submit,
- overlap_events=overlap_events,
- thread_count=thread_count,
- partition_activations=partition_activations,
- cpu_checkpointing=cpu_checkpointing,
- contiguous_memory_optimization=contiguous_memory_optimization,
- synchronize_checkpoint_boundary=synchronize_checkpoint_boundary,
- stage=stage,
- contiguous_gradients=contiguous_gradients,
- overlap_comm=overlap_comm,
- allgather_partitions=allgather_partitions,
- reduce_scatter=reduce_scatter,
- allgather_bucket_size=allgather_bucket_size,
- reduce_bucket_size=reduce_bucket_size,
- sub_group_size=sub_group_size,
- )
- import deepspeed
-
- self._config_initialized = False
- deepspeed.utils.logging.logger.setLevel(logging_level)
-
- self.remote_device = remote_device
- self.load_full_weights = load_full_weights
- self.exclude_frozen_parameters = exclude_frozen_parameters
-
- # default FP16 parameters.
- self.loss_scale = loss_scale
- self.initial_scale_power = initial_scale_power
- self.loss_scale_window = loss_scale_window
- self.hysteresis = hysteresis
- self.min_loss_scale = min_loss_scale
+ _raise_enterprise_not_available()
+ from pytorch_lightning_enterprise.strategies.deepspeed import (
+ DeepSpeedStrategyTrainer as EnterpriseDeepSpeedStrategy,
+ )
+
+ self.deepspeed_strategy_impl = EnterpriseDeepSpeedStrategy(
+ outer_object=self,
+ accelerator=accelerator,
+ zero_optimization=zero_optimization,
+ stage=stage,
+ remote_device=remote_device,
+ offload_optimizer=offload_optimizer,
+ offload_parameters=offload_parameters,
+ offload_params_device=offload_params_device,
+ nvme_path=nvme_path,
+ params_buffer_count=params_buffer_count,
+ params_buffer_size=params_buffer_size,
+ max_in_cpu=max_in_cpu,
+ offload_optimizer_device=offload_optimizer_device,
+ optimizer_buffer_count=optimizer_buffer_count,
+ block_size=block_size,
+ queue_depth=queue_depth,
+ single_submit=single_submit,
+ overlap_events=overlap_events,
+ thread_count=thread_count,
+ pin_memory=pin_memory,
+ sub_group_size=sub_group_size,
+ contiguous_gradients=contiguous_gradients,
+ overlap_comm=overlap_comm,
+ allgather_partitions=allgather_partitions,
+ reduce_scatter=reduce_scatter,
+ allgather_bucket_size=allgather_bucket_size,
+ reduce_bucket_size=reduce_bucket_size,
+ zero_allow_untested_optimizer=zero_allow_untested_optimizer,
+ logging_batch_size_per_gpu=logging_batch_size_per_gpu,
+ config=config,
+ logging_level=logging_level,
+ parallel_devices=parallel_devices,
+ cluster_environment=cluster_environment,
+ loss_scale=loss_scale,
+ initial_scale_power=initial_scale_power,
+ loss_scale_window=loss_scale_window,
+ hysteresis=hysteresis,
+ min_loss_scale=min_loss_scale,
+ partition_activations=partition_activations,
+ cpu_checkpointing=cpu_checkpointing,
+ contiguous_memory_optimization=contiguous_memory_optimization,
+ synchronize_checkpoint_boundary=synchronize_checkpoint_boundary,
+ load_full_weights=load_full_weights,
+ precision_plugin=precision_plugin,
+ process_group_backend=process_group_backend,
+ timeout=timeout,
+ exclude_frozen_parameters=exclude_frozen_parameters,
+ )
@override
def setup_environment(self) -> None:
- if not isinstance(self.accelerator, CUDAAccelerator):
- raise RuntimeError(
- f"The DeepSpeed strategy is only supported on CUDA GPUs but `{self.accelerator.__class__.__name__}`"
- " is used."
- )
- super().setup_environment()
+ return self.deepspeed_strategy_impl.setup_environment()
@override
def setup_distributed(self) -> None:
- assert self.parallel_devices is not None
- _validate_device_index_selection(self.parallel_devices)
- reset_seed()
- self.set_world_ranks()
- self._init_deepspeed_distributed()
+ return self.deepspeed_strategy_impl.setup_distributed()
@override
def setup(self, trainer: "pl.Trainer") -> None:
- self._init_config_if_needed()
- assert self.accelerator is not None
- self.accelerator.setup(trainer)
-
- assert self.model is not None
- self.model = self.precision_plugin.convert_module(self.model)
- self.model = self._setup_model(self.model)
-
- if trainer.state.fn == TrainerFn.FITTING:
- self.setup_optimizers(trainer)
- self.setup_precision_plugin()
- if trainer.state.fn == TrainerFn.FITTING:
- _optimizers_to_device(self.optimizers, self.root_device)
-
- self.init_deepspeed()
- self.barrier()
-
- def _init_deepspeed_distributed(self) -> None:
- import deepspeed
-
- assert self.cluster_environment is not None
- if platform.system() != "Windows":
- # do not set env variables on windows, allow deepspeed to control setup
- self._set_node_environment_variables()
- log.info(
- "initializing deepspeed distributed: "
- f"GLOBAL_RANK: {self.global_rank}, "
- f"MEMBER: {self.global_rank + 1}/{self.world_size}"
- )
- self._process_group_backend = self._get_process_group_backend()
- deepspeed.init_distributed(
- self._process_group_backend, distributed_port=self.cluster_environment.main_port, timeout=self._timeout
- )
-
- def _set_node_environment_variables(self) -> None:
- assert self.cluster_environment is not None
- os.environ["MASTER_ADDR"] = self.cluster_environment.main_address
- os.environ["MASTER_PORT"] = str(self.cluster_environment.main_port)
- os.environ["RANK"] = str(self.global_rank)
- os.environ["WORLD_SIZE"] = str(self.world_size)
- os.environ["LOCAL_RANK"] = str(self.local_rank)
+ return self.deepspeed_strategy_impl.setup(trainer=trainer)
@property
@override
def restore_checkpoint_after_setup(self) -> bool:
- return True
+ return self.deepspeed_strategy_impl.restore_checkpoint_after_setup
@override
def _setup_model_and_optimizers(
@@ -416,188 +325,28 @@ def _setup_model_and_optimizers(
deepspeed optimizer.
"""
- if len(optimizers) != 1:
- raise ValueError(
- f"Currently only one optimizer is supported with DeepSpeed. Got {len(optimizers)} optimizers instead."
- )
-
- # train_micro_batch_size_per_gpu is used for throughput logging purposes
- # normally we set this to the batch size, but it is not available here unless the user provides it
- # as part of the config
- assert self.config is not None
- self.config.setdefault("train_micro_batch_size_per_gpu", 1)
- self.model, optimizer = self._setup_model_and_optimizer(model, optimizers[0])
- self._set_deepspeed_activation_checkpointing()
- return self.model, [optimizer]
-
- def _setup_model_and_optimizer(
- self,
- model: Module,
- optimizer: Optional[Optimizer],
- lr_scheduler: Optional[Union[LRScheduler, ReduceLROnPlateau]] = None,
- ) -> tuple["deepspeed.DeepSpeedEngine", Optimizer]:
- """Initialize one model and one optimizer with an optional learning rate scheduler.
-
- This calls ``deepspeed.initialize`` internally.
-
- """
- import deepspeed
-
- model_parameters = filter(lambda p: p.requires_grad, model.parameters())
- deepspeed_engine, deepspeed_optimizer, _, _ = deepspeed.initialize(
- args=argparse.Namespace(device_rank=self.root_device.index),
- config=self.config,
- model=model,
- model_parameters=model_parameters,
- optimizer=optimizer,
- lr_scheduler=lr_scheduler,
- dist_init_required=False,
- )
- return deepspeed_engine, deepspeed_optimizer
-
- def init_deepspeed(self) -> None:
- assert self.lightning_module is not None
- # deepspeed handles gradient clipping internally
- if is_overridden("configure_gradient_clipping", self.lightning_module, pl.LightningModule):
- rank_zero_warn(
- "Since DeepSpeed handles gradient clipping internally, the default"
- " `LightningModule.configure_gradient_clipping` implementation will not actually clip gradients."
- " The hook will still be called. Consider setting"
- " `Trainer(gradient_clip_val=..., gradient_clip_algorithm='norm')`"
- " which will use the internal mechanism."
- )
-
- if self.lightning_module.trainer.gradient_clip_algorithm == GradClipAlgorithmType.VALUE:
- raise MisconfigurationException("DeepSpeed does not support clipping gradients by value.")
-
- assert isinstance(self.model, pl.LightningModule)
- if self.lightning_module.trainer and self.lightning_module.trainer.training:
- self._initialize_deepspeed_train(self.model)
- else:
- self._initialize_deepspeed_inference(self.model)
-
- def _init_optimizers(self) -> tuple[Optimizer, Optional[LRSchedulerConfig]]:
- assert self.lightning_module is not None
- optimizers, lr_schedulers = _init_optimizers_and_lr_schedulers(self.lightning_module)
- if len(optimizers) > 1 or len(lr_schedulers) > 1:
- raise MisconfigurationException(
- "DeepSpeed currently only supports single optimizer, single optional scheduler."
- )
- return optimizers[0], lr_schedulers[0] if lr_schedulers else None
+ return self.deepspeed_strategy_impl._setup_model_and_optimizers(model=model, optimizers=optimizers)
@property
def zero_stage_3(self) -> bool:
- assert isinstance(self.config, dict)
- zero_optimization = self.config.get("zero_optimization")
- return zero_optimization is not None and zero_optimization.get("stage") == 3
-
- def _initialize_deepspeed_train(self, model: Module) -> None:
- optimizer, scheduler = None, None
- assert isinstance(self.config, dict)
- if "optimizer" in self.config:
- rank_zero_info(
- "You have specified an optimizer and/or scheduler within the DeepSpeed config."
- " It is recommended to define it in `LightningModule.configure_optimizers`."
- )
- lr_scheduler = None
- else:
- (
- optimizer,
- lr_scheduler,
- ) = self._init_optimizers()
- if lr_scheduler is not None:
- scheduler = lr_scheduler.scheduler
-
- model, deepspeed_optimizer = self._setup_model_and_optimizer(model, optimizer, scheduler)
- self._set_deepspeed_activation_checkpointing()
-
- # although we set these here, deepspeed manages the specific optimizer logic
- self.optimizers = [deepspeed_optimizer]
-
- deepspeed_scheduler = model.lr_scheduler
- if deepspeed_scheduler is not None:
- # disable deepspeed lr scheduling as lightning manages scheduling
- model.lr_scheduler = None
- if lr_scheduler is None:
- lr_scheduler = LRSchedulerConfig(deepspeed_scheduler, interval="step")
- else:
- lr_scheduler.scheduler = deepspeed_scheduler
- self.lr_scheduler_configs = [lr_scheduler]
- self.model = model
+ return self.deepspeed_strategy_impl.zero_stage_3
@contextmanager
@override
def tensor_init_context(self, empty_init: Optional[bool] = None) -> Generator[None, None, None]:
- if self.zero_stage_3:
- if empty_init is False:
- raise NotImplementedError(
- f"`{empty_init=}` is not a valid choice with `DeepSpeedStrategy` when ZeRO stage 3 is enabled."
- )
- yield
- return
- with super().tensor_init_context(empty_init=empty_init):
+ with self.deepspeed_strategy_impl.tensor_init_context(empty_init=empty_init):
yield
@contextmanager
@override
def model_sharded_context(self) -> Generator[None, None, None]:
- import deepspeed
-
- self._init_config_if_needed()
- with deepspeed.zero.Init(
- enabled=self.zero_stage_3,
- remote_device=self.remote_device,
- config_dict_or_path=self.config,
- ):
+ with self.deepspeed_strategy_impl.model_sharded_context():
yield
- def _set_deepspeed_activation_checkpointing(self) -> None:
- import deepspeed
-
- assert isinstance(self.config, dict)
- if self.config.get("activation_checkpointing"):
- checkpoint_config = self.config["activation_checkpointing"]
- deepspeed.checkpointing.configure(
- mpu_=None,
- partition_activations=checkpoint_config.get("partition_activations"),
- contiguous_checkpointing=checkpoint_config.get("contiguous_memory_optimization"),
- checkpoint_in_cpu=checkpoint_config.get("cpu_checkpointing"),
- profile=checkpoint_config.get("profile"),
- )
-
- def _initialize_deepspeed_inference(self, model: Module) -> None:
- import deepspeed
-
- assert isinstance(self.config, dict)
-
- # todo: this is required for DeepSpeed throughput timers
- inference_config = {"train_micro_batch_size_per_gpu": 1}
- if "fp16" in self.config:
- inference_config.update({"fp16": self.config["fp16"]})
- if "bf16" in self.config:
- inference_config.update({"bf16": self.config["bf16"]})
- if self.zero_stage_3:
- inference_config.update({
- "zero_allow_untested_optimizer": self.config["zero_allow_untested_optimizer"],
- "zero_optimization": self.config["zero_optimization"],
- })
- # Remove all module hooks before initializing new model
- remove_module_hooks(model)
- model, _, _, _ = deepspeed.initialize(
- args=argparse.Namespace(device_rank=self.root_device.index),
- config=inference_config,
- model=model,
- optimizer=None,
- lr_scheduler=None,
- model_parameters=[],
- dist_init_required=False,
- )
- self.model = model
-
@property
@override
def distributed_sampler_kwargs(self) -> dict[str, int]:
- return {"num_replicas": self.world_size, "rank": self.global_rank}
+ return self.deepspeed_strategy_impl.distributed_sampler_kwargs()
@override
def setup_optimizers(self, trainer: "pl.Trainer") -> None:
@@ -607,29 +356,21 @@ def setup_optimizers(self, trainer: "pl.Trainer") -> None:
trainer: the Trainer, these optimizers should be connected to
"""
- # Skip initializing optimizers here as DeepSpeed handles optimizers via config.
- # User may have specified config options instead in configure_optimizers, but this is handled
- # via `_initialize_deepspeed_train`
- # empty optimizers, schedulers
- self.optimizers = []
- self.lr_scheduler_configs = []
-
- def _setup_model(self, model: Module) -> Module: # type: ignore[override]
- return model
+ return self.deepspeed_strategy_impl.setup_optimizers(trainer=trainer)
@property
@override
def handles_gradient_accumulation(self) -> bool:
"""Whether the strategy handles gradient accumulation internally."""
- return True
+ return self.deepspeed_strategy_impl.handles_gradient_accumulation
@property
def deepspeed_engine(self) -> "deepspeed.DeepSpeedEngine":
- return self.model
+ return self.deepspeed_strategy_impl.deepspeed_engine
@property
def _multi_device(self) -> bool:
- return self.num_processes > 1 or self.num_nodes > 1
+ return self.deepspeed_strategy_impl._multi_device
@override
def save_checkpoint(self, checkpoint: dict, filepath: _PATH, storage_options: Optional[Any] = None) -> None:
@@ -645,135 +386,27 @@ def save_checkpoint(self, checkpoint: dict, filepath: _PATH, storage_options: Op
If ``storage_options`` arg is passed in
"""
- # broadcast the filepath from rank 0 to ensure all the states are saved in a common filepath
- filepath = self.broadcast(filepath)
-
- if storage_options is not None:
- raise TypeError(
- "`Trainer.save_checkpoint(..., storage_options=...)` with `storage_options` arg"
- f" is not supported for `{self.__class__.__name__}` as `CheckpointIO` is not used."
- )
-
- if self.zero_stage_3 and self._multi_device and self.is_global_zero:
- warning_cache.warn(
- "When saving the DeepSpeed Stage 3 checkpoint, "
- "each worker will save a shard of the checkpoint within a directory. "
- "If a single file is required after training, "
- "see https://lightning.ai/docs/pytorch/stable/advanced/model_parallel.html#"
- "deepspeed-zero-stage-3-single-file for instructions."
- )
- # Use deepspeed's internal checkpointing function to handle partitioned weights across processes
- # dump states as a checkpoint dictionary object
- _exclude_keys = ["state_dict", "optimizer_states"]
- checkpoint = {k: v for k, v in checkpoint.items() if k not in _exclude_keys}
- self.deepspeed_engine.save_checkpoint(
- filepath,
- client_state=checkpoint,
- tag="checkpoint",
- exclude_frozen_parameters=self.exclude_frozen_parameters,
+ return self.deepspeed_strategy_impl.save_checkpoint(
+ checkpoint=checkpoint, filepath=filepath, storage_options=storage_options
)
@override
def load_checkpoint(self, checkpoint_path: _PATH, weights_only: Optional[bool] = None) -> dict[str, Any]:
- if self.load_full_weights and self.zero_stage_3:
- # Broadcast to ensure we load from the rank 0 checkpoint
- # This doesn't have to be the case when using deepspeed sharded checkpointing
- checkpoint_path = self.broadcast(checkpoint_path)
- return super().load_checkpoint(checkpoint_path, weights_only)
-
- _validate_checkpoint_directory(checkpoint_path)
-
- # Rely on deepspeed to load the checkpoint and necessary information
- assert self.lightning_module is not None
-
- from lightning.pytorch.trainer.states import TrainerFn
-
- is_fitting = self.lightning_module.trainer.state.fn == TrainerFn.FITTING
-
- _, client_state = self.deepspeed_engine.load_checkpoint(
- checkpoint_path,
- load_optimizer_states=is_fitting,
- load_lr_scheduler_states=False,
- load_module_strict=self.lightning_module.strict_loading,
- )
- if client_state is None:
- raise MisconfigurationException(
- "DeepSpeed was unable to load the checkpoint. Ensure you passed in a DeepSpeed compatible checkpoint "
- "or a single checkpoint file with `Trainer(strategy=DeepSpeedStrategy(load_full_weights=True))`."
- )
- return client_state
+ return self.deepspeed_strategy_impl.load_checkpoint(checkpoint_path=checkpoint_path, weights_only=weights_only)
@property
@override
def lightning_restore_optimizer(self) -> bool:
- assert self.lightning_module is not None
- # managed by DeepSpeed
- if self.load_full_weights and self.zero_stage_3 and self.lightning_module.trainer.state.fn == TrainerFn.FITTING:
- rank_zero_warn(
- "A single checkpoint file has been given. This means optimizer states cannot be restored."
- " If you'd like to restore these states, you must provide a path to the originally saved DeepSpeed"
- " checkpoint. When using ZeRO 3, the original path should be a directory."
- )
- return False
+ return self.deepspeed_strategy_impl.lightning_restore_optimizer
@override
def load_model_state_dict(self, checkpoint: Mapping[str, Any], strict: bool = True) -> None:
# override to do nothing, deepspeed engine already loaded the weights in `load_checkpoint()`
- if self.load_full_weights and self.zero_stage_3:
- self.model_to_device()
- self._restore_zero_state(checkpoint, strict=strict)
-
- def _restore_zero_state(self, ckpt: Mapping[str, Any], strict: bool) -> None:
- """Overrides the normal load_state_dict behaviour in PyTorch to ensure we gather parameters that may be sharded
- across processes before loading the state dictionary when using ZeRO stage 3. This is then automatically synced
- across processes.
-
- Args:
- ckpt: The ckpt file.
-
- """
- import deepspeed
-
- assert self.lightning_module is not None
-
- def load(module: torch.nn.Module, prefix: str = "") -> None:
- missing_keys: list[str] = []
- unexpected_keys: list[str] = []
- error_msgs: list[str] = []
- state_dict = ckpt["state_dict"]
-
- # copy state_dict so _load_from_state_dict can modify it
- metadata = getattr(state_dict, "_metadata", None)
- state_dict = state_dict.copy()
- if metadata is not None:
- state_dict._metadata = metadata
-
- local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
- # because zero3 puts placeholders in model params, this context
- # manager gathers (unpartitions) the params of the current layer, then loads from
- # the state dict and then re-partitions them again
- with deepspeed.zero.GatheredParameters(list(module.parameters(recurse=False)), modifier_rank=0):
- if self.is_global_zero:
- module._load_from_state_dict(
- state_dict=state_dict,
- prefix=prefix,
- local_metadata=local_metadata,
- strict=strict,
- missing_keys=missing_keys,
- unexpected_keys=unexpected_keys,
- error_msgs=error_msgs,
- )
-
- for name, child in module._modules.items():
- if child is not None:
- load(child, prefix + name + ".")
-
- load(self.lightning_module, prefix="")
+ return self.deepspeed_strategy_impl.load_model_state_dict(checkpoint=checkpoint, strict=strict)
@override
def load_optimizer_state_dict(self, checkpoint: Mapping[str, Any]) -> None:
- # Override to do nothing, the deepspeed engine already loaded the states in `load_checkpoint()`
- pass
+ return self.deepspeed_strategy_impl.load_optimizer_state_dict(checkpoint=checkpoint)
@classmethod
@override
@@ -809,137 +442,6 @@ def register_strategies(cls, strategy_registry: _StrategyRegistry) -> None:
offload_optimizer_device="nvme",
)
- def _load_config(self, config: Optional[Union[_PATH, dict[str, Any]]]) -> Optional[dict[str, Any]]:
- if config is None and self.DEEPSPEED_ENV_VAR in os.environ:
- rank_zero_info(f"Loading DeepSpeed config from set {self.DEEPSPEED_ENV_VAR} environment variable")
- config = os.environ[self.DEEPSPEED_ENV_VAR]
- if isinstance(config, (str, Path)):
- if not os.path.isfile(config):
- raise MisconfigurationException(
- f"You passed in a path to a DeepSpeed config but the path does not exist: {config}"
- )
- with open(config) as f:
- config = json.load(f)
- assert isinstance(config, dict) or config is None
- return config
-
- def _init_config_if_needed(self) -> None:
- if not self._config_initialized:
- self._format_config()
- self._config_initialized = True
-
- def _format_config(self) -> None:
- if self.config is None:
- raise MisconfigurationException(
- "To use DeepSpeed you must pass in a DeepSpeed config dict, or a path to a JSON config."
- " See: https://lightning.ai/docs/pytorch/stable/advanced/model_parallel.html#deepspeed"
- )
- self._format_batch_size_and_grad_accum_config()
- _format_precision_config(
- config=self.config,
- precision=self.precision_plugin.precision,
- loss_scale=self.loss_scale,
- loss_scale_window=self.loss_scale_window,
- min_loss_scale=self.min_loss_scale,
- initial_scale_power=self.initial_scale_power,
- hysteresis=self.hysteresis,
- )
-
- def _create_default_config(
- self,
- zero_optimization: bool,
- zero_allow_untested_optimizer: bool,
- logging_batch_size_per_gpu: Union[str, int],
- partition_activations: bool,
- cpu_checkpointing: bool,
- contiguous_memory_optimization: bool,
- synchronize_checkpoint_boundary: bool,
- offload_optimizer: bool,
- offload_parameters: bool,
- nvme_path: str,
- offload_params_device: str,
- params_buffer_count: int,
- params_buffer_size: int,
- max_in_cpu: int,
- offload_optimizer_device: str,
- optimizer_buffer_count: int,
- pin_memory: bool,
- block_size: int,
- queue_depth: int,
- single_submit: bool,
- overlap_events: bool,
- thread_count: int,
- **zero_kwargs: Any,
- ) -> dict:
- cfg = {
- "activation_checkpointing": {
- "partition_activations": partition_activations,
- "cpu_checkpointing": cpu_checkpointing,
- "contiguous_memory_optimization": contiguous_memory_optimization,
- "synchronize_checkpoint_boundary": synchronize_checkpoint_boundary,
- },
- "aio": {
- "block_size": block_size,
- "queue_depth": queue_depth,
- "single_submit": single_submit,
- "overlap_events": overlap_events,
- "thread_count": thread_count,
- },
- }
- if zero_optimization:
- zero_config = zero_kwargs
-
- if offload_optimizer:
- zero_config["offload_optimizer"] = {
- "device": offload_optimizer_device,
- "nvme_path": nvme_path,
- "buffer_count": optimizer_buffer_count,
- "pin_memory": pin_memory,
- }
- if offload_parameters:
- zero_config["offload_param"] = {
- "device": offload_params_device,
- "nvme_path": nvme_path,
- "buffer_count": params_buffer_count,
- "buffer_size": params_buffer_size,
- "max_in_cpu": max_in_cpu,
- "pin_memory": pin_memory,
- }
- cfg = {
- "zero_allow_untested_optimizer": zero_allow_untested_optimizer,
- "zero_optimization": zero_config,
- **cfg,
- }
- if logging_batch_size_per_gpu != "auto":
- cfg = {"train_micro_batch_size_per_gpu": logging_batch_size_per_gpu, **cfg}
- return cfg
-
- def _format_batch_size_and_grad_accum_config(self) -> None:
- # TODO: Using Fabric, we do not support these variables within the config
- assert isinstance(self.config, dict)
- if self.lightning_module is None:
- return
-
- if "gradient_accumulation_steps" in self.config:
- raise MisconfigurationException(
- "Do not set `gradient_accumulation_steps` in the DeepSpeed config"
- " as this will be set with the `accumulate_grad_batches` argument passed via the Lightning Trainer."
- )
- self.config["gradient_accumulation_steps"] = self.lightning_module.trainer.accumulate_grad_batches
- if "train_micro_batch_size_per_gpu" not in self.config:
- batch_size = self._auto_select_batch_size()
- self.config["train_micro_batch_size_per_gpu"] = batch_size
- if "gradient_clipping" not in self.config:
- self.config["gradient_clipping"] = self.lightning_module.trainer.gradient_clip_val or 0.0
-
- def _auto_select_batch_size(self) -> int:
- # train_micro_batch_size_per_gpu is used for throughput logging purposes
- # by default we try to use the batch size of the loader
- assert self.lightning_module is not None
- batch_size = 1
- data_source = self.lightning_module.trainer.fit_loop._data_source
- if data_source.is_defined():
- train_dataloader = data_source.dataloader()
- if hasattr(train_dataloader, "batch_sampler"):
- batch_size = train_dataloader.batch_sampler.batch_size
- return batch_size
+ @property
+ def config(self) -> dict[str, Any]:
+ return self.deepspeed_strategy_impl.config
diff --git a/src/lightning/pytorch/strategies/launchers/xla.py b/src/lightning/pytorch/strategies/launchers/xla.py
index 066fecc79f208..63842d5557d7e 100644
--- a/src/lightning/pytorch/strategies/launchers/xla.py
+++ b/src/lightning/pytorch/strategies/launchers/xla.py
@@ -11,23 +11,18 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-import os
import queue
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
import torch.multiprocessing as mp
from typing_extensions import override
-from lightning.fabric.accelerators.xla import _XLA_AVAILABLE
-from lightning.fabric.strategies.launchers.xla import _rank_teardown
-from lightning.fabric.utilities import move_data_to_device
+from lightning.fabric.utilities.imports import _raise_enterprise_not_available
from lightning.pytorch.strategies.launchers.multiprocessing import (
_GlobalStateSnapshot,
_MultiProcessingLauncher,
_WorkerOutput,
)
-from lightning.pytorch.trainer.states import TrainerFn
-from lightning.pytorch.utilities.rank_zero import rank_zero_debug
if TYPE_CHECKING:
import lightning.pytorch as pl
@@ -51,14 +46,16 @@ class _XLALauncher(_MultiProcessingLauncher):
"""
def __init__(self, strategy: "pl.strategies.XLAStrategy") -> None:
- if not _XLA_AVAILABLE:
- raise ModuleNotFoundError(str(_XLA_AVAILABLE))
- super().__init__(strategy=strategy, start_method="fork")
+ super().__init__(strategy)
+ _raise_enterprise_not_available()
+ from pytorch_lightning_enterprise.strategies.xla.launcher import _XLALauncherTrainer as EnterpriseXLALauncher
+
+ self.xla_launcher_impl = EnterpriseXLALauncher(strategy)
@property
@override
def is_interactive_compatible(self) -> bool:
- return True
+ return self.xla_launcher_impl.is_interactive_compatible
@override
def launch(self, function: Callable, *args: Any, trainer: Optional["pl.Trainer"] = None, **kwargs: Any) -> Any:
@@ -75,46 +72,7 @@ def launch(self, function: Callable, *args: Any, trainer: Optional["pl.Trainer"]
**kwargs: Optional keyword arguments to be passed to the given function.
"""
- if self._already_fit and trainer is not None and trainer.state.fn == TrainerFn.FITTING:
- # resolving https://github.com/Lightning-AI/pytorch-lightning/issues/18775 will lift this restriction
- raise NotImplementedError(
- "Calling `trainer.fit()` twice on the same Trainer instance using a spawn-based strategy is not"
- " supported. You can work around this by creating a new Trainer instance and passing the"
- " `fit(ckpt_path=...)` argument."
- )
-
- # pjrt requires that the queue is serializable
- return_queue = mp.Manager().Queue()
-
- import torch_xla.distributed.xla_multiprocessing as xmp
-
- spawn_kwargs = {}
- nprocs = self._strategy.num_processes
- if nprocs == 1:
- # avoid warning: "Unsupported nprocs". If it's 1, it will call the launched function directly.
- # otherwise it will use all devices
- spawn_kwargs["nprocs"] = nprocs
-
- process_context = xmp.spawn(
- self._wrapping_function,
- args=(trainer, function, args, kwargs, return_queue),
- start_method=self._start_method,
- join=False, # we will join ourselves to get the process references
- **spawn_kwargs,
- )
- # xla will not actually create processes if only 1 device
- if process_context is not None:
- self.procs = process_context.processes
- while not process_context.join():
- pass
-
- worker_output = return_queue.get()
- if trainer is None:
- return worker_output
-
- self._already_fit |= trainer.state.fn == TrainerFn.FITTING
- self._recover_results_in_main_process(worker_output, trainer)
- return worker_output.trainer_results
+ return self.xla_launcher_impl.launch(function, *args, trainer=trainer, **kwargs)
@override
def _wrapping_function(
@@ -129,48 +87,10 @@ def _wrapping_function(
return_queue: Union[mp.SimpleQueue, queue.Queue],
global_states: Optional[_GlobalStateSnapshot] = None,
) -> None:
- import torch_xla.core.xla_model as xm
-
- if len(xm.get_xla_supported_devices()) > 1:
- # `get_xla_supported_devices` in the spawned process returns the logical devices (2 for v2/v3 and 1 for v4)
- # so when there's more than one (multithreading), objects need to be deep-copied
- import copy
-
- trainer, function, args, kwargs = copy.deepcopy((trainer, function, args, kwargs))
-
- results = function(*args, **kwargs)
-
- if trainer is not None:
- results = self._collect_rank_zero_results(trainer, results)
-
- if self._strategy.local_rank == 0:
- return_queue.put(move_data_to_device(results, "cpu"))
-
- _rank_teardown(self._strategy.local_rank)
+ return self.xla_launcher_impl._wrapping_function(
+ process_idx, trainer, function, args, kwargs, return_queue, global_states
+ )
@override
def _collect_rank_zero_results(self, trainer: "pl.Trainer", results: Any) -> Optional["_WorkerOutput"]:
- rank_zero_debug("Collecting results from rank 0 process.")
- checkpoint_callback = trainer.checkpoint_callback
- best_model_path = (
- checkpoint_callback.best_model_path
- if checkpoint_callback and hasattr(checkpoint_callback, "best_model_path")
- else None
- )
-
- # save the last weights
- weights_path = None
- if trainer.state.fn == TrainerFn.FITTING:
- # requires to compute the state_dict on all processes in case Metrics are present
- state_dict = self._strategy.lightning_module_state_dict()
- weights_path = os.path.join(trainer.default_root_dir, ".temp.ckpt")
- self._strategy.checkpoint_io.save_checkpoint(state_dict, weights_path)
-
- # We use `local_rank` here as separate filesystems are used for each VM for TPU Pod Training
- if self._strategy.local_rank != 0:
- return None
-
- # add extra result data from trainer to send to main process
- extra = self.get_extra_results(trainer)
-
- return _WorkerOutput(best_model_path, weights_path, trainer.state, results, extra)
+ return self.xla_launcher_impl._collect_rank_zero_results(trainer, results)
diff --git a/src/lightning/pytorch/strategies/single_xla.py b/src/lightning/pytorch/strategies/single_xla.py
index 2a5e2f3a85b96..6e687ac815639 100644
--- a/src/lightning/pytorch/strategies/single_xla.py
+++ b/src/lightning/pytorch/strategies/single_xla.py
@@ -11,23 +11,18 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-import os
from typing import Optional, Union
-import torch
from typing_extensions import override
import lightning.pytorch as pl
-from lightning.fabric.accelerators.xla import _XLA_AVAILABLE
from lightning.fabric.plugins import CheckpointIO, Precision, XLACheckpointIO
from lightning.fabric.strategies import _StrategyRegistry
-from lightning.fabric.utilities.optimizer import _optimizers_to_device
+from lightning.fabric.utilities.imports import _raise_enterprise_not_available
from lightning.fabric.utilities.types import _DEVICE
from lightning.pytorch.plugins.io.wrapper import _WrappingCheckpointIO
from lightning.pytorch.plugins.precision.xla import XLAPrecision
from lightning.pytorch.strategies.single_device import SingleDeviceStrategy
-from lightning.pytorch.trainer.states import TrainerFn
-from lightning.pytorch.utilities import find_shared_parameters, set_shared_parameters
class SingleDeviceXLAStrategy(SingleDeviceStrategy):
@@ -41,20 +36,18 @@ def __init__(
precision_plugin: Optional[XLAPrecision] = None,
debug: bool = False,
):
- if not _XLA_AVAILABLE:
- raise ModuleNotFoundError(str(_XLA_AVAILABLE))
- if isinstance(device, torch.device):
- # unwrap the `torch.device` in favor of `xla_device`
- device = device.index
- import torch_xla.core.xla_model as xm
-
super().__init__(
accelerator=accelerator,
- device=xm.xla_device(device),
+ device=device,
checkpoint_io=checkpoint_io,
precision_plugin=precision_plugin,
)
- self.debug = debug
+ _raise_enterprise_not_available()
+ from pytorch_lightning_enterprise.strategies.xla.single import (
+ SingleDeviceXLAStrategyTrainer as EnterpriseSingleDeviceXLAStrategy,
+ )
+
+ self.single_xla_strategy_impl = EnterpriseSingleDeviceXLAStrategy(outer_object=self, device=device, debug=debug)
@property
@override
@@ -90,26 +83,7 @@ def precision_plugin(self, precision_plugin: Optional[Precision]) -> None:
@override
def setup(self, trainer: "pl.Trainer") -> None:
- if self.debug:
- os.environ["PT_XLA_DEBUG"] = str(1)
-
- assert self.accelerator is not None
- self.accelerator.setup(trainer)
-
- assert self.model is not None
- self.precision_plugin.convert_module(self.model)
-
- shared_params = find_shared_parameters(self.model)
- self.model_to_device()
- set_shared_parameters(self.model, shared_params)
-
- self.model = self._setup_model(self.model)
-
- if trainer.state.fn == TrainerFn.FITTING:
- self.setup_optimizers(trainer)
- self.setup_precision_plugin()
- if trainer.state.fn == TrainerFn.FITTING:
- _optimizers_to_device(self.optimizers, self.root_device)
+ return self.single_xla_strategy_impl.setup(trainer=trainer)
@classmethod
@override
@@ -118,5 +92,4 @@ def register_strategies(cls, strategy_registry: _StrategyRegistry) -> None:
@override
def teardown(self) -> None:
- super().teardown()
- os.environ.pop("PT_XLA_DEBUG", None)
+ return self.single_xla_strategy_impl.teardown()
diff --git a/src/lightning/pytorch/strategies/xla.py b/src/lightning/pytorch/strategies/xla.py
index cbdc890a1ca32..e57c627f374d7 100644
--- a/src/lightning/pytorch/strategies/xla.py
+++ b/src/lightning/pytorch/strategies/xla.py
@@ -11,8 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-import io
-import os
from typing import TYPE_CHECKING, Any, Optional, Union
import torch
@@ -21,20 +19,16 @@
from typing_extensions import override
import lightning.pytorch as pl
-from lightning.fabric.accelerators.xla import _XLA_AVAILABLE, _XLA_GREATER_EQUAL_2_1
from lightning.fabric.plugins import CheckpointIO, Precision, XLACheckpointIO
from lightning.fabric.plugins.environments import XLAEnvironment
from lightning.fabric.strategies import _StrategyRegistry
-from lightning.fabric.utilities.optimizer import _optimizers_to_device
+from lightning.fabric.utilities.imports import _raise_enterprise_not_available
from lightning.fabric.utilities.types import _PATH, ReduceOp
from lightning.pytorch.plugins import XLAPrecision
from lightning.pytorch.plugins.io.wrapper import _WrappingCheckpointIO
from lightning.pytorch.strategies.ddp import DDPStrategy
from lightning.pytorch.strategies.launchers.xla import _XLALauncher
from lightning.pytorch.strategies.strategy import TBroadcast
-from lightning.pytorch.trainer.states import TrainerFn
-from lightning.pytorch.utilities import find_shared_parameters, set_shared_parameters
-from lightning.pytorch.utilities.rank_zero import rank_zero_only
if TYPE_CHECKING:
from torch_xla.distributed.parallel_loader import MpDeviceLoader
@@ -56,8 +50,6 @@ def __init__(
sync_module_states: bool = True,
**_: Any,
) -> None:
- if not _XLA_AVAILABLE:
- raise ModuleNotFoundError(str(_XLA_AVAILABLE))
super().__init__(
accelerator=accelerator,
parallel_devices=parallel_devices,
@@ -66,9 +58,12 @@ def __init__(
precision_plugin=precision_plugin,
start_method="fork",
)
- self.debug = debug
- self._launched = False
- self._sync_module_states = sync_module_states
+ _raise_enterprise_not_available()
+ from pytorch_lightning_enterprise.strategies.xla.ddp import XLAStrategyTrainer as EnterpriseXLAStrategy
+
+ self.xla_strategy_impl = EnterpriseXLAStrategy(
+ outer_object=self, debug=debug, sync_module_states=sync_module_states
+ )
@property
@override
@@ -105,31 +100,27 @@ def precision_plugin(self, precision_plugin: Optional[Precision]) -> None:
@property
@override
def root_device(self) -> torch.device:
- if not self._launched:
- raise RuntimeError("Accessing the XLA device before processes have spawned is not allowed.")
- import torch_xla.core.xla_model as xm
-
- return xm.xla_device()
+ return self.xla_strategy_impl.root_device
@property
@override
def global_rank(self) -> int:
- return super().global_rank if self._launched else 0
+ return self.xla_strategy_impl.global_rank
@property
@override
def local_rank(self) -> int:
- return super().local_rank if self._launched else 0
+ return self.xla_strategy_impl.local_rank
@property
@override
def node_rank(self) -> int:
- return super().node_rank if self._launched else 0
+ return self.xla_strategy_impl.node_rank
@property
@override
def world_size(self) -> int:
- return super().world_size if self._launched else 1
+ return self.xla_strategy_impl.world_size
@override
def _configure_launcher(self) -> None:
@@ -137,113 +128,36 @@ def _configure_launcher(self) -> None:
@override
def setup(self, trainer: "pl.Trainer") -> None:
- assert self.accelerator is not None
- self.accelerator.setup(trainer)
-
- if self.debug:
- os.environ["PT_XLA_DEBUG"] = "1"
-
- assert self.model is not None
- self.precision_plugin.convert_module(self.model)
-
- shared_params = find_shared_parameters(self.model)
- self.model_to_device()
- set_shared_parameters(self.model, shared_params)
-
- self.model = self._setup_model(self.model)
-
- if self._sync_module_states:
- if _XLA_GREATER_EQUAL_2_1:
- from torch_xla.core.xla_model import broadcast_master_param
- else:
- from torch_xla.experimental.pjrt import broadcast_master_param
-
- broadcast_master_param(self.model)
-
- if trainer.state.fn == TrainerFn.FITTING:
- self.setup_optimizers(trainer)
- self.setup_precision_plugin()
- if trainer.state.fn == TrainerFn.FITTING:
- _optimizers_to_device(self.optimizers, self.root_device)
+ return self.xla_strategy_impl.setup(trainer=trainer)
@override
def _setup_model(self, model: Module) -> Module: # type: ignore
- return model
+ return self.xla_strategy_impl._setup_model(model=model)
@property
@override
def distributed_sampler_kwargs(self) -> dict[str, int]:
- return {"num_replicas": self.world_size, "rank": self.global_rank}
+ return self.xla_strategy_impl.distributed_sampler_kwargs
@override
def process_dataloader(self, dataloader: object) -> "MpDeviceLoader":
- from torch_xla.distributed.parallel_loader import MpDeviceLoader
-
- if isinstance(dataloader, MpDeviceLoader):
- # dataloader is already wrapped by MpDeviceLoader
- return dataloader
-
- dataloader = MpDeviceLoader(dataloader, self.root_device)
- # Mimic interface to torch.utils.data.DataLoader
- dataloader.dataset = dataloader._loader.dataset
- dataloader.batch_sampler = getattr(dataloader._loader, "batch_sampler", None)
- return dataloader
+ return self.xla_strategy_impl.process_dataloader(dataloader=dataloader)
@override
def configure_ddp(self) -> None:
- pass
+ return self.xla_strategy_impl.configure_ddp()
@override
def model_to_device(self) -> None:
- assert self.model is not None
- self.model = self.model.to(self.root_device)
+ return self.xla_strategy_impl.model_to_device()
@override
def barrier(self, name: Optional[str] = None, *args: Any, **kwargs: Any) -> None:
- if not self._launched:
- return
-
- import torch_xla.core.xla_model as xm
-
- if name is None:
- # `None` is not supported: "TypeError: _xla_rendezvous(): incompatible function arguments"
- name = ""
- xm.rendezvous(name)
+ return self.xla_strategy_impl.barrier(name=name, *args, **kwargs)
@override
def broadcast(self, obj: TBroadcast, src: int = 0) -> TBroadcast:
- if not self._launched:
- return obj
-
- import torch_xla.core.xla_model as xm
-
- is_tensor = isinstance(obj, Tensor)
- if is_tensor:
- if obj.dim() == 0:
- obj = obj.unsqueeze(0)
- original_device = obj.device
- # XLA distributed requires that the data is on the XLA device
- obj = obj.to(self.root_device)
- else:
- # support for arbitrary pickle-ables
- buffer = io.BytesIO()
- torch.save(obj, buffer)
- obj = torch.tensor( # type: ignore[assignment]
- bytearray(buffer.getbuffer()), device=self.root_device, dtype=torch.float
- )
-
- obj = [obj]
- xm.collective_broadcast(obj, root_ordinal=src)
- obj = obj[0]
-
- if not is_tensor:
- # this will preserve the dtype and device of any tensors
- buffer = io.BytesIO(obj.cpu().byte().numpy())
- obj = torch.load(buffer)
- else:
- obj = obj.to(original_device)
-
- return obj
+ return self.xla_strategy_impl.broadcast(obj=obj, src=src)
@override
def reduce(
@@ -252,60 +166,27 @@ def reduce(
group: Optional[Any] = None,
reduce_op: Optional[Union[ReduceOp, str]] = "mean",
) -> Tensor:
- if not isinstance(output, Tensor):
- output = torch.tensor(output, device=self.root_device)
-
- invalid_reduce_op = isinstance(reduce_op, ReduceOp) and reduce_op != ReduceOp.SUM
- invalid_reduce_op_str = isinstance(reduce_op, str) and reduce_op.lower() not in ("sum", "mean", "avg")
- if invalid_reduce_op or invalid_reduce_op_str:
- raise ValueError(
- "Currently, the XLAStrategy only supports `sum`, `mean`, `avg` for the reduce operation, got:"
- f" {reduce_op}"
- )
-
- import torch_xla.core.xla_model as xm
-
- output = xm.mesh_reduce("reduce", output, sum)
-
- if isinstance(reduce_op, str) and reduce_op.lower() in ("avg", "mean"):
- output = output / self.world_size
-
- return output
+ return self.xla_strategy_impl.reduce(output=output, group=group, reduce_op=reduce_op)
@override
def setup_environment(self) -> None:
- self._launched = True
- super().setup_environment()
+ return self.xla_strategy_impl.setup_environment()
@override
def setup_distributed(self) -> None:
- assert self.parallel_devices is not None
- if len(self.parallel_devices) == 1:
- # spawning only 1 device with PjRT is not supported:
- # https://github.com/Lightning-AI/pytorch-lightning/pull/17408#discussion_r1170671732
- raise NotImplementedError(
- "The `XLAStrategy` does not support running on a single device with the PjRT runtime."
- " Try using all devices or the `SingleDeviceXLAStrategy` strategy"
- )
- rank_zero_only.rank = self.global_rank
+ return self.xla_strategy_impl.setup_distributed()
@override
def set_world_ranks(self) -> None:
- # accessing global_rank will initialize the XLA computation client. since this is called outside of the spawned
- # processes (by the accelerator connector), we cannot run the code that would normally be here.
- # instead it's done in `setup_distributed`
- pass
+ return self.xla_strategy_impl.set_world_ranks()
@override
def save_checkpoint(
self, checkpoint: dict[str, Any], filepath: _PATH, storage_options: Optional[Any] = None
) -> None:
- import torch_xla.core.xla_model as xm
-
- # sync any pending lazy tensors on all ranks before saving to prevent potential collective hangs
- xm.mark_step()
- # save on global rank zero only
- super().save_checkpoint(checkpoint, filepath, storage_options=storage_options)
+ return self.xla_strategy_impl.save_checkpoint(
+ checkpoint=checkpoint, filepath=filepath, storage_options=storage_options
+ )
@override
def remove_checkpoint(self, filepath: _PATH) -> None:
@@ -315,8 +196,7 @@ def remove_checkpoint(self, filepath: _PATH) -> None:
filepath: Path to checkpoint
"""
- if self.local_rank == 0:
- self.checkpoint_io.remove_checkpoint(filepath)
+ return self.xla_strategy_impl.remove_checkpoint(filepath=filepath)
@override
def all_gather(self, tensor: Tensor, group: Optional[Any] = None, sync_grads: bool = False) -> Tensor:
@@ -330,29 +210,11 @@ def all_gather(self, tensor: Tensor, group: Optional[Any] = None, sync_grads: bo
A tensor of shape (world_size, ...)
"""
- if not self._launched:
- return tensor
- if not isinstance(tensor, Tensor):
- raise NotImplementedError(
- f"`{type(self).__name__}.all_gather` is only implemented for tensors. Given {tensor}"
- )
- if tensor.dim() == 0:
- tensor = tensor.unsqueeze(0)
- original_device = tensor.device
- tensor = tensor.to(self.root_device)
-
- import torch_xla.core.functions as xf
- import torch_xla.core.xla_model as xm
-
- tensor = xf.all_gather(tensor) if sync_grads else xm.all_gather(tensor)
- tensor = tensor.to(original_device)
- return tensor
+ return self.xla_strategy_impl.all_gather(tensor=tensor, group=group, sync_grads=sync_grads)
@override
def teardown(self) -> None:
- super().teardown()
- self._launched = False # after the Trainer finishes, we aren't inside the spawned region
- os.environ.pop("PT_XLA_DEBUG", None)
+ return self.xla_strategy_impl.teardown()
@classmethod
@override
diff --git a/src/lightning/pytorch/utilities/deepspeed.py b/src/lightning/pytorch/utilities/deepspeed.py
index 20b418437c681..21a9190552f43 100644
--- a/src/lightning/pytorch/utilities/deepspeed.py
+++ b/src/lightning/pytorch/utilities/deepspeed.py
@@ -20,8 +20,8 @@
import torch
+from lightning.fabric.utilities.imports import _DEEPSPEED_AVAILABLE
from lightning.fabric.utilities.types import _PATH
-from lightning.pytorch.strategies.deepspeed import _DEEPSPEED_AVAILABLE
CPU_DEVICE = torch.device("cpu")
diff --git a/tests/tests_fabric/conftest.py b/tests/tests_fabric/conftest.py
index 9d4a0b9462f2e..1f4dc37c1a6b7 100644
--- a/tests/tests_fabric/conftest.py
+++ b/tests/tests_fabric/conftest.py
@@ -19,9 +19,13 @@
from unittest.mock import Mock
import pytest
+import pytorch_lightning_enterprise.utils.imports
import torch.distributed
import lightning.fabric
+import lightning.fabric.plugins.environments.xla
+import lightning.fabric.plugins.io.xla
+import lightning.fabric.plugins.precision.xla
from lightning.fabric.accelerators import XLAAccelerator
from lightning.fabric.strategies.launchers.subprocess_script import _ChildProcessObserver
from lightning.fabric.utilities.distributed import _destroy_dist_connection
@@ -144,17 +148,27 @@ def reset_cudnn_benchmark():
def mock_xla_available(monkeypatch: pytest.MonkeyPatch, value: bool = True) -> None:
- monkeypatch.setattr(lightning.fabric.accelerators.xla, "_XLA_AVAILABLE", value)
- monkeypatch.setattr(lightning.fabric.plugins.environments.xla, "_XLA_AVAILABLE", value)
- monkeypatch.setattr(lightning.fabric.plugins.precision.xla, "_XLA_AVAILABLE", value)
- monkeypatch.setattr(lightning.fabric.plugins.io.xla, "_XLA_AVAILABLE", value)
- monkeypatch.setattr(lightning.fabric.strategies.single_xla, "_XLA_AVAILABLE", value)
- monkeypatch.setattr(lightning.fabric.strategies.xla_fsdp, "_XLA_AVAILABLE", value)
- monkeypatch.setattr(lightning.fabric.strategies.launchers.xla, "_XLA_AVAILABLE", value)
+ # First, mock torch_xla modules in sys.modules so imports succeed
monkeypatch.setitem(sys.modules, "torch_xla", Mock())
monkeypatch.setitem(sys.modules, "torch_xla.core.xla_model", Mock())
monkeypatch.setitem(sys.modules, "torch_xla.experimental", Mock())
monkeypatch.setitem(sys.modules, "torch_xla.distributed.fsdp.wrap", Mock())
+ monkeypatch.setitem(sys.modules, "torch_xla._internal", Mock())
+ monkeypatch.setitem(sys.modules, "torch_xla._internal.tpu", Mock())
+
+ # Then patch the _XLA_AVAILABLE flags in various modules
+ monkeypatch.setattr(pytorch_lightning_enterprise.utils.imports, "_XLA_AVAILABLE", value)
+ monkeypatch.setattr(pytorch_lightning_enterprise.utils.imports, "_XLA_GREATER_EQUAL_2_1", value)
+ monkeypatch.setattr(pytorch_lightning_enterprise.utils.imports, "_XLA_GREATER_EQUAL_2_5", value)
+ monkeypatch.setattr(lightning.fabric.accelerators.xla, "_XLA_AVAILABLE", value)
+ monkeypatch.setattr(lightning.fabric.accelerators.xla, "_XLA_GREATER_EQUAL_2_1", value)
+ monkeypatch.setattr(lightning.fabric.accelerators.xla, "_XLA_GREATER_EQUAL_2_5", value)
+ # Patch in the modules where they're used after import
+ monkeypatch.setattr("pytorch_lightning_enterprise.accelerators.xla._XLA_AVAILABLE", value)
+ monkeypatch.setattr("pytorch_lightning_enterprise.accelerators.xla._XLA_GREATER_EQUAL_2_1", value)
+ monkeypatch.setattr("pytorch_lightning_enterprise.accelerators.xla._XLA_GREATER_EQUAL_2_5", value)
+ monkeypatch.setattr("pytorch_lightning_enterprise.plugins.environments.xla._XLA_AVAILABLE", value)
+ monkeypatch.setattr("pytorch_lightning_enterprise.plugins.environments.xla._XLA_GREATER_EQUAL_2_1", value)
@pytest.fixture
@@ -166,6 +180,14 @@ def mock_tpu_available(monkeypatch: pytest.MonkeyPatch, value: bool = True) -> N
mock_xla_available(monkeypatch, value)
monkeypatch.setattr(lightning.fabric.accelerators.xla.XLAAccelerator, "is_available", lambda: value)
monkeypatch.setattr(lightning.fabric.accelerators.xla.XLAAccelerator, "auto_device_count", lambda *_: 8)
+ # Also mock the enterprise XLAAccelerator methods
+ import pytorch_lightning_enterprise.accelerators.xla
+
+ monkeypatch.setattr(pytorch_lightning_enterprise.accelerators.xla.XLAAccelerator, "is_available", lambda: value)
+ monkeypatch.setattr(pytorch_lightning_enterprise.accelerators.xla.XLAAccelerator, "auto_device_count", lambda *_: 8)
+ monkeypatch.setitem(sys.modules, "torch_xla", Mock())
+ monkeypatch.setitem(sys.modules, "torch_xla.core.xla_model", Mock())
+ monkeypatch.setitem(sys.modules, "torch_xla.experimental", Mock())
@pytest.fixture
diff --git a/tests/tests_fabric/graveyard/test_tpu.py b/tests/tests_fabric/graveyard/test_tpu.py
index 5b72d60491df7..7aae8ab5d2297 100644
--- a/tests/tests_fabric/graveyard/test_tpu.py
+++ b/tests/tests_fabric/graveyard/test_tpu.py
@@ -15,7 +15,7 @@ def test_graveyard_single_tpu(import_path, name):
module = import_module(import_path)
cls = getattr(module, name)
device = torch.device("cpu")
- with pytest.deprecated_call(match="is deprecated"), pytest.raises(ModuleNotFoundError, match="torch_xla"):
+ with pytest.deprecated_call(match="is deprecated"):
cls(device)
@@ -37,5 +37,5 @@ def test_graveyard_single_tpu(import_path, name):
def test_graveyard_no_device(import_path, name):
module = import_module(import_path)
cls = getattr(module, name)
- with pytest.deprecated_call(match="is deprecated"), pytest.raises(ModuleNotFoundError, match="torch_xla"):
+ with pytest.deprecated_call(match="is deprecated"):
cls()
diff --git a/tests/tests_fabric/plugins/environments/test_kubeflow.py b/tests/tests_fabric/plugins/environments/test_kubeflow.py
index 3436adc9ce2aa..72fb527d3facb 100644
--- a/tests/tests_fabric/plugins/environments/test_kubeflow.py
+++ b/tests/tests_fabric/plugins/environments/test_kubeflow.py
@@ -61,14 +61,14 @@ def test_attributes_from_environment_variables(caplog):
assert env.local_rank() == 0
assert env.node_rank() == 1
# setter should be no-op
- with caplog.at_level(logging.DEBUG, logger="lightning.fabric.plugins.environments"):
+ with caplog.at_level(logging.DEBUG, logger="pytorch_lightning_enterprise.plugins.environments.kubeflow"):
env.set_global_rank(100)
assert env.global_rank() == 1
assert "setting global rank is not allowed" in caplog.text
caplog.clear()
- with caplog.at_level(logging.DEBUG, logger="lightning.fabric.plugins.environments"):
+ with caplog.at_level(logging.DEBUG, logger="pytorch_lightning_enterprise.plugins.environments.kubeflow"):
env.set_world_size(100)
assert env.world_size() == 20
assert "setting world size is not allowed" in caplog.text
diff --git a/tests/tests_fabric/plugins/environments/test_slurm.py b/tests/tests_fabric/plugins/environments/test_slurm.py
index b907c287faa5f..742fc44855083 100644
--- a/tests/tests_fabric/plugins/environments/test_slurm.py
+++ b/tests/tests_fabric/plugins/environments/test_slurm.py
@@ -21,7 +21,6 @@
from lightning_utilities.test.warning import no_warning_call
from lightning.fabric.plugins.environments import SLURMEnvironment
-from lightning.fabric.utilities.warnings import PossibleUserWarning
from tests_fabric.helpers.runif import RunIf
@@ -72,14 +71,14 @@ def test_attributes_from_environment_variables(caplog):
assert env.node_rank() == 3
assert env.job_name() == "JOB"
# setter should be no-op
- with caplog.at_level(logging.DEBUG, logger="lightning.fabric.plugins.environments"):
+ with caplog.at_level(logging.DEBUG, logger="pytorch_lightning_enterprise.plugins.environments.slurm"):
env.set_global_rank(100)
assert env.global_rank() == 1
assert "setting global rank is not allowed" in caplog.text
caplog.clear()
- with caplog.at_level(logging.DEBUG, logger="lightning.fabric.plugins.environments"):
+ with caplog.at_level(logging.DEBUG, logger="pytorch_lightning_enterprise.plugins.environments.slurm"):
env.set_world_size(100)
assert env.world_size() == 20
assert "setting world size is not allowed" in caplog.text
@@ -135,18 +134,18 @@ def test_detect():
def test_srun_available_and_not_used(monkeypatch):
"""Test that a warning is emitted if Lightning suspects the user forgot to run their script with `srun`."""
monkeypatch.setattr(sys, "argv", ["train.py", "--lr", "0.01"])
- expected = "`srun` .* available .* but is not used. HINT: .* srun python train.py --lr 0.01"
+ expected = r"`srun` .* available .* but is not used. HINT: .* srun python\d* train.py --lr 0.01"
# pretend `srun` is available
with mock.patch("lightning.fabric.plugins.environments.slurm.shutil.which", return_value="/usr/bin/srun"):
- with pytest.warns(PossibleUserWarning, match=expected):
+ with pytest.warns(UserWarning, match=expected):
SLURMEnvironment()
- with pytest.warns(PossibleUserWarning, match=expected):
+ with pytest.warns(UserWarning, match=expected):
SLURMEnvironment.detect()
# no warning if `srun` is unavailable
- with no_warning_call(PossibleUserWarning, match=expected):
+ with no_warning_call(UserWarning, match=expected):
SLURMEnvironment()
assert not SLURMEnvironment.detect()
@@ -176,7 +175,7 @@ def test_validate_user_settings():
# in interactive mode, validation is skipped because processes get launched by Fabric/Trainer, not SLURM
with mock.patch(
- "lightning.fabric.plugins.environments.slurm.SLURMEnvironment.job_name", return_value="interactive"
+ "pytorch_lightning_enterprise.plugins.environments.slurm.SLURMEnvironment.job_name", return_value="interactive"
):
env = SLURMEnvironment()
env.validate_settings(num_devices=4, num_nodes=1) # no error
diff --git a/tests/tests_fabric/plugins/environments/test_torchelastic.py b/tests/tests_fabric/plugins/environments/test_torchelastic.py
index 161d42894df30..cfd196c231036 100644
--- a/tests/tests_fabric/plugins/environments/test_torchelastic.py
+++ b/tests/tests_fabric/plugins/environments/test_torchelastic.py
@@ -58,14 +58,14 @@ def test_attributes_from_environment_variables(caplog):
assert env.local_rank() == 2
assert env.node_rank() == 3
# setter should be no-op
- with caplog.at_level(logging.DEBUG, logger="lightning.fabric.plugins.environments"):
+ with caplog.at_level(logging.DEBUG, logger="pytorch_lightning_enterprise.plugins.environments.torchelastic"):
env.set_global_rank(100)
assert env.global_rank() == 1
assert "setting global rank is not allowed" in caplog.text
caplog.clear()
- with caplog.at_level(logging.DEBUG, logger="lightning.fabric.plugins.environments"):
+ with caplog.at_level(logging.DEBUG, logger="pytorch_lightning_enterprise.plugins.environments.torchelastic"):
env.set_world_size(100)
assert env.world_size() == 20
assert "setting world size is not allowed" in caplog.text
diff --git a/tests/tests_fabric/plugins/environments/test_xla.py b/tests/tests_fabric/plugins/environments/test_xla.py
index 76fd18012ee94..0557f5620e8e1 100644
--- a/tests/tests_fabric/plugins/environments/test_xla.py
+++ b/tests/tests_fabric/plugins/environments/test_xla.py
@@ -100,8 +100,9 @@ def test_detect(monkeypatch):
@mock.patch.dict(os.environ, {}, clear=True)
+@mock.patch("pytorch_lightning_enterprise.utils.imports._XLA_AVAILABLE", True)
+@mock.patch("pytorch_lightning_enterprise.utils.imports._XLA_GREATER_EQUAL_2_1", True)
@mock.patch("lightning.fabric.accelerators.xla._XLA_GREATER_EQUAL_2_1", True)
-@mock.patch("lightning.fabric.plugins.environments.xla._XLA_GREATER_EQUAL_2_1", True)
def test_world_size_from_xla_runtime_greater_2_1(xla_available):
"""Test that world_size uses torch_xla.runtime when XLA >= 2.1."""
env = XLAEnvironment()
@@ -113,8 +114,9 @@ def test_world_size_from_xla_runtime_greater_2_1(xla_available):
@mock.patch.dict(os.environ, {}, clear=True)
+@mock.patch("pytorch_lightning_enterprise.utils.imports._XLA_AVAILABLE", True)
+@mock.patch("pytorch_lightning_enterprise.utils.imports._XLA_GREATER_EQUAL_2_1", True)
@mock.patch("lightning.fabric.accelerators.xla._XLA_GREATER_EQUAL_2_1", True)
-@mock.patch("lightning.fabric.plugins.environments.xla._XLA_GREATER_EQUAL_2_1", True)
def test_global_rank_from_xla_runtime_greater_2_1(xla_available):
"""Test that global_rank uses torch_xla.runtime when XLA >= 2.1."""
env = XLAEnvironment()
@@ -126,8 +128,9 @@ def test_global_rank_from_xla_runtime_greater_2_1(xla_available):
@mock.patch.dict(os.environ, {}, clear=True)
+@mock.patch("pytorch_lightning_enterprise.utils.imports._XLA_AVAILABLE", True)
+@mock.patch("pytorch_lightning_enterprise.utils.imports._XLA_GREATER_EQUAL_2_1", True)
@mock.patch("lightning.fabric.accelerators.xla._XLA_GREATER_EQUAL_2_1", True)
-@mock.patch("lightning.fabric.plugins.environments.xla._XLA_GREATER_EQUAL_2_1", True)
def test_local_rank_from_xla_runtime_greater_2_1(xla_available):
"""Test that local_rank uses torch_xla.runtime when XLA >= 2.1."""
env = XLAEnvironment()
@@ -139,8 +142,9 @@ def test_local_rank_from_xla_runtime_greater_2_1(xla_available):
@mock.patch.dict(os.environ, {}, clear=True)
+@mock.patch("pytorch_lightning_enterprise.utils.imports._XLA_AVAILABLE", True)
+@mock.patch("pytorch_lightning_enterprise.utils.imports._XLA_GREATER_EQUAL_2_1", True)
@mock.patch("lightning.fabric.accelerators.xla._XLA_GREATER_EQUAL_2_1", True)
-@mock.patch("lightning.fabric.plugins.environments.xla._XLA_GREATER_EQUAL_2_1", True)
def test_setters_readonly_when_xla_runtime_greater_2_1(xla_available):
"""Test that set_world_size and set_global_rank don't affect values when using XLA runtime >= 2.1."""
env = XLAEnvironment()
diff --git a/tests/tests_fabric/plugins/precision/test_transformer_engine.py b/tests/tests_fabric/plugins/precision/test_transformer_engine.py
index ed7c984b1ae64..22908e107131b 100644
--- a/tests/tests_fabric/plugins/precision/test_transformer_engine.py
+++ b/tests/tests_fabric/plugins/precision/test_transformer_engine.py
@@ -15,19 +15,22 @@
from unittest.mock import Mock
import pytest
+import pytorch_lightning_enterprise.utils.imports
import torch
import torch.distributed
-import lightning.fabric
from lightning.fabric.connector import _Connector
from lightning.fabric.plugins.precision.transformer_engine import TransformerEnginePrecision
def test_transformer_engine_plugin(monkeypatch):
- module = lightning.fabric.plugins.precision.transformer_engine
+ module = pytorch_lightning_enterprise.utils.imports
if module._TRANSFORMER_ENGINE_AVAILABLE:
pytest.skip("Assumes transformer_engine is unavailable")
- monkeypatch.setattr(module, "_TRANSFORMER_ENGINE_AVAILABLE", lambda: True)
+ monkeypatch.setattr(module, "_TRANSFORMER_ENGINE_AVAILABLE", True)
+ monkeypatch.setattr(
+ "pytorch_lightning_enterprise.plugins.precision.transformer_engine._TRANSFORMER_ENGINE_AVAILABLE", True
+ )
transformer_engine_mock = Mock()
monkeypatch.setitem(sys.modules, "transformer_engine", transformer_engine_mock)
monkeypatch.setitem(sys.modules, "transformer_engine.pytorch", Mock())
@@ -118,8 +121,11 @@ class TELayerNormMock(Mock): ...
def test_convert_module_handles_linear_without_bias(monkeypatch):
- module = lightning.fabric.plugins.precision.transformer_engine # Set up mock transformer_engine
- monkeypatch.setattr(module, "_TRANSFORMER_ENGINE_AVAILABLE", lambda: True)
+ module = pytorch_lightning_enterprise.utils.imports # Set up mock transformer_engine
+ monkeypatch.setattr(module, "_TRANSFORMER_ENGINE_AVAILABLE", True)
+ monkeypatch.setattr(
+ "pytorch_lightning_enterprise.plugins.precision.transformer_engine._TRANSFORMER_ENGINE_AVAILABLE", True
+ )
transformer_engine_mock = Mock()
monkeypatch.setitem(sys.modules, "transformer_engine", transformer_engine_mock)
diff --git a/tests/tests_fabric/strategies/test_deepspeed.py b/tests/tests_fabric/strategies/test_deepspeed.py
index 0194c7b87820a..022e1320068ff 100644
--- a/tests/tests_fabric/strategies/test_deepspeed.py
+++ b/tests/tests_fabric/strategies/test_deepspeed.py
@@ -75,7 +75,7 @@ def test_deepspeed_defaults():
strategy = DeepSpeedStrategy()
assert strategy.config is not None
assert isinstance(strategy.config["zero_optimization"], dict)
- assert strategy._backward_sync_control is None
+ assert strategy.deepspeed_impl._backward_sync_control is None
@RunIf(deepspeed=True)
@@ -230,7 +230,7 @@ def test_deepspeed_save_checkpoint_exclude_frozen_parameters(exclude_frozen_para
from deepspeed import DeepSpeedEngine
strategy = DeepSpeedStrategy(exclude_frozen_parameters=exclude_frozen_parameters)
- assert strategy.exclude_frozen_parameters is exclude_frozen_parameters
+ assert strategy.deepspeed_impl.exclude_frozen_parameters is exclude_frozen_parameters
model = Mock(spec=DeepSpeedEngine, optimizer=None)
model.modules.return_value = [model]
@@ -275,7 +275,7 @@ def test_deepspeed_load_checkpoint_no_state(tmp_path):
@RunIf(deepspeed=True)
-@mock.patch("lightning.fabric.strategies.deepspeed._is_deepspeed_checkpoint", return_value=True)
+@mock.patch("pytorch_lightning_enterprise.strategies.deepspeed._is_deepspeed_checkpoint", return_value=True)
def test_deepspeed_load_checkpoint_one_deepspeed_engine_required(_, tmp_path):
"""Test that the DeepSpeed strategy can only load one DeepSpeedEngine per checkpoint."""
from deepspeed import DeepSpeedEngine
@@ -317,7 +317,7 @@ def test_deepspeed_load_checkpoint_client_state_missing(tmp_path):
@RunIf(deepspeed=True)
-@mock.patch("lightning.fabric.strategies.deepspeed._is_deepspeed_checkpoint", return_value=True)
+@mock.patch("pytorch_lightning_enterprise.strategies.deepspeed._is_deepspeed_checkpoint", return_value=True)
def test_deepspeed_load_checkpoint_state_updated_with_client_state(_, tmp_path):
"""Test that the DeepSpeed strategy properly updates the state variables and returns additional metadata."""
from deepspeed import DeepSpeedEngine
@@ -342,7 +342,7 @@ def test_deepspeed_load_checkpoint_state_updated_with_client_state(_, tmp_path):
@RunIf(deepspeed=True)
@pytest.mark.parametrize("optimzer_state_requested", [True, False])
-@mock.patch("lightning.fabric.strategies.deepspeed._is_deepspeed_checkpoint", return_value=True)
+@mock.patch("pytorch_lightning_enterprise.strategies.deepspeed._is_deepspeed_checkpoint", return_value=True)
def test_deepspeed_load_checkpoint_optimzer_state_requested(_, optimzer_state_requested, tmp_path):
"""Test that the DeepSpeed strategy loads the optimizer state only when requested."""
from deepspeed import DeepSpeedEngine
@@ -427,6 +427,7 @@ def test_validate_parallel_devices_indices(device_indices):
"""
accelerator = Mock(spec=CUDAAccelerator)
+ accelerator.name.return_value = "cuda"
strategy = DeepSpeedStrategy(
accelerator=accelerator, parallel_devices=[torch.device("cuda", i) for i in device_indices]
)
diff --git a/tests/tests_fabric/strategies/test_deepspeed_integration.py b/tests/tests_fabric/strategies/test_deepspeed_integration.py
index 2a04c6b27dded..72adf161fb5ce 100644
--- a/tests/tests_fabric/strategies/test_deepspeed_integration.py
+++ b/tests/tests_fabric/strategies/test_deepspeed_integration.py
@@ -251,7 +251,7 @@ def test_deepspeed_env_variables_on_platforms(_, deepspeed_dist_mock, platform):
strategy = fabric._strategy
assert isinstance(strategy, DeepSpeedStrategy)
with mock.patch("platform.system", return_value=platform) as platform_mock:
- strategy._init_deepspeed_distributed()
+ strategy.deepspeed_impl._init_deepspeed_distributed()
deepspeed_dist_mock.assert_called()
platform_mock.assert_called()
if platform == "Windows":
diff --git a/tests/tests_fabric/strategies/test_xla.py b/tests/tests_fabric/strategies/test_xla.py
index 9ca21d8d8b894..1884208dc2126 100644
--- a/tests/tests_fabric/strategies/test_xla.py
+++ b/tests/tests_fabric/strategies/test_xla.py
@@ -194,14 +194,14 @@ def test_rank_properties_access(xla_available):
strategy.cluster_environment = Mock()
# we're in the main process, no processes have been launched yet
- assert not strategy._launched
+ assert not strategy.xla_strategy_impl._launched
assert strategy.global_rank == 0
assert strategy.local_rank == 0
assert strategy.node_rank == 0
assert strategy.world_size == 1
# simulate we're in a worker process
- strategy._launched = True
+ strategy.xla_strategy_impl._launched = True
assert strategy.global_rank == strategy.cluster_environment.global_rank()
assert strategy.local_rank == strategy.cluster_environment.local_rank()
assert strategy.node_rank == strategy.cluster_environment.node_rank()
diff --git a/tests/tests_fabric/strategies/test_xla_fsdp.py b/tests/tests_fabric/strategies/test_xla_fsdp.py
index c2634283ad110..fa1021402590e 100644
--- a/tests/tests_fabric/strategies/test_xla_fsdp.py
+++ b/tests/tests_fabric/strategies/test_xla_fsdp.py
@@ -23,7 +23,6 @@
from lightning.fabric.accelerators import XLAAccelerator
from lightning.fabric.plugins import XLAPrecision
from lightning.fabric.strategies import XLAFSDPStrategy
-from lightning.fabric.strategies.xla_fsdp import _activation_checkpointing_auto_wrapper, _XLAFSDPBackwardSyncControl
from tests_fabric.helpers.runif import RunIf
@@ -43,6 +42,7 @@ def test_xla_fsdp_setup_optimizer_validation():
def test_xla_fsdp_no_backward_sync():
"""Test that the backward sync control calls `.no_sync()`, and only on a module wrapped in
XlaFullyShardedDataParallel."""
+ from pytorch_lightning_enterprise.strategies.xla.fsdp import _XLAFSDPBackwardSyncControl
from torch_xla.distributed.fsdp import XlaFullyShardedDataParallel
strategy = XLAFSDPStrategy()
@@ -75,41 +75,45 @@ def test_xla_fsdp_grad_clipping_value_error():
@mock.patch.dict(os.environ, os.environ.copy(), clear=True)
+@mock.patch("pytorch_lightning_enterprise.strategies.xla.fsdp._XLA_AVAILABLE", True)
def test_rank_properties_access(xla_available):
"""Test that the strategy returns the expected values depending on whether we're in the main process or not."""
strategy = XLAFSDPStrategy()
strategy.cluster_environment = Mock()
# we're in the main process, no processes have been launched yet
- assert not strategy._launched
+ assert not strategy.xla_fsdp_impl._launched
assert strategy.global_rank == 0
assert strategy.local_rank == 0
assert strategy.node_rank == 0
assert strategy.world_size == 1
# simulate we're in a worker process
- strategy._launched = True
+ strategy.xla_fsdp_impl._launched = True
assert strategy.global_rank == strategy.cluster_environment.global_rank()
assert strategy.local_rank == strategy.cluster_environment.local_rank()
assert strategy.node_rank == strategy.cluster_environment.node_rank()
assert strategy.world_size == strategy.cluster_environment.world_size()
+@mock.patch("pytorch_lightning_enterprise.strategies.xla.fsdp._XLA_AVAILABLE", True)
def test_xla_fsdp_policy(xla_available):
+ from pytorch_lightning_enterprise.strategies.xla import fsdp as fsdp_module
+
strategy = XLAFSDPStrategy(foo=1)
- assert strategy._fsdp_kwargs == {"foo": 1}
+ assert strategy.xla_fsdp_impl._fsdp_kwargs == {"foo": 1}
strategy = XLAFSDPStrategy(auto_wrap_policy={torch.nn.Linear})
- kwargs = strategy._parse_fsdp_kwargs()
+ kwargs = strategy.xla_fsdp_impl._parse_fsdp_kwargs()
assert set(kwargs) == {"auto_wrap_policy", "compute_dtype"}
assert kwargs["auto_wrap_policy"].func._mock_name == "transformer_auto_wrap_policy"
assert kwargs["compute_dtype"] is torch.float32
strategy = XLAFSDPStrategy(activation_checkpointing_policy={torch.nn.Linear})
- _ = strategy._parse_fsdp_kwargs()
- kwargs = strategy._parse_fsdp_kwargs() # ensure it's idempotent
+ _ = strategy.xla_fsdp_impl._parse_fsdp_kwargs()
+ kwargs = strategy.xla_fsdp_impl._parse_fsdp_kwargs() # ensure it's idempotent
assert set(kwargs) == {"auto_wrapper_callable", "compute_dtype"}
- assert kwargs["auto_wrapper_callable"].func is _activation_checkpointing_auto_wrapper
+ assert kwargs["auto_wrapper_callable"].func is fsdp_module._activation_checkpointing_auto_wrapper
assert kwargs["compute_dtype"] is torch.float32
strategy = XLAFSDPStrategy(
@@ -118,17 +122,17 @@ def test_xla_fsdp_policy(xla_available):
activation_checkpointing_policy={torch.nn.Linear},
precision=XLAPrecision("bf16-true"),
)
- kwargs = strategy._parse_fsdp_kwargs()
+ kwargs = strategy.xla_fsdp_impl._parse_fsdp_kwargs()
assert set(kwargs) == {"auto_wrap_policy", "auto_wrapper_callable", "compute_dtype"}
assert kwargs["auto_wrap_policy"].func._mock_name == "transformer_auto_wrap_policy"
- assert kwargs["auto_wrapper_callable"].func is _activation_checkpointing_auto_wrapper
+ assert kwargs["auto_wrapper_callable"].func is fsdp_module._activation_checkpointing_auto_wrapper
assert kwargs["compute_dtype"] is torch.bfloat16
strategy.teardown()
strategy = XLAFSDPStrategy(activation_checkpointing_policy={torch.nn.Linear}, auto_wrapper_callable="foo")
with pytest.raises(ValueError, match="cannot set both"):
- strategy._parse_fsdp_kwargs()
+ strategy.xla_fsdp_impl._parse_fsdp_kwargs()
strategy = XLAFSDPStrategy(activation_checkpointing_policy="foo")
with pytest.raises(TypeError, match="must be a set"):
- strategy._parse_fsdp_kwargs()
+ strategy.xla_fsdp_impl._parse_fsdp_kwargs()
diff --git a/tests/tests_fabric/test_connector.py b/tests/tests_fabric/test_connector.py
index 1074789e71055..faf22bf3a06b5 100644
--- a/tests/tests_fabric/test_connector.py
+++ b/tests/tests_fabric/test_connector.py
@@ -20,6 +20,7 @@
from unittest.mock import Mock
import pytest
+import pytorch_lightning_enterprise.utils.imports
import torch
import torch.distributed
from lightning_utilities.test.warning import no_warning_call
@@ -242,8 +243,10 @@ class TestStrategy(DDPStrategy):
),
],
)
-@mock.patch("lightning.fabric.plugins.environments.lsf.LSFEnvironment._read_hosts", return_value=["node0", "node1"])
-@mock.patch("lightning.fabric.plugins.environments.lsf.LSFEnvironment._get_node_rank", return_value=0)
+@mock.patch(
+ "pytorch_lightning_enterprise.plugins.environments.lsf.LSFEnvironment._read_hosts", return_value=["node0", "node1"]
+)
+@mock.patch("pytorch_lightning_enterprise.plugins.environments.lsf.LSFEnvironment._get_node_rank", return_value=0)
def test_fallback_from_ddp_spawn_to_ddp_on_cluster(_, __, env_vars, expected_environment):
with mock.patch.dict(os.environ, env_vars, clear=True):
connector = _Connector(strategy="ddp_spawn", accelerator="cpu", devices=2)
@@ -341,8 +344,8 @@ def test_cuda_accelerator_can_not_run_on_system(_):
@pytest.mark.skipif(XLAAccelerator.is_available(), reason="test requires missing TPU")
-@mock.patch("lightning.fabric.accelerators.xla._XLA_AVAILABLE", True)
-@mock.patch("lightning.fabric.accelerators.xla._using_pjrt", return_value=True)
+@mock.patch("pytorch_lightning_enterprise.accelerators.xla._XLA_AVAILABLE", True)
+@mock.patch("pytorch_lightning_enterprise.accelerators.xla._using_pjrt", return_value=True)
def test_tpu_accelerator_can_not_run_on_system(_):
with pytest.raises(RuntimeError, match="XLAAccelerator` can not run on your system"):
_Connector(accelerator="tpu", devices=8)
@@ -888,7 +891,7 @@ def test_precision_selection_model_parallel(_, precision, raises):
def test_bitsandbytes_precision_cuda_required(monkeypatch):
- monkeypatch.setattr(lightning.fabric.plugins.precision.bitsandbytes, "_BITSANDBYTES_AVAILABLE", True)
+ monkeypatch.setattr(pytorch_lightning_enterprise.utils.imports, "_BITSANDBYTES_AVAILABLE", True)
monkeypatch.setitem(sys.modules, "bitsandbytes", Mock())
with pytest.raises(RuntimeError, match="Bitsandbytes is only supported on CUDA GPUs"):
_Connector(accelerator="cpu", plugins=BitsandbytesPrecision(mode="int8"))
diff --git a/tests/tests_fabric/utilities/test_throughput.py b/tests/tests_fabric/utilities/test_throughput.py
index a175fa97fd444..785ec0a3ed2e9 100644
--- a/tests/tests_fabric/utilities/test_throughput.py
+++ b/tests/tests_fabric/utilities/test_throughput.py
@@ -45,7 +45,13 @@ def test_get_available_flops(xla_available):
):
assert get_available_flops(torch.device("cuda"), torch.bfloat16) is None
- from torch_xla.experimental import tpu
+ # Import from the right module based on _XLA_GREATER_EQUAL_2_1
+ from lightning.fabric.accelerators.xla import _XLA_GREATER_EQUAL_2_1
+
+ if _XLA_GREATER_EQUAL_2_1:
+ from torch_xla._internal import tpu
+ else:
+ from torch_xla.experimental import tpu
assert isinstance(tpu, Mock)
diff --git a/tests/tests_pytorch/accelerators/test_xla.py b/tests/tests_pytorch/accelerators/test_xla.py
index 5e56d5c585c88..52434ee3a19d5 100644
--- a/tests/tests_pytorch/accelerators/test_xla.py
+++ b/tests/tests_pytorch/accelerators/test_xla.py
@@ -22,7 +22,6 @@
from torch import nn
from torch.utils.data import DataLoader
-import lightning.fabric
from lightning.fabric.utilities.imports import _IS_WINDOWS
from lightning.pytorch import Trainer
from lightning.pytorch.accelerators import CPUAccelerator, XLAAccelerator
@@ -312,7 +311,7 @@ def test_warning_if_tpus_not_used(tpu_available):
)
@RunIf(min_python="3.9") # mocking issue
def test_trainer_config_device_ids(devices, expected_device_ids, tpu_available, monkeypatch):
- monkeypatch.setattr(lightning.fabric.accelerators.xla, "_using_pjrt", lambda: True)
+ monkeypatch.setattr("pytorch_lightning_enterprise.accelerators.xla._using_pjrt", lambda: True)
mock = DeviceMock()
monkeypatch.setattr(torch, "device", mock)
diff --git a/tests/tests_pytorch/conftest.py b/tests/tests_pytorch/conftest.py
index 878298c6bfd94..ddee9998f9f4a 100644
--- a/tests/tests_pytorch/conftest.py
+++ b/tests/tests_pytorch/conftest.py
@@ -22,6 +22,7 @@
from unittest.mock import Mock
import pytest
+import pytorch_lightning_enterprise.utils.imports
import torch.distributed
from tqdm import TMonitor
@@ -220,14 +221,41 @@ def mps_count_1(monkeypatch):
def mock_xla_available(monkeypatch: pytest.MonkeyPatch, value: bool = True) -> None:
- monkeypatch.setattr(lightning.pytorch.strategies.xla, "_XLA_AVAILABLE", value)
- monkeypatch.setattr(lightning.pytorch.strategies.single_xla, "_XLA_AVAILABLE", value)
- monkeypatch.setattr(lightning.pytorch.plugins.precision.xla, "_XLA_AVAILABLE", value)
- monkeypatch.setattr(lightning.pytorch.strategies.launchers.xla, "_XLA_AVAILABLE", value)
+ # First, mock torch_xla modules in sys.modules so imports succeed
+ monkeypatch.setitem(sys.modules, "torch_xla", Mock())
+ monkeypatch.setitem(sys.modules, "torch_xla.core", Mock())
+ monkeypatch.setitem(sys.modules, "torch_xla.core.xla_model", Mock())
+ monkeypatch.setitem(sys.modules, "torch_xla.core.functions", Mock())
+ monkeypatch.setitem(sys.modules, "torch_xla.core.xla_env_vars", Mock())
+ monkeypatch.setitem(sys.modules, "torch_xla.experimental", Mock())
+ monkeypatch.setitem(sys.modules, "torch_xla.experimental.pjrt", Mock())
+ monkeypatch.setitem(sys.modules, "torch_xla.experimental.tpu", Mock())
+ monkeypatch.setitem(sys.modules, "torch_xla.distributed", Mock())
+ monkeypatch.setitem(sys.modules, "torch_xla.distributed.fsdp", Mock())
+ monkeypatch.setitem(sys.modules, "torch_xla.distributed.fsdp.wrap", Mock())
+ monkeypatch.setitem(sys.modules, "torch_xla.distributed.parallel_loader", Mock())
+ monkeypatch.setitem(sys.modules, "torch_xla.distributed.xla_multiprocessing", Mock())
+ monkeypatch.setitem(sys.modules, "torch_xla.runtime", Mock())
+ monkeypatch.setitem(sys.modules, "torch_xla.utils", Mock())
+ monkeypatch.setitem(sys.modules, "torch_xla.utils.utils", Mock())
+ monkeypatch.setitem(sys.modules, "torch_xla.debug", Mock())
+ monkeypatch.setitem(sys.modules, "torch_xla.debug.profiler", Mock())
+ monkeypatch.setitem(sys.modules, "torch_xla._internal", Mock())
+ monkeypatch.setitem(sys.modules, "torch_xla._internal.tpu", Mock())
+
+ # Then patch the _XLA_AVAILABLE flags in various modules
+ monkeypatch.setattr(pytorch_lightning_enterprise.utils.imports, "_XLA_AVAILABLE", value)
+ monkeypatch.setattr(pytorch_lightning_enterprise.utils.imports, "_XLA_GREATER_EQUAL_2_1", value)
+ monkeypatch.setattr(pytorch_lightning_enterprise.utils.imports, "_XLA_GREATER_EQUAL_2_5", value)
monkeypatch.setattr(lightning.fabric.accelerators.xla, "_XLA_AVAILABLE", value)
- monkeypatch.setattr(lightning.fabric.plugins.environments.xla, "_XLA_AVAILABLE", value)
- monkeypatch.setattr(lightning.fabric.plugins.io.xla, "_XLA_AVAILABLE", value)
- monkeypatch.setattr(lightning.fabric.strategies.launchers.xla, "_XLA_AVAILABLE", value)
+ monkeypatch.setattr(lightning.fabric.accelerators.xla, "_XLA_GREATER_EQUAL_2_1", value)
+ monkeypatch.setattr(lightning.fabric.accelerators.xla, "_XLA_GREATER_EQUAL_2_5", value)
+ # Patch in the modules where they're used after import
+ monkeypatch.setattr("pytorch_lightning_enterprise.accelerators.xla._XLA_AVAILABLE", value)
+ monkeypatch.setattr("pytorch_lightning_enterprise.accelerators.xla._XLA_GREATER_EQUAL_2_1", value)
+ monkeypatch.setattr("pytorch_lightning_enterprise.accelerators.xla._XLA_GREATER_EQUAL_2_5", value)
+ monkeypatch.setattr("pytorch_lightning_enterprise.plugins.environments.xla._XLA_AVAILABLE", value)
+ monkeypatch.setattr("pytorch_lightning_enterprise.plugins.environments.xla._XLA_GREATER_EQUAL_2_1", value)
@pytest.fixture
@@ -237,10 +265,14 @@ def xla_available(monkeypatch: pytest.MonkeyPatch) -> None:
def mock_tpu_available(monkeypatch: pytest.MonkeyPatch, value: bool = True) -> None:
mock_xla_available(monkeypatch, value)
- monkeypatch.setattr(lightning.pytorch.accelerators.xla.XLAAccelerator, "is_available", lambda: value)
monkeypatch.setattr(lightning.fabric.accelerators.xla.XLAAccelerator, "is_available", lambda: value)
- monkeypatch.setattr(lightning.pytorch.accelerators.xla.XLAAccelerator, "auto_device_count", lambda *_: 8)
monkeypatch.setattr(lightning.fabric.accelerators.xla.XLAAccelerator, "auto_device_count", lambda *_: 8)
+ # Also mock the enterprise XLAAccelerator methods
+ import pytorch_lightning_enterprise.accelerators.xla
+
+ monkeypatch.setattr(pytorch_lightning_enterprise.accelerators.xla.XLAAccelerator, "is_available", lambda: value)
+ monkeypatch.setattr(pytorch_lightning_enterprise.accelerators.xla.XLAAccelerator, "auto_device_count", lambda *_: 8)
+
monkeypatch.setitem(sys.modules, "torch_xla", Mock())
monkeypatch.setitem(sys.modules, "torch_xla.core.xla_model", Mock())
monkeypatch.setitem(sys.modules, "torch_xla.experimental", Mock())
diff --git a/tests/tests_pytorch/deprecated_api/test_no_removal_version.py b/tests/tests_pytorch/deprecated_api/test_no_removal_version.py
index 0b79b638534fa..a24d53af58fa7 100644
--- a/tests/tests_pytorch/deprecated_api/test_no_removal_version.py
+++ b/tests/tests_pytorch/deprecated_api/test_no_removal_version.py
@@ -2,6 +2,7 @@
from unittest.mock import Mock
import pytest
+import pytorch_lightning_enterprise.utils.imports
import torch.nn
import lightning.fabric
@@ -62,6 +63,7 @@ def test_fsdp_precision_plugin():
def test_bitsandbytes_precision_plugin(monkeypatch):
monkeypatch.setattr(lightning.fabric.plugins.precision.bitsandbytes, "_BITSANDBYTES_AVAILABLE", True)
+ monkeypatch.setattr(pytorch_lightning_enterprise.utils.imports, "_BITSANDBYTES_AVAILABLE", True)
bitsandbytes_mock = Mock()
monkeypatch.setitem(sys.modules, "bitsandbytes", bitsandbytes_mock)
@@ -107,7 +109,8 @@ def test_precision_plugin():
def test_transformer_engine_precision_plugin(monkeypatch):
- monkeypatch.setattr(lightning.fabric.plugins.precision.transformer_engine, "_TRANSFORMER_ENGINE_AVAILABLE", True)
+ monkeypatch.setattr(lightning.fabric.utilities.imports, "_TRANSFORMER_ENGINE_AVAILABLE", True)
+ monkeypatch.setattr(pytorch_lightning_enterprise.utils.imports, "_TRANSFORMER_ENGINE_AVAILABLE", True)
transformer_engine_mock = Mock()
monkeypatch.setitem(sys.modules, "transformer_engine", transformer_engine_mock)
monkeypatch.setitem(sys.modules, "transformer_engine.pytorch", Mock())
diff --git a/tests/tests_pytorch/loggers/conftest.py b/tests/tests_pytorch/loggers/conftest.py
index 98819eb080eb8..f5a0c41fc97a4 100644
--- a/tests/tests_pytorch/loggers/conftest.py
+++ b/tests/tests_pytorch/loggers/conftest.py
@@ -38,8 +38,8 @@ def mlflow_mock(monkeypatch):
mlflow.tracking = mlflow_tracking
mlflow.entities = mlflow_entities
- monkeypatch.setattr("lightning.pytorch.loggers.mlflow._MLFLOW_AVAILABLE", True)
- monkeypatch.setattr("lightning.pytorch.loggers.mlflow._MLFLOW_SYNCHRONOUS_AVAILABLE", True)
+ monkeypatch.setattr("pytorch_lightning_enterprise.utils.imports._MLFLOW_AVAILABLE", True)
+ monkeypatch.setattr("pytorch_lightning_enterprise.utils.imports._MLFLOW_SYNCHRONOUS_AVAILABLE", True)
return mlflow
@@ -86,7 +86,7 @@ class RunType: # to make isinstance checks pass
wandb.sdk.lib = wandb_sdk_lib
wandb.wandb_run = wandb_wandb_run
- monkeypatch.setattr("lightning.pytorch.loggers.wandb._WANDB_AVAILABLE", True)
+ monkeypatch.setattr("pytorch_lightning_enterprise.utils.imports._WANDB_AVAILABLE", True)
return wandb
@@ -109,7 +109,7 @@ def comet_mock(monkeypatch):
comet.start = Mock(name="comet_ml.start", return_value=comet.Experiment())
comet.config = Mock()
- monkeypatch.setattr("lightning.pytorch.loggers.comet._COMET_AVAILABLE", True)
+ monkeypatch.setattr("pytorch_lightning_enterprise.utils.imports._COMET_AVAILABLE", True)
return comet
@@ -153,5 +153,5 @@ def __setitem__(self, key, value):
neptune.types = neptune_types
neptune.utils = neptune_utils
- monkeypatch.setattr("lightning.pytorch.loggers.neptune._NEPTUNE_AVAILABLE", True)
+ monkeypatch.setattr("pytorch_lightning_enterprise.utils.imports._NEPTUNE_AVAILABLE", True)
return neptune
diff --git a/tests/tests_pytorch/loggers/test_all.py b/tests/tests_pytorch/loggers/test_all.py
index a9cf9af79185b..22b7632d5bea1 100644
--- a/tests/tests_pytorch/loggers/test_all.py
+++ b/tests/tests_pytorch/loggers/test_all.py
@@ -70,7 +70,7 @@ def _instantiate_logger(logger_class, save_dir, **override_kwargs):
@mock.patch.dict(os.environ, {})
-@mock.patch("lightning.pytorch.loggers.mlflow._get_resolve_tags", Mock())
+@mock.patch("pytorch_lightning_enterprise.loggers.mlflow._get_resolve_tags", Mock())
@pytest.mark.parametrize("logger_class", ALL_LOGGER_CLASSES)
def test_loggers_fit_test_all(logger_class, mlflow_mock, wandb_mock, comet_mock, neptune_mock, tmp_path, monkeypatch):
"""Verify that basic functionality of all loggers."""
@@ -335,7 +335,7 @@ def test_logger_with_prefix_all(mlflow_mock, wandb_mock, comet_mock, neptune_moc
logger.experiment.log.assert_called_once_with({"tmp-test": 1.0, "trainer/global_step": 0})
-@mock.patch("lightning.pytorch.loggers.mlflow._get_resolve_tags", Mock())
+@mock.patch("pytorch_lightning_enterprise.loggers.mlflow._get_resolve_tags", Mock())
def test_logger_default_name(mlflow_mock, monkeypatch, tmp_path):
"""Test that the default logger name is lightning_logs."""
# CSV
diff --git a/tests/tests_pytorch/loggers/test_comet.py b/tests/tests_pytorch/loggers/test_comet.py
index dae8f617b873e..4108d3c3b641c 100644
--- a/tests/tests_pytorch/loggers/test_comet.py
+++ b/tests/tests_pytorch/loggers/test_comet.py
@@ -89,10 +89,10 @@ def test_comet_experiment_is_still_alive_after_training_complete(comet_mock):
logger.finalize("ended")
# Assert that data was saved to comet.com
- logger._experiment.flush.assert_called_once()
+ logger.logger_impl._experiment.flush.assert_called_once()
# Assert that was not ended
- logger._experiment.end.assert_not_called()
+ logger.logger_impl._experiment.end.assert_not_called()
@mock.patch.dict(os.environ, {})
@@ -115,8 +115,8 @@ def test_comet_logger_experiment_name(comet_mock):
experiment_config=comet_mock.ExperimentConfig(),
)
# check that we saved "experiment name" in kwargs as new "name" arg
- assert logger._kwargs["name"] == experiment_name
- assert "experiment_name" not in logger._kwargs
+ assert logger.logger_impl._kwargs["name"] == experiment_name
+ assert "experiment_name" not in logger.logger_impl._kwargs
# check that "experiment name" was passed to experiment config correctly
assert call(experiment_name=experiment_name) not in comet_mock.ExperimentConfig.call_args_list
@@ -130,10 +130,10 @@ def test_comet_version(comet_mock):
experiment_name = "My Name"
logger = CometLogger(api_key=api_key, name=experiment_name)
- assert logger._experiment is not None
+ assert logger.logger_impl._experiment is not None
_ = logger.version
- logger._experiment.get_key.assert_called()
+ logger.logger_impl._experiment.get_key.assert_called()
@mock.patch.dict(os.environ, {})
@@ -146,7 +146,7 @@ def test_comet_epoch_logging(comet_mock, tmp_path, monkeypatch):
{"test": 1},
epoch=1,
step=123,
- prefix=logger._prefix,
+ prefix=logger.logger_impl._prefix,
framework="pytorch-lightning",
)
diff --git a/tests/tests_pytorch/loggers/test_mlflow.py b/tests/tests_pytorch/loggers/test_mlflow.py
index c7f9dbe1fe2c6..c45f5ee5b8295 100644
--- a/tests/tests_pytorch/loggers/test_mlflow.py
+++ b/tests/tests_pytorch/loggers/test_mlflow.py
@@ -16,13 +16,12 @@
from unittest.mock import MagicMock, Mock
import pytest
+from pytorch_lightning_enterprise.utils.imports import _MLFLOW_AVAILABLE
from lightning.pytorch import Trainer
from lightning.pytorch.demos.boring_classes import BoringModel
from lightning.pytorch.loggers.mlflow import (
- _MLFLOW_AVAILABLE,
MLFlowLogger,
- _get_resolve_tags,
)
@@ -30,13 +29,13 @@ def mock_mlflow_run_creation(logger, experiment_name=None, experiment_id=None, r
"""Helper function to simulate mlflow client creating a new (or existing) experiment."""
run = MagicMock()
run.info.run_id = run_id
- logger._mlflow_client.get_experiment_by_name = MagicMock(return_value=experiment_name)
- logger._mlflow_client.create_experiment = MagicMock(return_value=experiment_id)
- logger._mlflow_client.create_run = MagicMock(return_value=run)
+ logger.logger_impl._mlflow_client.get_experiment_by_name = MagicMock(return_value=experiment_name)
+ logger.logger_impl._mlflow_client.create_experiment = MagicMock(return_value=experiment_id)
+ logger.logger_impl._mlflow_client.create_run = MagicMock(return_value=run)
return logger
-@mock.patch("lightning.pytorch.loggers.mlflow._get_resolve_tags", Mock())
+@mock.patch("pytorch_lightning_enterprise.loggers.mlflow._get_resolve_tags", Mock())
def test_mlflow_logger_exists(mlflow_mock, tmp_path):
"""Test launching three independent loggers with either same or different experiment name."""
client = mlflow_mock.tracking.MlflowClient
@@ -57,8 +56,8 @@ def test_mlflow_logger_exists(mlflow_mock, tmp_path):
client.return_value.create_run = MagicMock(return_value=run1)
logger = MLFlowLogger("test", save_dir=str(tmp_path))
- assert logger._experiment_id is None
- assert logger._run_id is None
+ assert logger.logger_impl._experiment_id is None
+ assert logger.logger_impl._run_id is None
_ = logger.experiment
assert logger.experiment_id == "exp-id-1"
assert logger.run_id == "run-id-1"
@@ -96,13 +95,14 @@ def test_mlflow_run_name_setting(tmp_path):
pytest.skip("test for explicit file creation requires mlflow dependency to be installed.")
from mlflow.utils.mlflow_tags import MLFLOW_RUN_NAME
+ from pytorch_lightning_enterprise.loggers.mlflow import _get_resolve_tags
resolve_tags = _get_resolve_tags()
tags = resolve_tags({MLFLOW_RUN_NAME: "run-name-1"})
# run_name is appended to tags
logger = MLFlowLogger("test", run_name="run-name-1", save_dir=str(tmp_path))
- logger._mlflow_client = client = Mock()
+ logger.logger_impl._mlflow_client = client = Mock()
logger = mock_mlflow_run_creation(logger, experiment_id="exp-id")
_ = logger.experiment
@@ -122,7 +122,7 @@ def test_mlflow_run_name_setting(tmp_path):
client.create_run.assert_called_with(experiment_id="exp-id", tags=default_tags)
-@mock.patch("lightning.pytorch.loggers.mlflow._get_resolve_tags", Mock())
+@mock.patch("pytorch_lightning_enterprise.loggers.mlflow._get_resolve_tags", Mock())
def test_mlflow_run_id_setting(mlflow_mock, tmp_path):
"""Test that the run_id argument uses the provided run_id."""
client = mlflow_mock.tracking.MlflowClient
@@ -143,7 +143,7 @@ def test_mlflow_run_id_setting(mlflow_mock, tmp_path):
client.reset_mock(return_value=True)
-@mock.patch("lightning.pytorch.loggers.mlflow._get_resolve_tags", Mock())
+@mock.patch("pytorch_lightning_enterprise.loggers.mlflow._get_resolve_tags", Mock())
def test_mlflow_log_dir(mlflow_mock, tmp_path):
"""Test that the trainer saves checkpoints in the logger's save dir."""
client = mlflow_mock.tracking.MlflowClient
@@ -211,8 +211,8 @@ def on_train_epoch_end(self, *args, **kwargs):
assert os.listdir(trainer.checkpoint_callback.dirpath) == [f"epoch=0-step={limit_batches}.ckpt"]
-@mock.patch("lightning.pytorch.loggers.mlflow._get_resolve_tags", Mock())
-@mock.patch("lightning.pytorch.loggers.mlflow._MLFLOW_AVAILABLE", return_value=True)
+@mock.patch("pytorch_lightning_enterprise.loggers.mlflow._get_resolve_tags", Mock())
+@mock.patch("pytorch_lightning_enterprise.utils.imports._MLFLOW_AVAILABLE", return_value=True)
def test_mlflow_experiment_id_retrieved_once(_, mlflow_mock, tmp_path):
"""Test that the logger experiment_id retrieved only once."""
logger = MLFlowLogger("test", save_dir=str(tmp_path))
@@ -222,7 +222,7 @@ def test_mlflow_experiment_id_retrieved_once(_, mlflow_mock, tmp_path):
assert logger.experiment.get_experiment_by_name.call_count == 1
-@mock.patch("lightning.pytorch.loggers.mlflow._get_resolve_tags", Mock())
+@mock.patch("pytorch_lightning_enterprise.loggers.mlflow._get_resolve_tags", Mock())
def test_mlflow_logger_with_unexpected_characters(mlflow_mock, tmp_path):
"""Test that the logger raises warning with special characters not accepted by MLFlow."""
logger = MLFlowLogger("test", save_dir=str(tmp_path))
@@ -232,7 +232,7 @@ def test_mlflow_logger_with_unexpected_characters(mlflow_mock, tmp_path):
logger.log_metrics(metrics)
-@mock.patch("lightning.pytorch.loggers.mlflow._get_resolve_tags", Mock())
+@mock.patch("pytorch_lightning_enterprise.loggers.mlflow._get_resolve_tags", Mock())
def test_mlflow_logger_experiment_calls(mlflow_mock, tmp_path):
"""Test that the logger calls methods on the mlflow experiment correctly."""
time = mlflow_mock.entities.time
@@ -242,7 +242,7 @@ def test_mlflow_logger_experiment_calls(mlflow_mock, tmp_path):
time.return_value = 1
logger = MLFlowLogger("test", save_dir=str(tmp_path), artifact_location="my_artifact_location")
- logger._mlflow_client.get_experiment_by_name.return_value = None
+ logger.logger_impl._mlflow_client.get_experiment_by_name.return_value = None
params = {"test": "test_param"}
logger.log_hyperparams(params)
@@ -260,13 +260,13 @@ def test_mlflow_logger_experiment_calls(mlflow_mock, tmp_path):
)
metric.assert_called_with(key="some_metric", value=10, timestamp=1000, step=0)
- logger._mlflow_client.create_experiment.assert_called_once_with(
+ logger.logger_impl._mlflow_client.create_experiment.assert_called_once_with(
name="test", artifact_location="my_artifact_location"
)
@pytest.mark.parametrize("synchronous", [False, True])
-@mock.patch("lightning.pytorch.loggers.mlflow._get_resolve_tags", Mock())
+@mock.patch("pytorch_lightning_enterprise.loggers.mlflow._get_resolve_tags", Mock())
def test_mlflow_logger_experiment_calls_with_synchronous(mlflow_mock, tmp_path, synchronous):
"""Test that the logger calls methods on the mlflow experiment with the specified synchronous flag."""
@@ -302,8 +302,8 @@ def test_mlflow_logger_experiment_calls_with_synchronous(mlflow_mock, tmp_path,
mlflow_client.create_experiment.assert_called_once_with(name="test", artifact_location="my_artifact_location")
-@mock.patch("lightning.pytorch.loggers.mlflow._get_resolve_tags", Mock())
-@mock.patch.dict("lightning.pytorch.loggers.mlflow.__dict__", {"_MLFLOW_SYNCHRONOUS_AVAILABLE": False})
+@mock.patch("pytorch_lightning_enterprise.loggers.mlflow._get_resolve_tags", Mock())
+@mock.patch("pytorch_lightning_enterprise.utils.imports._MLFLOW_SYNCHRONOUS_AVAILABLE", False)
def test_mlflow_logger_no_synchronous_support(mlflow_mock, tmp_path):
"""Test that the logger does not support synchronous flag."""
time = mlflow_mock.entities.time
@@ -315,7 +315,7 @@ def test_mlflow_logger_no_synchronous_support(mlflow_mock, tmp_path):
MLFlowLogger("test", save_dir=str(tmp_path), artifact_location="my_artifact_location", synchronous=True)
-@mock.patch("lightning.pytorch.loggers.mlflow._get_resolve_tags", Mock())
+@mock.patch("pytorch_lightning_enterprise.loggers.mlflow._get_resolve_tags", Mock())
def test_mlflow_logger_with_long_param_value(mlflow_mock, tmp_path):
"""Test that long parameter values are truncated to 250 characters."""
@@ -333,7 +333,7 @@ def _check_value_length(value, *args, **kwargs):
logger.experiment.log_batch.assert_called_once()
-@mock.patch("lightning.pytorch.loggers.mlflow._get_resolve_tags", Mock())
+@mock.patch("pytorch_lightning_enterprise.loggers.mlflow._get_resolve_tags", Mock())
def test_mlflow_logger_with_many_params(mlflow_mock, tmp_path):
"""Test that when logging more than 100 parameters, it will be split into batches of at most 100 parameters."""
logger = MLFlowLogger("test", save_dir=str(tmp_path))
@@ -352,37 +352,37 @@ def test_mlflow_logger_with_many_params(mlflow_mock, tmp_path):
("finished", "FINISHED"),
],
)
-@mock.patch("lightning.pytorch.loggers.mlflow._get_resolve_tags", Mock())
+@mock.patch("pytorch_lightning_enterprise.loggers.mlflow._get_resolve_tags", Mock())
def test_mlflow_logger_finalize(mlflow_mock, status, expected):
logger = MLFlowLogger("test")
# Pretend we are in a worker process and finalizing
_ = logger.experiment
- assert logger._initialized
+ assert logger.logger_impl._initialized
logger.finalize(status)
logger.experiment.set_terminated.assert_called_once_with(logger.run_id, expected)
-@mock.patch("lightning.pytorch.loggers.mlflow._get_resolve_tags", Mock())
+@mock.patch("pytorch_lightning_enterprise.loggers.mlflow._get_resolve_tags", Mock())
def test_mlflow_logger_finalize_when_exception(mlflow_mock):
logger = MLFlowLogger("test")
# Pretend we are on the main process and failing
- assert logger._mlflow_client
- assert not logger._initialized
+ assert logger.logger_impl._mlflow_client
+ assert not logger.logger_impl._initialized
logger.finalize("failed")
logger.experiment.set_terminated.assert_not_called()
# Pretend we are in a worker process and failing
_ = logger.experiment
- assert logger._initialized
+ assert logger.logger_impl._initialized
logger.finalize("failed")
logger.experiment.set_terminated.assert_called_once_with(logger.run_id, "FAILED")
@pytest.mark.parametrize("log_model", ["all", True, False])
-@mock.patch("lightning.pytorch.loggers.mlflow._get_resolve_tags", Mock())
+@mock.patch("pytorch_lightning_enterprise.loggers.mlflow._get_resolve_tags", Mock())
def test_mlflow_log_model(mlflow_mock, log_model, tmp_path):
"""Test that the logger creates the folders and files in the right place."""
client = mlflow_mock.tracking.MlflowClient
@@ -420,7 +420,7 @@ def test_mlflow_log_model(mlflow_mock, log_model, tmp_path):
assert not client.return_value.log_artifacts.called
-@mock.patch("lightning.pytorch.loggers.mlflow._get_resolve_tags", Mock())
+@mock.patch("pytorch_lightning_enterprise.loggers.mlflow._get_resolve_tags", Mock())
def test_set_tracking_uri(mlflow_mock):
"""Test that the tracking uri is set for logging artifacts to MLFlow server."""
logger = MLFlowLogger(tracking_uri="the_tracking_uri")
diff --git a/tests/tests_pytorch/loggers/test_neptune.py b/tests/tests_pytorch/loggers/test_neptune.py
index 572b85715da86..e50bcf967ef9a 100644
--- a/tests/tests_pytorch/loggers/test_neptune.py
+++ b/tests/tests_pytorch/loggers/test_neptune.py
@@ -13,7 +13,6 @@
# limitations under the License.
import os
import pickle
-from collections import namedtuple
from unittest import mock
from unittest.mock import MagicMock, call
@@ -45,10 +44,10 @@ def _fit_and_test(logger, model, tmp_path):
def _get_logger_with_mocks(**kwargs):
logger = NeptuneLogger(**kwargs)
run_instance_mock = MagicMock()
- logger._run_instance = run_instance_mock
- logger._run_instance.__getitem__.return_value.fetch.return_value = "exp-name"
+ logger.logger_impl._run_instance = run_instance_mock
+ logger.logger_impl._run_instance.__getitem__.return_value.fetch.return_value = "exp-name"
run_attr_mock = MagicMock()
- logger._run_instance.__getitem__.return_value = run_attr_mock
+ logger.logger_impl._run_instance.__getitem__.return_value = run_attr_mock
return logger, run_instance_mock, run_attr_mock
@@ -60,7 +59,7 @@ def test_neptune_online(neptune_mock):
logger = NeptuneLogger(api_key="test", project="project")
created_run_mock = logger.run
- assert logger._run_instance == created_run_mock
+ assert logger.logger_impl._run_instance == created_run_mock
created_run_mock.exists.assert_called_once_with("sys/id")
assert logger.name == "Run test name"
assert logger.version == "TEST-1"
@@ -79,8 +78,8 @@ def test_neptune_offline(neptune_mock):
logger.experiment["foo"] = "bar"
created_run_mock.exists.assert_called_once_with("sys/id")
- assert logger._run_short_id == "OFFLINE"
- assert logger._run_name == "offline-name"
+ assert logger.logger_impl._run_short_id == "OFFLINE"
+ assert logger.logger_impl._run_name == "offline-name"
def test_online_with_custom_run(neptune_mock):
@@ -91,8 +90,8 @@ def test_online_with_custom_run(neptune_mock):
neptune_mock.init_run.reset_mock()
logger = NeptuneLogger(run=created_run)
- assert logger._run_instance == created_run
- assert logger._run_instance == created_run
+ assert logger.logger_impl._run_instance == created_run
+ assert logger.logger_impl._run_instance == created_run
assert logger.version == "TEST-1"
assert neptune_mock.init_run.call_count == 0
@@ -275,38 +274,6 @@ def test_save_dir(neptune_mock):
assert logger.save_dir == os.path.join(os.getcwd(), ".neptune")
-def test_get_full_model_name():
- SimpleCheckpoint = namedtuple("SimpleCheckpoint", ["dirpath"])
- test_input_data = [
- ("key", os.path.join("foo", "bar", "key.ext"), SimpleCheckpoint(dirpath=os.path.join("foo", "bar"))),
- (
- "key/in/parts",
- os.path.join("foo", "bar", "key/in/parts.ext"),
- SimpleCheckpoint(dirpath=os.path.join("foo", "bar")),
- ),
- ("key", os.path.join("../foo", "bar", "key.ext"), SimpleCheckpoint(dirpath=os.path.join("../foo", "bar"))),
- ("key", os.path.join("foo", "key.ext"), SimpleCheckpoint(dirpath=os.path.join("./foo", "bar/../"))),
- ]
-
- for expected_model_name, model_path, checkpoint in test_input_data:
- assert NeptuneLogger._get_full_model_name(model_path, checkpoint) == expected_model_name
-
-
-def test_get_full_model_names_from_exp_structure():
- input_dict = {
- "foo": {
- "bar": {
- "lvl1_1": {"lvl2": {"lvl3_1": "some non important value", "lvl3_2": "some non important value"}},
- "lvl1_2": "some non important value",
- },
- "other_non_important": {"val100": 100},
- },
- "other_non_important": {"val42": 42},
- }
- expected_keys = {"lvl1_1/lvl2/lvl3_1", "lvl1_1/lvl2/lvl3_2", "lvl1_2"}
- assert NeptuneLogger._get_full_model_names_from_exp_structure(input_dict, "foo/bar") == expected_keys
-
-
def test_inactive_run(neptune_mock, tmp_path, monkeypatch):
from neptune.exceptions import InactiveRunException
diff --git a/tests/tests_pytorch/loggers/test_wandb.py b/tests/tests_pytorch/loggers/test_wandb.py
index e9b9e9a8090b0..0409cbd7b8f2c 100644
--- a/tests/tests_pytorch/loggers/test_wandb.py
+++ b/tests/tests_pytorch/loggers/test_wandb.py
@@ -25,7 +25,6 @@
from lightning.pytorch.cli import LightningCLI
from lightning.pytorch.demos.boring_classes import BoringModel
from lightning.pytorch.loggers import TensorBoardLogger, WandbLogger
-from lightning.pytorch.utilities.exceptions import MisconfigurationException
def test_wandb_project_name(wandb_mock):
@@ -101,7 +100,7 @@ def test_wandb_logger_init(wandb_mock):
_ = logger.experiment
# verify default resume value
- assert logger._wandb_init["resume"] == "allow"
+ assert logger.logger_impl._wandb_init["resume"] == "allow"
logger.log_metrics({"acc": 1.0}, step=3)
wandb_mock.init.assert_called_once()
@@ -145,9 +144,9 @@ def test_wandb_logger_sync_tensorboard_log_metrics(wandb_mock):
def test_wandb_logger_init_before_spawn(wandb_mock):
logger = WandbLogger()
- assert logger._experiment is None
- logger.__getstate__()
- assert logger._experiment is not None
+ assert logger.logger_impl._experiment is None
+ logger.logger_impl.__getstate__()
+ assert logger.logger_impl._experiment is not None
def test_wandb_logger_experiment_called_first(wandb_mock, tmp_path):
@@ -626,7 +625,7 @@ def test_wandb_log_media(wandb_mock, tmp_path):
def test_wandb_logger_offline_log_model(wandb_mock, tmp_path):
"""Test that log_model=True raises an error in offline mode."""
- with pytest.raises(MisconfigurationException, match="checkpoints cannot be uploaded in offline mode"):
+ with pytest.raises(ValueError, match="checkpoints cannot be uploaded in offline mode"):
_ = WandbLogger(save_dir=tmp_path, offline=True, log_model=True)
diff --git a/tests/tests_pytorch/plugins/precision/test_deepspeed_precision.py b/tests/tests_pytorch/plugins/precision/test_deepspeed_precision.py
index da4ce6b89aaab..86eb93c4542a0 100644
--- a/tests/tests_pytorch/plugins/precision/test_deepspeed_precision.py
+++ b/tests/tests_pytorch/plugins/precision/test_deepspeed_precision.py
@@ -19,7 +19,7 @@
def test_invalid_precision_with_deepspeed_precision():
- with pytest.raises(ValueError, match="is not supported. `precision` must be one of"):
+ with pytest.raises(ValueError, match=r"is not supported in DeepSpeed. `precision` must be one of"):
DeepSpeedPrecision(precision="64-true")
diff --git a/tests/tests_pytorch/plugins/precision/test_transformer_engine.py b/tests/tests_pytorch/plugins/precision/test_transformer_engine.py
index a9967280e3f23..6055f9ecd6e60 100644
--- a/tests/tests_pytorch/plugins/precision/test_transformer_engine.py
+++ b/tests/tests_pytorch/plugins/precision/test_transformer_engine.py
@@ -16,16 +16,16 @@
from unittest.mock import ANY, Mock
import pytest
+import pytorch_lightning_enterprise.utils.imports
import torch
-import lightning.fabric
from lightning.pytorch import LightningModule, Trainer
from lightning.pytorch.plugins import TransformerEnginePrecision
from lightning.pytorch.trainer.connectors.accelerator_connector import _AcceleratorConnector
def test_transformer_engine_precision_plugin(monkeypatch):
- module = lightning.fabric.plugins.precision.transformer_engine
+ module = pytorch_lightning_enterprise.plugins.precision.transformer_engine
if module._TRANSFORMER_ENGINE_AVAILABLE:
pytest.skip("Assumes transformer_engine is unavailable")
monkeypatch.setattr(module, "_TRANSFORMER_ENGINE_AVAILABLE", lambda: True)
@@ -44,7 +44,7 @@ def test_transformer_engine_precision_plugin(monkeypatch):
def test_configure_model(monkeypatch):
- module = lightning.fabric.plugins.precision.transformer_engine
+ module = pytorch_lightning_enterprise.plugins.precision.transformer_engine
if module._TRANSFORMER_ENGINE_AVAILABLE:
pytest.skip("Assumes transformer_engine is unavailable")
monkeypatch.setattr(module, "_TRANSFORMER_ENGINE_AVAILABLE", lambda: True)
diff --git a/tests/tests_pytorch/strategies/test_deepspeed.py b/tests/tests_pytorch/strategies/test_deepspeed.py
index 047e2d24f77a3..4f22a602b28aa 100644
--- a/tests/tests_pytorch/strategies/test_deepspeed.py
+++ b/tests/tests_pytorch/strategies/test_deepspeed.py
@@ -153,7 +153,7 @@ def test_deepspeed_precision_choice(cuda_count_1, tmp_path):
def test_deepspeed_with_invalid_config_path():
"""Test to ensure if we pass an invalid config path we throw an exception."""
with pytest.raises(
- MisconfigurationException, match="You passed in a path to a DeepSpeed config but the path does not exist"
+ FileNotFoundError, match="You passed in a path to a DeepSpeed config but the path does not exist"
):
DeepSpeedStrategy(config="invalid_path.json")
@@ -1004,7 +1004,7 @@ def test_deepspeed_strategy_env_variables(mock_deepspeed_distributed, tmp_path,
strategy = trainer.strategy
assert isinstance(strategy, DeepSpeedStrategy)
with mock.patch("platform.system", return_value=platform) as mock_platform:
- strategy._init_deepspeed_distributed()
+ strategy.deepspeed_strategy_impl._init_deepspeed_distributed()
mock_deepspeed_distributed.assert_called()
mock_platform.assert_called()
if platform == "Windows":
@@ -1267,6 +1267,7 @@ def test_validate_parallel_devices_indices(device_indices):
"""
accelerator = Mock(spec=CUDAAccelerator)
+ accelerator.name.return_value = "cuda"
strategy = DeepSpeedStrategy(
accelerator=accelerator, parallel_devices=[torch.device("cuda", i) for i in device_indices]
)
diff --git a/tests/tests_pytorch/strategies/test_xla.py b/tests/tests_pytorch/strategies/test_xla.py
index 3fde2600c9483..0fd2afa98690b 100644
--- a/tests/tests_pytorch/strategies/test_xla.py
+++ b/tests/tests_pytorch/strategies/test_xla.py
@@ -50,14 +50,14 @@ def test_rank_properties_access(xla_available):
strategy.cluster_environment = Mock()
# we're in the main process, no processes have been launched yet
- assert not strategy._launched
+ assert not strategy.xla_strategy_impl._launched
assert strategy.global_rank == 0
assert strategy.local_rank == 0
assert strategy.node_rank == 0
assert strategy.world_size == 1
# simulate we're in a worker process
- strategy._launched = True
+ strategy.xla_strategy_impl._launched = True
assert strategy.global_rank == strategy.cluster_environment.global_rank()
assert strategy.local_rank == strategy.cluster_environment.local_rank()
assert strategy.node_rank == strategy.cluster_environment.node_rank()
diff --git a/tests/tests_pytorch/utilities/migration/test_utils.py b/tests/tests_pytorch/utilities/migration/test_utils.py
index 41ca1a779f8a5..8bb696a659469 100644
--- a/tests/tests_pytorch/utilities/migration/test_utils.py
+++ b/tests/tests_pytorch/utilities/migration/test_utils.py
@@ -101,7 +101,10 @@ def test_patch_legacy_imports_unified(pl_version):
assert any(key.startswith("lightning." + "pytorch") for key in sys.modules), (
f"Imported unified package, so it has to be in sys.modules: {_list_sys_modules('lightning' + '.pytorch')}"
)
- assert not any(key.startswith("pytorch_lightning") for key in sys.modules), (
+ assert not any(
+ key.startswith("pytorch_lightning") and not key.startswith("pytorch_lightning_enterprise")
+ for key in sys.modules
+ ), (
"Should not import standalone package, all imports should be redirected to the unified package;\n"
f" environment: {_list_sys_modules('pytorch_lightning')}"
)