diff --git a/.github/workflows/docs-build.yml b/.github/workflows/docs-build.yml
index db2e4517b79e9..31af5cc30176f 100644
--- a/.github/workflows/docs-build.yml
+++ b/.github/workflows/docs-build.yml
@@ -125,10 +125,7 @@ jobs:
working-directory: ./docs/source-${{ matrix.pkg-name }}
# allow failing link check and doctest if you run with dispatch
continue-on-error: ${{ (matrix.target == 'doctest' || matrix.target == 'linkcheck') && github.event_name == 'workflow_dispatch' }}
- run: |
- # temp fix: https://github.com/Lightning-AI/pytorch-lightning/actions/runs/19440502586/job/55622388642?pr=21354#step:11:4596
- uv pip install -U fastapi
- make ${{ matrix.target }} --debug --jobs $(nproc) SPHINXOPTS="$BUILD_SPHINX_OPTS"
+ run: make ${{ matrix.target }} --debug --jobs $(nproc) SPHINXOPTS="$BUILD_SPHINX_OPTS"
- name: Keep artifact
if: github.event_name == 'pull_request'
diff --git a/docs/source-pytorch/conf.py b/docs/source-pytorch/conf.py
index 67b1c8bf7bfd6..db3f5334ed162 100644
--- a/docs/source-pytorch/conf.py
+++ b/docs/source-pytorch/conf.py
@@ -604,7 +604,10 @@ def package_list_from_file(file):
from lightning.pytorch.cli import _JSONARGPARSE_SIGNATURES_AVAILABLE as _JSONARGPARSE_AVAILABLE
from lightning.pytorch.utilities.imports import _TORCHVISION_AVAILABLE
from lightning.fabric.loggers.tensorboard import _TENSORBOARD_AVAILABLE, _TENSORBOARDX_AVAILABLE
-from lightning.fabric.utilities.imports import _COMET_AVAILABLE, _MLFLOW_AVAILABLE, _NEPTUNE_AVAILABLE, _WANDB_AVAILABLE
+from lightning.pytorch.loggers.neptune import _NEPTUNE_AVAILABLE
+from lightning.pytorch.loggers.comet import _COMET_AVAILABLE
+from lightning.pytorch.loggers.mlflow import _MLFLOW_AVAILABLE
+from lightning.pytorch.loggers.wandb import _WANDB_AVAILABLE
"""
coverage_skip_undoc_in_source = True
diff --git a/requirements/fabric/base.txt b/requirements/fabric/base.txt
index d187769117be6..fdc814c7f7fa1 100644
--- a/requirements/fabric/base.txt
+++ b/requirements/fabric/base.txt
@@ -6,4 +6,3 @@ fsspec[http] >=2022.5.0, <2025.11.0
packaging >=20.0, <=25.0
typing-extensions >4.5.0, <4.16.0
lightning-utilities >=0.10.0, <0.16.0
-pytorch-lightning-enterprise >=2.6.0
diff --git a/requirements/pytorch/base.txt b/requirements/pytorch/base.txt
index 65a3f7484fb7d..014b223b1f012 100644
--- a/requirements/pytorch/base.txt
+++ b/requirements/pytorch/base.txt
@@ -9,4 +9,3 @@ torchmetrics >0.7.0, <1.9.0
packaging >=20.0, <=25.0
typing-extensions >4.5.0, <4.16.0
lightning-utilities >=0.10.0, <0.16.0
-pytorch-lightning-enterprise >=2.6.0
diff --git a/src/lightning/fabric/accelerators/xla.py b/src/lightning/fabric/accelerators/xla.py
index 0cf42821424ee..db2cf2586e1ba 100644
--- a/src/lightning/fabric/accelerators/xla.py
+++ b/src/lightning/fabric/accelerators/xla.py
@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import functools
-import warnings
from typing import Any, Union
import torch
@@ -21,11 +20,7 @@
from lightning.fabric.accelerators.accelerator import Accelerator
from lightning.fabric.accelerators.registry import _AcceleratorRegistry
-from lightning.fabric.utilities.imports import _raise_enterprise_not_available
-
-_XLA_AVAILABLE = RequirementCache("torch_xla>=1.13", "torch_xla")
-_XLA_GREATER_EQUAL_2_1 = RequirementCache("torch_xla>=2.1")
-_XLA_GREATER_EQUAL_2_5 = RequirementCache("torch_xla>=2.5")
+from lightning.fabric.utilities.device_parser import _check_data_type
class XLAAccelerator(Accelerator):
@@ -36,38 +31,38 @@ class XLAAccelerator(Accelerator):
"""
def __init__(self, *args: Any, **kwargs: Any) -> None:
- _raise_enterprise_not_available()
+ if not _XLA_AVAILABLE:
+ raise ModuleNotFoundError(str(_XLA_AVAILABLE))
+ if not _using_pjrt():
+ raise RuntimeError("The XLA XRT runtime is not supported anymore.")
super().__init__(*args, **kwargs)
- from pytorch_lightning_enterprise.accelerators.xla import XLAAccelerator as EnterpriseXLAAccelerator
-
- self.accelerator_impl = EnterpriseXLAAccelerator(*args, **kwargs)
-
@override
def setup_device(self, device: torch.device) -> None:
- return self.accelerator_impl.setup_device(device)
+ pass
@override
def teardown(self) -> None:
- return self.accelerator_impl.teardown()
+ pass
@staticmethod
@override
def parse_devices(devices: Union[int, str, list[int]]) -> Union[int, list[int]]:
"""Accelerator device parsing logic."""
- _raise_enterprise_not_available()
- from pytorch_lightning_enterprise.accelerators.xla import XLAAccelerator as EnterpriseXLAAccelerator
-
- return EnterpriseXLAAccelerator.parse_devices(devices)
+ return _parse_tpu_devices(devices)
@staticmethod
@override
def get_parallel_devices(devices: Union[int, list[int]]) -> list[torch.device]:
"""Gets parallel devices for the Accelerator."""
- _raise_enterprise_not_available()
- from pytorch_lightning_enterprise.accelerators.xla import XLAAccelerator as EnterpriseXLAAccelerator
-
- return EnterpriseXLAAccelerator.get_parallel_devices(devices)
+ devices = _parse_tpu_devices(devices)
+ if isinstance(devices, int):
+ return [torch.device("xla", i) for i in range(devices)]
+ # list of devices is not supported, just a specific index, fine to access [0]
+ return [torch.device("xla", devices[0])]
+ # we cannot create `xla_device` here because processes have not been spawned yet (this is called in the
+ # accelerator connector init). However, there doesn't seem to be a problem with instantiating `torch.device`.
+ # it will be replaced with `xla_device` (also a torch.device`, but with extra logic) in the strategy
@staticmethod
@override
@@ -76,10 +71,16 @@ def get_parallel_devices(devices: Union[int, list[int]]) -> list[torch.device]:
@functools.lru_cache(maxsize=1)
def auto_device_count() -> int:
"""Get the devices when set to auto."""
- _raise_enterprise_not_available()
- from pytorch_lightning_enterprise.accelerators.xla import XLAAccelerator as EnterpriseXLAAccelerator
+ if not _XLA_AVAILABLE:
+ return 0
+ if _XLA_GREATER_EQUAL_2_1:
+ from torch_xla._internal import tpu
+
+ return tpu.num_available_devices()
+ from torch_xla.experimental import tpu
- return EnterpriseXLAAccelerator.auto_device_count()
+ device_count_on_version = {2: 8, 3: 8, 4: 4}
+ return device_count_on_version.get(tpu.version(), 8)
@staticmethod
@override
@@ -91,9 +92,6 @@ def is_available() -> bool:
# XLA may raise these exceptions if it's not properly configured. This needs to be avoided for the cases
# when `torch_xla` is imported but not used
return False
- except ModuleNotFoundError as e:
- warnings.warn(str(e))
- return False
@staticmethod
@override
@@ -108,3 +106,74 @@ def register_accelerators(cls, accelerator_registry: _AcceleratorRegistry) -> No
cls,
description=cls.__name__,
)
+
+
+# PJRT support requires this minimum version
+_XLA_AVAILABLE = RequirementCache("torch_xla>=1.13", "torch_xla")
+_XLA_GREATER_EQUAL_2_1 = RequirementCache("torch_xla>=2.1")
+_XLA_GREATER_EQUAL_2_5 = RequirementCache("torch_xla>=2.5")
+
+
+def _using_pjrt() -> bool:
+ # `using_pjrt` is removed in torch_xla 2.5
+ if _XLA_GREATER_EQUAL_2_5:
+ from torch_xla import runtime as xr
+
+ return xr.device_type() is not None
+ # delete me when torch_xla 2.2 is the min supported version, where XRT support has been dropped.
+ if _XLA_GREATER_EQUAL_2_1:
+ from torch_xla import runtime as xr
+
+ return xr.using_pjrt()
+
+ from torch_xla.experimental import pjrt
+
+ return pjrt.using_pjrt()
+
+
+def _parse_tpu_devices(devices: Union[int, str, list[int]]) -> Union[int, list[int]]:
+ """Parses the TPU devices given in the format as accepted by the
+ :class:`~lightning.pytorch.trainer.trainer.Trainer` and :class:`~lightning.fabric.Fabric`.
+
+ Args:
+ devices: An int of 1 or string '1' indicates that 1 core with multi-processing should be used
+ An int 8 or string '8' indicates that all 8 cores with multi-processing should be used
+ A single element list of int or string can be used to indicate the specific TPU core to use.
+
+ Returns:
+ A list of tpu cores to be used.
+
+ """
+ _check_data_type(devices)
+ if isinstance(devices, str):
+ devices = _parse_tpu_devices_str(devices)
+ _check_tpu_devices_valid(devices)
+ return devices
+
+
+def _check_tpu_devices_valid(devices: object) -> None:
+ device_count = XLAAccelerator.auto_device_count()
+ if (
+ # support number of devices
+ isinstance(devices, int)
+ and devices in {1, device_count}
+ # support picking a specific device
+ or isinstance(devices, (list, tuple))
+ and len(devices) == 1
+ and 0 <= devices[0] <= device_count - 1
+ ):
+ return
+ raise ValueError(
+ f"`devices` can only be 'auto', 1, {device_count} or [<0-{device_count - 1}>] for TPUs. Got {devices!r}"
+ )
+
+
+def _parse_tpu_devices_str(devices: str) -> Union[int, list[int]]:
+ devices = devices.strip()
+ try:
+ return int(devices)
+ except ValueError:
+ try:
+ return [int(x.strip()) for x in devices.split(",") if len(x) > 0]
+ except ValueError:
+ raise ValueError(f"Could not parse the selected TPU devices: {devices!r}")
diff --git a/src/lightning/fabric/plugins/environments/kubeflow.py b/src/lightning/fabric/plugins/environments/kubeflow.py
index 8013ec2a65720..23a1c0d1753af 100644
--- a/src/lightning/fabric/plugins/environments/kubeflow.py
+++ b/src/lightning/fabric/plugins/environments/kubeflow.py
@@ -13,11 +13,11 @@
# limitations under the License.
import logging
+import os
from typing_extensions import override
from lightning.fabric.plugins.environments.cluster_environment import ClusterEnvironment
-from lightning.fabric.utilities.imports import _raise_enterprise_not_available
log = logging.getLogger(__name__)
@@ -33,28 +33,20 @@ class KubeflowEnvironment(ClusterEnvironment):
"""
- def __init__(self) -> None:
- _raise_enterprise_not_available()
- from pytorch_lightning_enterprise.plugins.environments.kubeflow import (
- KubeflowEnvironment as EnterpriseKubeflowEnvironment,
- )
-
- self.kubeflow_impl = EnterpriseKubeflowEnvironment()
-
@property
@override
def creates_processes_externally(self) -> bool:
- return self.kubeflow_impl.creates_processes_externally
+ return True
@property
@override
def main_address(self) -> str:
- return self.kubeflow_impl.main_address
+ return os.environ["MASTER_ADDR"]
@property
@override
def main_port(self) -> int:
- return self.kubeflow_impl.main_port
+ return int(os.environ["MASTER_PORT"])
@staticmethod
@override
@@ -63,24 +55,24 @@ def detect() -> bool:
@override
def world_size(self) -> int:
- return self.kubeflow_impl.world_size()
+ return int(os.environ["WORLD_SIZE"])
@override
def set_world_size(self, size: int) -> None:
- return self.kubeflow_impl.set_world_size(size)
+ log.debug("KubeflowEnvironment.set_world_size was called, but setting world size is not allowed. Ignored.")
@override
def global_rank(self) -> int:
- return self.kubeflow_impl.global_rank()
+ return int(os.environ["RANK"])
@override
def set_global_rank(self, rank: int) -> None:
- return self.kubeflow_impl.set_global_rank(rank)
+ log.debug("KubeflowEnvironment.set_global_rank was called, but setting global rank is not allowed. Ignored.")
@override
def local_rank(self) -> int:
- return self.kubeflow_impl.local_rank()
+ return 0
@override
def node_rank(self) -> int:
- return self.kubeflow_impl.node_rank()
+ return self.global_rank()
diff --git a/src/lightning/fabric/plugins/environments/lsf.py b/src/lightning/fabric/plugins/environments/lsf.py
index 0b62bf502303a..f0a07d61d9f03 100644
--- a/src/lightning/fabric/plugins/environments/lsf.py
+++ b/src/lightning/fabric/plugins/environments/lsf.py
@@ -13,11 +13,12 @@
# limitations under the License.
import logging
import os
+import socket
from typing_extensions import override
from lightning.fabric.plugins.environments.cluster_environment import ClusterEnvironment
-from lightning.fabric.utilities.imports import _raise_enterprise_not_available
+from lightning.fabric.utilities.cloud_io import get_filesystem
log = logging.getLogger(__name__)
@@ -49,32 +50,36 @@ class LSFEnvironment(ClusterEnvironment):
def __init__(self) -> None:
super().__init__()
-
- _raise_enterprise_not_available()
- from pytorch_lightning_enterprise.plugins.environments.lsf import (
- LSFEnvironment as EnterpriseLSFEnvironment,
- )
-
- self.lsf_impl = EnterpriseLSFEnvironment()
+ self._main_address = self._get_main_address()
+ self._main_port = self._get_main_port()
+ self._node_rank = self._get_node_rank()
+ self._set_init_progress_group_env_vars()
+
+ def _set_init_progress_group_env_vars(self) -> None:
+ # set environment variables needed for initializing torch distributed process group
+ os.environ["MASTER_ADDR"] = str(self._main_address)
+ log.debug(f"MASTER_ADDR: {os.environ['MASTER_ADDR']}")
+ os.environ["MASTER_PORT"] = str(self._main_port)
+ log.debug(f"MASTER_PORT: {os.environ['MASTER_PORT']}")
@property
@override
def creates_processes_externally(self) -> bool:
"""LSF creates subprocesses, i.e., PyTorch Lightning does not need to spawn them."""
- return self.lsf_impl.creates_processes_externally
+ return True
@property
@override
def main_address(self) -> str:
"""The main address is read from an OpenMPI host rank file in the environment variable
``LSB_DJOB_RANKFILE``."""
- return self.lsf_impl.main_address
+ return self._main_address
@property
@override
def main_port(self) -> int:
"""The main port is calculated from the LSF job ID."""
- return self.lsf_impl.main_port
+ return self._main_port
@staticmethod
@override
@@ -86,28 +91,110 @@ def detect() -> bool:
@override
def world_size(self) -> int:
"""The world size is read from the environment variable ``JSM_NAMESPACE_SIZE``."""
- return self.lsf_impl.world_size()
+ world_size = os.environ.get("JSM_NAMESPACE_SIZE")
+ if world_size is None:
+ raise ValueError(
+ "Cannot determine world size. Environment variable `JSM_NAMESPACE_SIZE` not found."
+ " Make sure you run your executable with `jsrun`."
+ )
+ return int(world_size)
@override
def set_world_size(self, size: int) -> None:
- return self.lsf_impl.set_world_size(size)
+ log.debug("LSFEnvironment.set_world_size was called, but setting world size is not allowed. Ignored.")
@override
def global_rank(self) -> int:
"""The world size is read from the environment variable ``JSM_NAMESPACE_RANK``."""
- return self.lsf_impl.global_rank()
+ global_rank = os.environ.get("JSM_NAMESPACE_RANK")
+ if global_rank is None:
+ raise ValueError(
+ "Cannot determine global rank. Environment variable `JSM_NAMESPACE_RANK` not found."
+ " Make sure you run your executable with `jsrun`."
+ )
+ return int(global_rank)
@override
def set_global_rank(self, rank: int) -> None:
- return self.lsf_impl.set_global_rank(rank)
+ log.debug("LSFEnvironment.set_global_rank was called, but setting global rank is not allowed. Ignored.")
@override
def local_rank(self) -> int:
"""The local rank is read from the environment variable `JSM_NAMESPACE_LOCAL_RANK`."""
- return self.lsf_impl.local_rank()
+ local_rank = os.environ.get("JSM_NAMESPACE_LOCAL_RANK")
+ if local_rank is None:
+ raise ValueError(
+ "Cannot determine local rank. Environment variable `JSM_NAMESPACE_LOCAL_RANK` not found."
+ " Make sure you run your executable with `jsrun`."
+ )
+ return int(local_rank)
@override
def node_rank(self) -> int:
"""The node rank is determined by the position of the current hostname in the OpenMPI host rank file stored in
``LSB_DJOB_RANKFILE``."""
- return self.lsf_impl.node_rank()
+ return self._node_rank
+
+ def _get_node_rank(self) -> int:
+ """A helper method for getting the node rank.
+
+ The node rank is determined by the position of the current node in the list of hosts used in the job. This is
+ calculated by reading all hosts from ``LSB_DJOB_RANKFILE`` and finding this node's hostname in the list.
+
+ """
+ hosts = self._read_hosts()
+ count: dict[str, int] = {}
+ for host in hosts:
+ if host not in count:
+ count[host] = len(count)
+ return count[socket.gethostname()]
+
+ @staticmethod
+ def _read_hosts() -> list[str]:
+ """Read compute hosts that are a part of the compute job.
+
+ LSF uses the Job Step Manager (JSM) to manage job steps. Job steps are executed by the JSM from "launch" nodes.
+ Each job is assigned a launch node. This launch node will be the first node in the list contained in
+ ``LSB_DJOB_RANKFILE``.
+
+ """
+ var = "LSB_DJOB_RANKFILE"
+ rankfile = os.environ.get(var)
+ if rankfile is None:
+ raise ValueError("Did not find the environment variable `LSB_DJOB_RANKFILE`")
+ if not rankfile:
+ raise ValueError("The environment variable `LSB_DJOB_RANKFILE` is empty")
+
+ fs = get_filesystem(rankfile)
+ with fs.open(rankfile, "r") as f:
+ ret = [line.strip() for line in f]
+ # remove the launch node (i.e. the first node in LSB_DJOB_RANKFILE) from the list
+ return ret[1:]
+
+ def _get_main_address(self) -> str:
+ """A helper for getting the main address.
+
+ The main address is assigned to the first node in the list of nodes used for the job.
+
+ """
+ hosts = self._read_hosts()
+ return hosts[0]
+
+ @staticmethod
+ def _get_main_port() -> int:
+ """A helper function for accessing the main port.
+
+ Uses the LSF job ID so all ranks can compute the main port.
+
+ """
+ # check for user-specified main port
+ if "MASTER_PORT" in os.environ:
+ log.debug(f"Using externally specified main port: {os.environ['MASTER_PORT']}")
+ return int(os.environ["MASTER_PORT"])
+ if "LSB_JOBID" in os.environ:
+ port = int(os.environ["LSB_JOBID"])
+ # all ports should be in the 10k+ range
+ port = port % 1000 + 10000
+ log.debug(f"calculated LSF main port: {port}")
+ return port
+ raise ValueError("Could not find job id in environment variable LSB_JOBID")
diff --git a/src/lightning/fabric/plugins/environments/slurm.py b/src/lightning/fabric/plugins/environments/slurm.py
index 4ac79ce0014e9..4d98b7ed6a8eb 100644
--- a/src/lightning/fabric/plugins/environments/slurm.py
+++ b/src/lightning/fabric/plugins/environments/slurm.py
@@ -14,6 +14,7 @@
import logging
import os
+import re
import shutil
import signal
import sys
@@ -22,7 +23,7 @@
from typing_extensions import override
from lightning.fabric.plugins.environments.cluster_environment import ClusterEnvironment
-from lightning.fabric.utilities.imports import _IS_WINDOWS, _raise_enterprise_not_available
+from lightning.fabric.utilities.imports import _IS_WINDOWS
from lightning.fabric.utilities.rank_zero import rank_zero_warn
from lightning.fabric.utilities.warnings import PossibleUserWarning
@@ -45,35 +46,57 @@ class SLURMEnvironment(ClusterEnvironment):
def __init__(self, auto_requeue: bool = True, requeue_signal: Optional[signal.Signals] = None) -> None:
super().__init__()
- _raise_enterprise_not_available()
- from pytorch_lightning_enterprise.plugins.environments.slurm import (
- SLURMEnvironment as EnterpriseSLURMEnvironment,
- )
-
- self.slurm_impl = EnterpriseSLURMEnvironment(auto_requeue=auto_requeue, requeue_signal=requeue_signal)
-
- @property
- def auto_requeue(self) -> bool:
- return self.slurm_impl.auto_requeue
-
- @property
- def requeue_signal(self) -> Optional[signal.Signals]:
- return self.slurm_impl.requeue_signal
+ self.auto_requeue = auto_requeue
+ if requeue_signal is None and not _IS_WINDOWS:
+ requeue_signal = signal.SIGUSR1
+ self.requeue_signal = requeue_signal
+ self._validate_srun_used()
+ self._validate_srun_variables()
@property
@override
def creates_processes_externally(self) -> bool:
- return self.slurm_impl.creates_processes_externally
+ return True
@property
@override
def main_address(self) -> str:
- return self.slurm_impl.main_address
+ root_node = os.environ.get("MASTER_ADDR")
+ if root_node is None:
+ nodelist = os.environ.get("SLURM_NODELIST", "127.0.0.1")
+ root_node = self.resolve_root_node_address(nodelist)
+ os.environ["MASTER_ADDR"] = root_node
+
+ log.debug(f"MASTER_ADDR: {os.environ['MASTER_ADDR']}")
+ return root_node
@property
@override
def main_port(self) -> int:
- return self.slurm_impl.main_port
+ # -----------------------
+ # SLURM JOB = PORT number
+ # -----------------------
+ # this way every process knows what port to use
+ job_id = os.environ.get("SLURM_JOB_ID")
+ if job_id is not None:
+ # use the last 4 numbers in the job id as the id
+ default_port = job_id[-4:]
+ # all ports should be in the 10k+ range
+ default_port = int(default_port) + 15000
+ else:
+ default_port = 12910
+
+ # -----------------------
+ # PORT NUMBER = MASTER_PORT
+ # -----------------------
+ # in case the user passed it in
+ if "MASTER_PORT" in os.environ:
+ default_port = int(os.environ["MASTER_PORT"])
+ else:
+ os.environ["MASTER_PORT"] = str(default_port)
+
+ log.debug(f"MASTER_PORT: {os.environ['MASTER_PORT']}")
+ return default_port
@staticmethod
@override
@@ -95,40 +118,58 @@ def job_name() -> Optional[str]:
@staticmethod
def job_id() -> Optional[int]:
- _raise_enterprise_not_available()
- from pytorch_lightning_enterprise.plugins.environments.slurm import (
- SLURMEnvironment as EnterpriseSLURMEnvironment,
- )
-
- return EnterpriseSLURMEnvironment.job_id()
+ # in interactive mode, don't make logs use the same job id
+ if _is_slurm_interactive_mode():
+ return None
+
+ job_id = os.environ.get("SLURM_JOB_ID")
+ if job_id is None:
+ return None
+ try:
+ return int(job_id)
+ except ValueError:
+ return None
@override
def world_size(self) -> int:
- return self.slurm_impl.world_size()
+ return int(os.environ["SLURM_NTASKS"])
@override
def set_world_size(self, size: int) -> None:
- return self.slurm_impl.set_world_size(size)
+ log.debug("SLURMEnvironment.set_world_size was called, but setting world size is not allowed. Ignored.")
@override
def global_rank(self) -> int:
- return self.slurm_impl.global_rank()
+ return int(os.environ["SLURM_PROCID"])
@override
def set_global_rank(self, rank: int) -> None:
- return self.slurm_impl.set_global_rank(rank)
+ log.debug("SLURMEnvironment.set_global_rank was called, but setting global rank is not allowed. Ignored.")
@override
def local_rank(self) -> int:
- return self.slurm_impl.local_rank()
+ return int(os.environ["SLURM_LOCALID"])
@override
def node_rank(self) -> int:
- return self.slurm_impl.node_rank()
+ return int(os.environ["SLURM_NODEID"])
@override
def validate_settings(self, num_devices: int, num_nodes: int) -> None:
- return self.slurm_impl.validate_settings(num_devices, num_nodes)
+ if _is_slurm_interactive_mode():
+ return
+ ntasks_per_node = os.environ.get("SLURM_NTASKS_PER_NODE")
+ if ntasks_per_node is not None and int(ntasks_per_node) != num_devices:
+ raise ValueError(
+ f"You set `devices={num_devices}` in Lightning, but the number of tasks per node configured in SLURM"
+ f" `--ntasks-per-node={ntasks_per_node}` does not match. HINT: Set `devices={ntasks_per_node}`."
+ )
+ nnodes = os.environ.get("SLURM_NNODES")
+ if nnodes is not None and int(nnodes) != num_nodes:
+ raise ValueError(
+ f"You set `num_nodes={num_nodes}` in Lightning, but the number of nodes configured in SLURM"
+ f" `--nodes={nnodes}` does not match. HINT: Set `num_nodes={nnodes}`."
+ )
@staticmethod
def resolve_root_node_address(nodes: str) -> str:
@@ -141,12 +182,9 @@ def resolve_root_node_address(nodes: str) -> str:
- the range notation with brackets, e.g., 'host[5-9]' yields 'host5' as the root
"""
- _raise_enterprise_not_available()
- from pytorch_lightning_enterprise.plugins.environments.slurm import (
- SLURMEnvironment as EnterpriseSLURMEnvironment,
- )
-
- return EnterpriseSLURMEnvironment.resolve_root_node_address(nodes)
+ nodes = re.sub(r"\[(.*?)[,-].*\]", "\\1", nodes) # Take the first node of every node range
+ nodes = re.sub(r"\[(.*?)\]", "\\1", nodes) # handle special case where node range is single number
+ return nodes.split(" ")[0].split(",")[0]
@staticmethod
def _validate_srun_used() -> None:
@@ -178,12 +216,12 @@ def _validate_srun_variables() -> None:
for a complete list of supported srun variables.
"""
- _raise_enterprise_not_available()
- from pytorch_lightning_enterprise.plugins.environments.slurm import (
- SLURMEnvironment as EnterpriseSLURMEnvironment,
- )
-
- return EnterpriseSLURMEnvironment._validate_srun_variables()
+ ntasks = int(os.environ.get("SLURM_NTASKS", "1"))
+ if ntasks > 1 and "SLURM_NTASKS_PER_NODE" not in os.environ:
+ raise RuntimeError(
+ f"You set `--ntasks={ntasks}` in your SLURM bash script, but this variable is not supported."
+ f" HINT: Use `--ntasks-per-node={ntasks}` instead."
+ )
def _is_srun_used() -> bool:
diff --git a/src/lightning/fabric/plugins/environments/torchelastic.py b/src/lightning/fabric/plugins/environments/torchelastic.py
index 62ddd442df7b2..4003dd30dec09 100644
--- a/src/lightning/fabric/plugins/environments/torchelastic.py
+++ b/src/lightning/fabric/plugins/environments/torchelastic.py
@@ -13,12 +13,13 @@
# limitations under the License.
import logging
+import os
import torch.distributed
from typing_extensions import override
from lightning.fabric.plugins.environments.cluster_environment import ClusterEnvironment
-from lightning.fabric.utilities.imports import _raise_enterprise_not_available
+from lightning.fabric.utilities.rank_zero import rank_zero_warn
log = logging.getLogger(__name__)
@@ -26,29 +27,29 @@
class TorchElasticEnvironment(ClusterEnvironment):
"""Environment for fault-tolerant and elastic training with `torchelastic `_"""
- def __init__(self) -> None:
- super().__init__()
- _raise_enterprise_not_available()
- from pytorch_lightning_enterprise.plugins.environments.torchelastic import (
- TorchElasticEnvironment as EnterpriseTorchElasticEnvironment,
- )
-
- self.torchelastic_impl = EnterpriseTorchElasticEnvironment()
-
@property
@override
def creates_processes_externally(self) -> bool:
- return self.torchelastic_impl.creates_processes_externally
+ return True
@property
@override
def main_address(self) -> str:
- return self.torchelastic_impl.main_address
+ if "MASTER_ADDR" not in os.environ:
+ rank_zero_warn("MASTER_ADDR environment variable is not defined. Set as localhost")
+ os.environ["MASTER_ADDR"] = "127.0.0.1"
+ log.debug(f"MASTER_ADDR: {os.environ['MASTER_ADDR']}")
+ return os.environ["MASTER_ADDR"]
@property
@override
def main_port(self) -> int:
- return self.torchelastic_impl.main_port
+ if "MASTER_PORT" not in os.environ:
+ rank_zero_warn("MASTER_PORT environment variable is not defined. Set as 12910")
+ os.environ["MASTER_PORT"] = "12910"
+ log.debug(f"MASTER_PORT: {os.environ['MASTER_PORT']}")
+
+ return int(os.environ["MASTER_PORT"])
@staticmethod
@override
@@ -59,28 +60,34 @@ def detect() -> bool:
@override
def world_size(self) -> int:
- return self.torchelastic_impl.world_size()
+ return int(os.environ["WORLD_SIZE"])
@override
def set_world_size(self, size: int) -> None:
- return self.torchelastic_impl.set_world_size(size)
+ log.debug("TorchElasticEnvironment.set_world_size was called, but setting world size is not allowed. Ignored.")
@override
def global_rank(self) -> int:
- return self.torchelastic_impl.global_rank()
+ return int(os.environ["RANK"])
@override
def set_global_rank(self, rank: int) -> None:
- return self.torchelastic_impl.set_global_rank(rank)
+ log.debug(
+ "TorchElasticEnvironment.set_global_rank was called, but setting global rank is not allowed. Ignored."
+ )
@override
def local_rank(self) -> int:
- return self.torchelastic_impl.local_rank()
+ return int(os.environ["LOCAL_RANK"])
@override
def node_rank(self) -> int:
- return self.torchelastic_impl.node_rank()
+ return int(os.environ.get("GROUP_RANK", 0))
@override
def validate_settings(self, num_devices: int, num_nodes: int) -> None:
- return self.torchelastic_impl.validate_settings(num_devices, num_nodes)
+ if num_devices * num_nodes != self.world_size():
+ raise ValueError(
+ f"You set `devices={num_devices}` and `num_nodes={num_nodes}` in Lightning, but the product"
+ f" ({num_devices} * {num_nodes}) does not match the world size ({self.world_size()})."
+ )
diff --git a/src/lightning/fabric/plugins/environments/xla.py b/src/lightning/fabric/plugins/environments/xla.py
index 515668fa2b154..b8350872f22d9 100644
--- a/src/lightning/fabric/plugins/environments/xla.py
+++ b/src/lightning/fabric/plugins/environments/xla.py
@@ -17,9 +17,8 @@
from typing_extensions import override
-from lightning.fabric.accelerators.xla import XLAAccelerator
+from lightning.fabric.accelerators.xla import _XLA_AVAILABLE, _XLA_GREATER_EQUAL_2_1, XLAAccelerator
from lightning.fabric.plugins.environments.cluster_environment import ClusterEnvironment
-from lightning.fabric.utilities.imports import _raise_enterprise_not_available
log = logging.getLogger(__name__)
@@ -33,28 +32,26 @@ class XLAEnvironment(ClusterEnvironment):
"""
def __init__(self, *args: Any, **kwargs: Any) -> None:
- super().__init__()
- _raise_enterprise_not_available()
- from pytorch_lightning_enterprise.plugins.environments.xla import (
- XLAEnvironment as EnterpriseXLAEnvironment,
- )
-
- self.xla_impl = EnterpriseXLAEnvironment(*args, **kwargs)
+ if not _XLA_AVAILABLE:
+ raise ModuleNotFoundError(str(_XLA_AVAILABLE))
+ super().__init__(*args, **kwargs)
@property
@override
def creates_processes_externally(self) -> bool:
- return self.xla_impl.creates_processes_externally
+ return False
@property
@override
def main_address(self) -> str:
- return self.xla_impl.main_address
+ # unused by lightning
+ raise NotImplementedError
@property
@override
def main_port(self) -> int:
- return self.xla_impl.main_port
+ # unused by lightning
+ raise NotImplementedError
@staticmethod
@override
@@ -69,11 +66,18 @@ def world_size(self) -> int:
The output is cached for performance.
"""
- return self.xla_impl.world_size()
+ if _XLA_GREATER_EQUAL_2_1:
+ from torch_xla import runtime as xr
+
+ return xr.world_size()
+
+ import torch_xla.core.xla_model as xm
+
+ return xm.xrt_world_size()
@override
def set_world_size(self, size: int) -> None:
- return self.xla_impl.set_world_size(size)
+ log.debug("XLAEnvironment.set_world_size was called, but setting world size is not allowed. Ignored.")
@override
@functools.lru_cache(maxsize=1)
@@ -83,11 +87,18 @@ def global_rank(self) -> int:
The output is cached for performance.
"""
- return self.xla_impl.global_rank()
+ if _XLA_GREATER_EQUAL_2_1:
+ from torch_xla import runtime as xr
+
+ return xr.global_ordinal()
+
+ import torch_xla.core.xla_model as xm
+
+ return xm.get_ordinal()
@override
def set_global_rank(self, rank: int) -> None:
- return self.xla_impl.set_global_rank(rank)
+ log.debug("XLAEnvironment.set_global_rank was called, but setting global rank is not allowed. Ignored.")
@override
@functools.lru_cache(maxsize=1)
@@ -97,7 +108,14 @@ def local_rank(self) -> int:
The output is cached for performance.
"""
- return self.xla_impl.local_rank()
+ if _XLA_GREATER_EQUAL_2_1:
+ from torch_xla import runtime as xr
+
+ return xr.local_ordinal()
+
+ import torch_xla.core.xla_model as xm
+
+ return xm.get_local_ordinal()
@override
@functools.lru_cache(maxsize=1)
@@ -107,4 +125,11 @@ def node_rank(self) -> int:
The output is cached for performance.
"""
- return self.xla_impl.node_rank()
+ if _XLA_GREATER_EQUAL_2_1:
+ from torch_xla import runtime as xr
+
+ return xr.host_index()
+ import torch_xla.core.xla_env_vars as xenv
+ from torch_xla.utils.utils import getenv_as
+
+ return getenv_as(xenv.HOST_ORDINAL, int, 0)
diff --git a/src/lightning/fabric/plugins/io/xla.py b/src/lightning/fabric/plugins/io/xla.py
index da6ea414cac97..146fa2f33b510 100644
--- a/src/lightning/fabric/plugins/io/xla.py
+++ b/src/lightning/fabric/plugins/io/xla.py
@@ -12,12 +12,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
+import os
from typing import Any, Optional
+import torch
+from lightning_utilities.core.apply_func import apply_to_collection
+from lightning_utilities.core.imports import RequirementCache
from typing_extensions import override
+from lightning.fabric.accelerators.xla import _XLA_AVAILABLE
from lightning.fabric.plugins.io.torch_io import TorchCheckpointIO
-from lightning.fabric.utilities.imports import _raise_enterprise_not_available
+from lightning.fabric.utilities.cloud_io import get_filesystem
from lightning.fabric.utilities.types import _PATH
log = logging.getLogger(__name__)
@@ -31,11 +36,9 @@ class XLACheckpointIO(TorchCheckpointIO):
"""
def __init__(self, *args: Any, **kwargs: Any) -> None:
+ if not _XLA_AVAILABLE:
+ raise ModuleNotFoundError(str(_XLA_AVAILABLE))
super().__init__(*args, **kwargs)
- _raise_enterprise_not_available()
- from pytorch_lightning_enterprise.plugins.io.xla import XLACheckpointIO as EnterpriseXLACheckpointIO
-
- self.xla_impl = EnterpriseXLACheckpointIO(*args, **kwargs)
@override
def save_checkpoint(self, checkpoint: dict[str, Any], path: _PATH, storage_options: Optional[Any] = None) -> None:
@@ -51,4 +54,21 @@ def save_checkpoint(self, checkpoint: dict[str, Any], path: _PATH, storage_optio
If ``storage_options`` arg is passed in
"""
- return self.xla_impl.save_checkpoint(checkpoint, path, storage_options)
+ if storage_options is not None:
+ raise TypeError(
+ "`Trainer.save_checkpoint(..., storage_options=...)` with `storage_options` arg"
+ f" is not supported for `{self.__class__.__name__}`. Please implement your custom `CheckpointIO`"
+ " to define how you'd like to use `storage_options`."
+ )
+ fs = get_filesystem(path)
+ fs.makedirs(os.path.dirname(path), exist_ok=True)
+ if RequirementCache("omegaconf"):
+ # workaround for https://github.com/pytorch/xla/issues/2773
+ from omegaconf import DictConfig, ListConfig, OmegaConf
+
+ checkpoint = apply_to_collection(checkpoint, (DictConfig, ListConfig), OmegaConf.to_container)
+ import torch_xla.core.xla_model as xm
+
+ cpu_data = xm._maybe_convert_to_cpu(checkpoint, convert=True)
+ log.debug(f"Saving checkpoint: {path}")
+ torch.save(cpu_data, path)
diff --git a/src/lightning/fabric/plugins/precision/bitsandbytes.py b/src/lightning/fabric/plugins/precision/bitsandbytes.py
index 4ff20ee9fab1e..4c648f2b97181 100644
--- a/src/lightning/fabric/plugins/precision/bitsandbytes.py
+++ b/src/lightning/fabric/plugins/precision/bitsandbytes.py
@@ -11,15 +11,34 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from contextlib import AbstractContextManager
-from typing import Any, Literal, Optional
+import functools
+import logging
+import math
+import os
+import warnings
+from collections import OrderedDict
+from contextlib import AbstractContextManager, ExitStack
+from functools import partial
+from types import ModuleType
+from typing import Any, Callable, Literal, Optional, cast
import torch
+from lightning_utilities import apply_to_collection
from lightning_utilities.core.imports import RequirementCache
-from typing_extensions import override
+from torch import Tensor
+from torch.nn import init
+from torch.nn.modules.module import _IncompatibleKeys
+from typing_extensions import Self, override
from lightning.fabric.plugins.precision.precision import Precision
-from lightning.fabric.utilities.imports import _raise_enterprise_not_available
+from lightning.fabric.plugins.precision.utils import (
+ _ClassReplacementContextManager,
+ _convert_fp_tensor,
+ _DtypeContextManager,
+)
+from lightning.fabric.utilities.types import _DEVICE
+
+log = logging.getLogger(__name__)
_BITSANDBYTES_AVAILABLE = RequirementCache("bitsandbytes")
@@ -54,34 +73,376 @@ def __init__(
dtype: Optional[torch.dtype] = None,
ignore_modules: Optional[set[str]] = None,
) -> None:
- super().__init__()
- _raise_enterprise_not_available()
- from pytorch_lightning_enterprise.plugins.precision.bitsandbytes import (
- BitsandbytesPrecision as EnterpriseBitsandbytesPrecision,
- )
+ _import_bitsandbytes()
+
+ if dtype is None:
+ # try to be smart about the default selection
+ if mode.startswith("int8"):
+ dtype = torch.float16
+ else:
+ dtype = (
+ torch.bfloat16 if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else torch.float16
+ )
+ if mode.startswith("int8") and dtype is not torch.float16:
+ # this limitation is mentioned in https://huggingface.co/blog/hf-bitsandbytes-integration#usage
+ raise ValueError(f"{mode!r} only works with `dtype=torch.float16`, but you chose `{dtype}`")
- self.bitsandbytes_impl = EnterpriseBitsandbytesPrecision(mode=mode, dtype=dtype, ignore_modules=ignore_modules)
+ globals_ = globals()
+ mode_to_cls = {
+ "nf4": globals_["_NF4Linear"],
+ "nf4-dq": globals_["_NF4DQLinear"],
+ "fp4": globals_["_FP4Linear"],
+ "fp4-dq": globals_["_FP4DQLinear"],
+ "int8-training": globals_["_Linear8bitLt"],
+ "int8": globals_["_Int8LinearInference"],
+ }
+ self._linear_cls = mode_to_cls[mode]
+ self.dtype = dtype
+ self.ignore_modules = ignore_modules or set()
@override
def convert_module(self, module: torch.nn.Module) -> torch.nn.Module:
- return self.bitsandbytes_impl.convert_module(module)
+ # avoid naive users thinking they quantized their model
+ if not any(isinstance(m, torch.nn.Linear) for m in module.modules()):
+ raise TypeError(
+ "You are using the bitsandbytes precision plugin, but your model has no Linear layers. This plugin"
+ " won't work for your model."
+ )
+
+ # convert modules if they haven't been converted already
+ bnb = _import_bitsandbytes()
+ if not any(isinstance(m, (bnb.nn.Linear8bitLt, bnb.nn.Linear4bit)) for m in module.modules()):
+ # this will not quantize the model but only replace the layer classes
+ _convert_layers(module, self._linear_cls, self.ignore_modules)
+
+ # set the compute dtype if necessary
+ for m in module.modules():
+ if isinstance(m, bnb.nn.Linear4bit):
+ m.compute_dtype = self.dtype
+ m.compute_type_is_set = False
+ return module
@override
def tensor_init_context(self) -> AbstractContextManager:
- return self.bitsandbytes_impl.tensor_init_context()
+ return _DtypeContextManager(self.dtype)
@override
def module_init_context(self) -> AbstractContextManager:
- return self.bitsandbytes_impl.module_init_context()
+ if self.ignore_modules:
+ # cannot patch the Linear class if the user wants to skip some submodules
+ raise RuntimeError(
+ "Instantiating your model under the `init_module` context manager is not supported when used with"
+ f" `BitsandbytesPrecision(..., ignore_modules={self.ignore_modules})` as this"
+ " may initialize the layers on-device, defeating the purpose of quantization. You can remove"
+ " `ignore_modules` or remove the `init_module` context manager."
+ )
+ dtype_ctx = self.tensor_init_context()
+ # TODO: this could also support replacing `Embedding` and `Conv1D`
+ context_manager = _ClassReplacementContextManager({"torch.nn.Linear": self._linear_cls})
+ stack = ExitStack()
+ stack.enter_context(dtype_ctx)
+ stack.enter_context(context_manager)
+ return stack
@override
def forward_context(self) -> AbstractContextManager:
- return self.bitsandbytes_impl.forward_context()
+ return _DtypeContextManager(self.dtype)
@override
def convert_input(self, data: Any) -> Any:
- return self.bitsandbytes_impl.convert_input(data)
+ return apply_to_collection(data, function=_convert_fp_tensor, dtype=Tensor, dst_type=self.dtype)
@override
def convert_output(self, data: Any) -> Any:
- return self.bitsandbytes_impl.convert_output(data)
+ return apply_to_collection(data, function=_convert_fp_tensor, dtype=Tensor, dst_type=torch.get_default_dtype())
+
+
+def _quantize_on_load_hook(quantize_fn: Callable[[torch.Tensor], None], state_dict: OrderedDict, *_: Any) -> None:
+ # There is only one key that ends with `*.weight`, the other one is the bias
+ weight_key = next((name for name in state_dict if name.endswith("weight")), None)
+ if weight_key is None:
+ return
+ # Load the weight from the state dict and re-quantize it
+ weight = state_dict.pop(weight_key)
+ quantize_fn(weight)
+
+
+def _ignore_missing_weights_hook(module: torch.nn.Module, incompatible_keys: _IncompatibleKeys) -> None:
+ # since we manually loaded the weight in the `_quantize_on_load_hook` hook, we need to avoid this missing key false
+ # positive
+ for key in reversed(incompatible_keys.missing_keys):
+ if key.endswith("weight"):
+ incompatible_keys.missing_keys.remove(key)
+
+
+def _replace_param(
+ param: torch.nn.Parameter, data: torch.Tensor, quant_state: Optional[tuple] = None
+) -> torch.nn.Parameter:
+ bnb = _import_bitsandbytes()
+
+ # doing `param.data = weight` raises a RuntimeError if param.data was on meta-device, so
+ # we need to re-create the parameters instead of overwriting the data
+ if param.device.type == "meta":
+ if isinstance(param, bnb.nn.Params4bit):
+ return bnb.nn.Params4bit(
+ data=data,
+ requires_grad=data.requires_grad,
+ quant_state=quant_state,
+ blocksize=param.blocksize,
+ compress_statistics=param.compress_statistics,
+ quant_type=param.quant_type,
+ quant_storage=param.quant_storage,
+ module=param.module,
+ bnb_quantized=param.bnb_quantized,
+ )
+ return torch.nn.Parameter(data, requires_grad=data.requires_grad)
+ param.data = data
+ if isinstance(param, bnb.nn.Params4bit):
+ param.quant_state = quant_state
+ return param
+
+
+@functools.lru_cache(maxsize=1)
+def _import_bitsandbytes() -> ModuleType:
+ if not _BITSANDBYTES_AVAILABLE:
+ raise ModuleNotFoundError(str(_BITSANDBYTES_AVAILABLE))
+ # configuration for bitsandbytes before import
+ nowelcome_set = "BITSANDBYTES_NOWELCOME" in os.environ
+ if not nowelcome_set:
+ os.environ["BITSANDBYTES_NOWELCOME"] = "1"
+ warnings.filterwarnings("ignore", message=r".*bitsandbytes was compiled without GPU support.*")
+ warnings.filterwarnings(
+ "ignore", message=r"MatMul8bitLt: inputs will be cast from .* to float16 during quantization"
+ )
+ import bitsandbytes as bnb
+
+ if not nowelcome_set:
+ del os.environ["BITSANDBYTES_NOWELCOME"]
+
+ class _Linear8bitLt(bnb.nn.Linear8bitLt):
+ """Wraps `bnb.nn.Linear8bitLt` and enables instantiation directly on the device and re-quantizaton when loading
+ the state dict."""
+
+ def __init__(self, *args: Any, device: Optional[_DEVICE] = None, threshold: float = 6.0, **kwargs: Any) -> None:
+ super().__init__(*args, device=device, threshold=threshold, **kwargs)
+ self.weight = cast(bnb.nn.Int8Params, self.weight) # type: ignore[has-type]
+ self.bias: Optional[torch.nn.Parameter] = self.bias
+ # if the device is CUDA or we are under a CUDA context manager, quantize the weight here, so we don't end up
+ # filling the device memory with float32 weights which could lead to OOM
+ if torch.tensor(0, device=device).device.type == "cuda":
+ self.quantize_()
+ self._register_load_state_dict_pre_hook(partial(_quantize_on_load_hook, self.quantize_))
+ self.register_load_state_dict_post_hook(_ignore_missing_weights_hook)
+
+ def quantize_(self, weight: Optional[torch.Tensor] = None, device: Optional[torch.device] = None) -> None:
+ """Inplace quantize."""
+ if weight is None:
+ weight = self.weight.data
+ if weight.data.dtype == torch.int8:
+ # already quantized
+ return
+ assert isinstance(self.weight, bnb.nn.Int8Params)
+ self.weight = self.quantize(self.weight, weight, device)
+
+ @staticmethod
+ def quantize(
+ int8params: bnb.nn.Int8Params, weight: torch.Tensor, device: Optional[torch.device]
+ ) -> bnb.nn.Int8Params:
+ device = device or torch.device("cuda")
+ if device.type != "cuda":
+ raise RuntimeError(f"Unexpected device type: {device.type}")
+ # https://github.com/TimDettmers/bitsandbytes/blob/0.41.0/bitsandbytes/nn/modules.py#L291-L302
+ B = weight.contiguous().to(device=device, dtype=torch.float16)
+ if int8params.has_fp16_weights:
+ int8params.data = B
+ else:
+ # bitsandbytes >= 0.45 supports an improved API
+ if hasattr(bnb.functional, "int8_vectorwise_quant"):
+ CB, SCB, _ = bnb.functional.int8_vectorwise_quant(B)
+ else: # old method is deprecated in 0.45, removed in 0.46+.
+ CB, _, SCB, _, _ = bnb.functional.double_quant(B)
+
+ int8params.data = CB
+ setattr(int8params, "CB", CB)
+ setattr(int8params, "SCB", SCB)
+ return int8params
+
+ def to_empty(self, *, device: _DEVICE, recurse: bool = True) -> Self:
+ if self.weight.device.type == "meta":
+ # need custom logic if int8params is on meta device
+ raise NotImplementedError
+ if self.weight.dtype == torch.uint8: # was quantized
+ # need the original shape here
+ raise NotImplementedError
+ device = torch.device(device)
+ weight = torch.empty_like(self.weight.data, device=device)
+ if device.type == "cuda": # re-quantize
+ self.quantize_(weight, device)
+ else:
+ self.weight = _replace_param(self.weight, weight)
+ if self.bias is not None:
+ self.bias = _replace_param(self.bias, torch.empty_like(self.bias, device=device))
+ return self
+
+ def reset_parameters(self) -> None:
+ # from `torch.nn.Linear.reset_parameters`
+ if self.bias is not None:
+ fan_in, _ = torch.nn.init._calculate_fan_in_and_fan_out(self.weight)
+ bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
+ init.uniform_(self.bias, -bound, bound)
+
+ linear_init_finished = isinstance(self.weight, bnb.nn.Params4bit)
+ if linear_init_finished and self.weight.dtype == torch.uint8: # was quantized
+ # need the original shape here
+ raise NotImplementedError
+ weight = self.weight.data
+ torch.nn.init.kaiming_uniform_(weight, a=math.sqrt(5))
+ if linear_init_finished:
+ if self.weight.device.type == "meta":
+ # need custom logic if int8params is on meta device
+ raise NotImplementedError
+ if self.weight.device.type == "cuda": # re-quantize
+ self.quantize_(weight)
+ else:
+ self.weight = _replace_param(self.weight, weight)
+
+ class _Linear4bit(bnb.nn.Linear4bit):
+ """Wraps `bnb.nn.Linear4bit` to enable: instantiation directly on the device, re-quantizaton when loading the
+ state dict, meta-device initialization, and materialization."""
+
+ def __init__(self, *args: Any, device: Optional[_DEVICE] = None, **kwargs: Any) -> None:
+ super().__init__(*args, device=device, **kwargs)
+ self.weight = cast(bnb.nn.Params4bit, self.weight) # type: ignore[has-type]
+ self.bias: Optional[torch.nn.Parameter] = self.bias
+ # if the device is CUDA or we are under a CUDA context manager, quantize the weight here, so we don't end up
+ # filling the device memory with float32 weights which could lead to OOM
+ if torch.tensor(0, device=device).device.type == "cuda":
+ self.quantize_()
+ self._register_load_state_dict_pre_hook(partial(_quantize_on_load_hook, self.quantize_))
+ self.register_load_state_dict_post_hook(_ignore_missing_weights_hook)
+
+ def quantize_(self, weight: Optional[torch.Tensor] = None, device: Optional[torch.device] = None) -> None:
+ """Inplace quantize."""
+ if weight is None:
+ weight = self.weight.data
+ if weight.data.dtype == torch.uint8:
+ # already quantized
+ return
+ assert isinstance(self.weight, bnb.nn.Params4bit)
+ self.weight = self.quantize(self.weight, weight, device)
+ self.weight.bnb_quantized = True
+
+ @staticmethod
+ def quantize(
+ params4bit: bnb.nn.Params4bit, weight: torch.Tensor, device: Optional[torch.device]
+ ) -> bnb.nn.Params4bit:
+ device = device or torch.device("cuda")
+ if device.type != "cuda":
+ raise RuntimeError(f"Unexpected device type: {device.type}")
+ # https://github.com/TimDettmers/bitsandbytes/blob/0.41.0/bitsandbytes/nn/modules.py#L156-L159
+ w = weight.contiguous().to(device=device, dtype=torch.half)
+ w_4bit, quant_state = bnb.functional.quantize_4bit(
+ w,
+ blocksize=params4bit.blocksize,
+ compress_statistics=params4bit.compress_statistics,
+ quant_type=params4bit.quant_type,
+ quant_storage=params4bit.quant_storage,
+ )
+ return _replace_param(params4bit, w_4bit, quant_state)
+
+ def to_empty(self, *, device: _DEVICE, recurse: bool = True) -> Self:
+ if self.weight.dtype == torch.uint8: # was quantized
+ # cannot init the quantized params directly
+ weight = torch.empty(self.weight.quant_state.shape, device=device, dtype=torch.half)
+ else:
+ weight = torch.empty_like(self.weight.data, device=device)
+ device = torch.device(device)
+ if device.type == "cuda": # re-quantize
+ self.quantize_(weight, device)
+ else:
+ self.weight = _replace_param(self.weight, weight)
+ if self.bias is not None:
+ self.bias = _replace_param(self.bias, torch.empty_like(self.bias, device=device))
+ return self
+
+ def reset_parameters(self) -> None:
+ # from `torch.nn.Linear.reset_parameters`
+ if self.bias is not None:
+ fan_in, _ = torch.nn.init._calculate_fan_in_and_fan_out(self.weight)
+ bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
+ init.uniform_(self.bias, -bound, bound)
+
+ linear_init_finished = isinstance(self.weight, bnb.nn.Params4bit)
+ if linear_init_finished and self.weight.dtype == torch.uint8: # was quantized
+ # cannot init the quantized params directly
+ weight = torch.empty(self.weight.quant_state.shape, device=self.weight.device, dtype=torch.half)
+ else:
+ weight = self.weight.data
+ torch.nn.init.kaiming_uniform_(weight, a=math.sqrt(5))
+ if linear_init_finished:
+ if self.weight.device.type == "cuda": # re-quantize
+ self.quantize_(weight)
+ else:
+ self.weight = _replace_param(self.weight, weight)
+
+ # Use a class instead `functools.partial` to respect `isinstance` checks and attribute accesses
+ class _Int8LinearInference(_Linear8bitLt):
+ def __init__(self, *args: Any, **kwargs: Any) -> None:
+ super().__init__(*args, has_fp16_weights=False, **kwargs)
+
+ class _FP4Linear(_Linear4bit):
+ def __init__(self, *args: Any, **kwargs: Any) -> None:
+ super().__init__(*args, quant_type="fp4", compress_statistics=False, **kwargs)
+
+ class _FP4DQLinear(_Linear4bit):
+ def __init__(self, *args: Any, **kwargs: Any) -> None:
+ super().__init__(*args, quant_type="fp4", compress_statistics=True, **kwargs)
+
+ class _NF4Linear(_Linear4bit):
+ def __init__(self, *args: Any, **kwargs: Any) -> None:
+ super().__init__(*args, quant_type="nf4", compress_statistics=False, **kwargs)
+
+ class _NF4DQLinear(_Linear4bit):
+ def __init__(self, *args: Any, **kwargs: Any) -> None:
+ super().__init__(*args, quant_type="nf4", compress_statistics=True, **kwargs)
+
+ # these classes are defined programmatically like this to avoid importing bitsandbytes in environments that have
+ # it available but will not use it
+ classes = {
+ "_Linear8bitLt": _Linear8bitLt,
+ "_Linear4bit": _Linear4bit,
+ "_Int8LinearInference": _Int8LinearInference,
+ "_FP4Linear": _FP4Linear,
+ "_FP4DQLinear": _FP4DQLinear,
+ "_NF4Linear": _NF4Linear,
+ "_NF4DQLinear": _NF4DQLinear,
+ }
+ globals().update(classes)
+
+ return bnb
+
+
+def _convert_layers(module: torch.nn.Module, linear_cls: type, ignore_modules: set[str], prefix: str = "") -> None:
+ for name, child in module.named_children():
+ fullname = f"{prefix}.{name}" if prefix else name
+ if isinstance(child, torch.nn.Linear) and not any(fullname.startswith(s) for s in ignore_modules):
+ log.debug(f"Replacing layer {fullname!r} with bitsandbytes equivalent")
+ has_bias = child.bias is not None
+ # since we are going to copy over the child's data, the device doesn't matter. I chose CPU
+ # to avoid spiking CUDA memory even though initialization is slower
+ # 4bit layers support quantizing from meta-device params so this is only relevant for 8-bit
+ _Linear4bit = globals()["_Linear4bit"]
+ device = torch.device("meta" if issubclass(linear_cls, _Linear4bit) else "cpu")
+ replacement = linear_cls(
+ child.in_features,
+ child.out_features,
+ bias=has_bias,
+ device=device,
+ )
+ if has_bias:
+ replacement.bias = _replace_param(replacement.bias, child.bias.data.clone())
+ state = {"quant_state": replacement.weight.quant_state if issubclass(linear_cls, _Linear4bit) else None}
+ replacement.weight = _replace_param(replacement.weight, child.weight.data.clone(), **state)
+ module.__setattr__(name, replacement)
+ else:
+ _convert_layers(child, linear_cls, ignore_modules, prefix=fullname)
diff --git a/src/lightning/fabric/plugins/precision/deepspeed.py b/src/lightning/fabric/plugins/precision/deepspeed.py
index b2f369dbf956d..526095008f376 100644
--- a/src/lightning/fabric/plugins/precision/deepspeed.py
+++ b/src/lightning/fabric/plugins/precision/deepspeed.py
@@ -11,16 +11,17 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from contextlib import AbstractContextManager
+from contextlib import AbstractContextManager, nullcontext
from typing import TYPE_CHECKING, Any, Literal
import torch
+from lightning_utilities.core.apply_func import apply_to_collection
from torch import Tensor
from torch.nn import Module
-from typing_extensions import override
+from typing_extensions import get_args, override
-from lightning.fabric.plugins.precision.precision import _PRECISION_INPUT_STR, Precision
-from lightning.fabric.utilities.imports import _raise_enterprise_not_available
+from lightning.fabric.plugins.precision.precision import Precision
+from lightning.fabric.plugins.precision.utils import _convert_fp_tensor, _DtypeContextManager
from lightning.fabric.utilities.types import Steppable
if TYPE_CHECKING:
@@ -43,37 +44,51 @@ class DeepSpeedPrecision(Precision):
"""
def __init__(self, precision: _PRECISION_INPUT) -> None:
- super().__init__()
- _raise_enterprise_not_available()
- from pytorch_lightning_enterprise.plugins.precision.deepspeed import (
- DeepSpeedPrecisionFabric as EnterpriseDeepSpeedPrecision,
- )
-
- self.deepspeed_impl = EnterpriseDeepSpeedPrecision(precision=precision)
+ supported_precision = get_args(_PRECISION_INPUT)
+ if precision not in supported_precision:
+ raise ValueError(
+ f"`precision={precision!r})` is not supported in DeepSpeed."
+ f" `precision` must be one of: {supported_precision}."
+ )
+ self.precision = precision
+
+ precision_to_type = {
+ "bf16-mixed": torch.bfloat16,
+ "16-mixed": torch.float16,
+ "bf16-true": torch.bfloat16,
+ "16-true": torch.float16,
+ "32-true": torch.float32,
+ }
+ self._desired_dtype = precision_to_type[self.precision]
@override
def convert_module(self, module: Module) -> Module:
- return self.deepspeed_impl.convert_module(module)
+ if "true" in self.precision:
+ return module.to(dtype=self._desired_dtype)
+ return module
@override
def tensor_init_context(self) -> AbstractContextManager:
- return self.deepspeed_impl.tensor_init_context()
+ if "true" not in self.precision:
+ return nullcontext()
+ return _DtypeContextManager(self._desired_dtype)
@override
def module_init_context(self) -> AbstractContextManager:
- return self.deepspeed_impl.module_init_context()
+ return self.tensor_init_context()
@override
def convert_input(self, data: Any) -> Any:
- return self.deepspeed_impl.convert_input(data)
+ return apply_to_collection(data, function=_convert_fp_tensor, dtype=Tensor, dst_type=self._desired_dtype)
@override
def convert_output(self, data: Any) -> Any:
- return self.deepspeed_impl.convert_output(data)
+ return apply_to_collection(data, function=_convert_fp_tensor, dtype=Tensor, dst_type=torch.get_default_dtype())
@override
def backward(self, tensor: Tensor, model: "DeepSpeedEngine", *args: Any, **kwargs: Any) -> None:
- return self.deepspeed_impl.backward(tensor, model, *args, **kwargs)
+ """Performs back-propagation using DeepSpeed's engine."""
+ model.backward(tensor, *args, **kwargs)
@override
def optimizer_step(
@@ -81,20 +96,5 @@ def optimizer_step(
optimizer: Steppable,
**kwargs: Any,
) -> Any:
- return self.deepspeed_impl.optimizer_step(optimizer, **kwargs)
-
- @property
- def precision(self) -> _PRECISION_INPUT_STR:
- return self.deepspeed_impl.precision
-
- @precision.setter
- def precision(self, precision: _PRECISION_INPUT_STR) -> None:
- self.deepspeed_impl.precision = precision
-
- @property
- def _desired_dtype(self) -> torch.dtype:
- return self.deepspeed_impl._desired_dtype
-
- @_desired_dtype.setter
- def _desired_dtype(self, dtype: torch.dtype) -> None:
- self.deepspeed_impl._desired_dtype = dtype
+ # DeepSpeed handles the optimizer step internally
+ return optimizer.step(**kwargs)
diff --git a/src/lightning/fabric/plugins/precision/transformer_engine.py b/src/lightning/fabric/plugins/precision/transformer_engine.py
index 3d3d2491763f4..bf1e51ea6b2b0 100644
--- a/src/lightning/fabric/plugins/precision/transformer_engine.py
+++ b/src/lightning/fabric/plugins/precision/transformer_engine.py
@@ -11,19 +11,31 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import logging
from collections.abc import Mapping
-from contextlib import AbstractContextManager
+from contextlib import AbstractContextManager, ExitStack
from typing import TYPE_CHECKING, Any, Literal, Optional, Union
import torch
+from lightning_utilities import apply_to_collection
+from lightning_utilities.core.imports import RequirementCache
+from torch import Tensor
from typing_extensions import override
from lightning.fabric.plugins.precision.precision import Precision
-from lightning.fabric.utilities.imports import _raise_enterprise_not_available
+from lightning.fabric.plugins.precision.utils import (
+ _ClassReplacementContextManager,
+ _convert_fp_tensor,
+ _DtypeContextManager,
+)
+from lightning.fabric.utilities.rank_zero import rank_zero_info, rank_zero_warn
if TYPE_CHECKING:
from transformer_engine.common.recipe import DelayedScaling
+_TRANSFORMER_ENGINE_AVAILABLE = RequirementCache("transformer_engine>=0.11.0")
+log = logging.getLogger(__name__)
+
class TransformerEnginePrecision(Precision):
"""Plugin for training with fp8 precision via nvidia's
@@ -60,79 +72,111 @@ def __init__(
replace_layers: Optional[bool] = None,
fallback_compute_dtype: Optional[torch.dtype] = None,
) -> None:
- super().__init__()
- _raise_enterprise_not_available()
- from pytorch_lightning_enterprise.plugins.precision.transformer_engine import (
- TransformerEnginePrecision as EnterpriseTransformerEnginePrecision,
- )
-
- self.transformer_engine_impl = EnterpriseTransformerEnginePrecision(
- weights_dtype=weights_dtype,
- recipe=recipe,
- replace_layers=replace_layers,
- fallback_compute_dtype=fallback_compute_dtype,
- )
-
- @property
- def weights_dtype(self) -> torch.dtype:
- return self.transformer_engine_impl.weights_dtype
-
- @weights_dtype.setter
- def weights_dtype(self, value: torch.dtype) -> None:
- self.transformer_engine_impl.weights_dtype = value
-
- @property
- def recipe(self) -> Union[Mapping[str, Any], "DelayedScaling"]:
- return self.transformer_engine_impl.recipe
-
- @recipe.setter
- def recipe(self, value: Union[Mapping[str, Any], "DelayedScaling"]) -> None:
- self.transformer_engine_impl.recipe = value
-
- @property
- def replace_layers(self) -> bool:
- return self.transformer_engine_impl.replace_layers
-
- @replace_layers.setter
- def replace_layers(self, value: bool) -> None:
- self.transformer_engine_impl.replace_layers = value
-
- @property
- def fallback_compute_dtype(self) -> torch.dtype:
- return self.transformer_engine_impl.fallback_compute_dtype
-
- @fallback_compute_dtype.setter
- def fallback_compute_dtype(self, value: torch.dtype) -> None:
- self.transformer_engine_impl.fallback_compute_dtype = value
+ if not _TRANSFORMER_ENGINE_AVAILABLE:
+ raise ModuleNotFoundError(str(_TRANSFORMER_ENGINE_AVAILABLE))
+ from transformer_engine.common.recipe import DelayedScaling
+
+ if recipe is None:
+ recipe = DelayedScaling()
+ elif isinstance(recipe, Mapping):
+ recipe = dict(recipe) # copy
+ if "fp8_format" in recipe:
+ from transformer_engine.common.recipe import Format
+
+ recipe["fp8_format"] = getattr(Format, recipe["fp8_format"])
+ recipe = DelayedScaling(**recipe)
+
+ self.weights_dtype = weights_dtype
+ self.recipe = recipe
+ self.replace_layers = replace_layers
+ self.fallback_compute_dtype = fallback_compute_dtype or weights_dtype
@override
def convert_module(self, module: torch.nn.Module) -> torch.nn.Module:
- return self.transformer_engine_impl.convert_module(module)
+ # avoid converting if any is found. assume the user took care of it
+ if any("transformer_engine.pytorch" in m.__module__ for m in module.modules()):
+ if self.replace_layers is True:
+ # info level because this is expected with `init_module`
+ rank_zero_info(
+ "`TransformerEnginePrecision(replace_layers=True)` is set but the model already contains"
+ " TransformerEngine layers. Skipping"
+ )
+ elif self.replace_layers in (None, True):
+ _convert_layers(module)
+ module = module.to(dtype=self.weights_dtype)
+ return module
@override
def tensor_init_context(self) -> AbstractContextManager:
- return self.transformer_engine_impl.tensor_init_context()
+ return _DtypeContextManager(self.weights_dtype)
@override
def module_init_context(self) -> AbstractContextManager:
- return self.transformer_engine_impl.module_init_context()
+ dtype_ctx = self.tensor_init_context()
+ stack = ExitStack()
+ if self.replace_layers:
+ import transformer_engine.pytorch as te
+
+ context_manager = _ClassReplacementContextManager({
+ "torch.nn.Linear": te.Linear,
+ "torch.nn.LayerNorm": te.LayerNorm,
+ })
+ stack.enter_context(context_manager)
+ stack.enter_context(dtype_ctx)
+ return stack
@override
def forward_context(self) -> AbstractContextManager:
- return self.transformer_engine_impl.forward_context()
+ dtype_ctx = _DtypeContextManager(self.weights_dtype)
+ fallback_autocast_ctx = torch.autocast(device_type="cuda", dtype=self.fallback_compute_dtype)
+ import transformer_engine.pytorch as te
+
+ autocast_ctx = te.fp8_autocast(enabled=True, fp8_recipe=self.recipe)
+ stack = ExitStack()
+ stack.enter_context(dtype_ctx)
+ # enable an outer fallback autocast for operations that do not support fp8
+ stack.enter_context(fallback_autocast_ctx)
+ stack.enter_context(autocast_ctx)
+ return stack
@override
def convert_input(self, data: Any) -> Any:
- return self.transformer_engine_impl.convert_input(data)
+ return apply_to_collection(data, function=_convert_fp_tensor, dtype=Tensor, dst_type=self.weights_dtype)
@override
def convert_output(self, data: Any) -> Any:
- return self.transformer_engine_impl.convert_output(data)
-
- @property
- def _desired_dtype(self) -> torch.dtype:
- return self.transformer_engine_impl._desired_dtype
-
- @_desired_dtype.setter
- def _desired_dtype(self, dtype: torch.dtype) -> None:
- self.transformer_engine_impl._desired_dtype = dtype
+ return apply_to_collection(data, function=_convert_fp_tensor, dtype=Tensor, dst_type=torch.get_default_dtype())
+
+
+def _convert_layers(module: torch.nn.Module) -> None:
+ import transformer_engine.pytorch as te
+
+ for name, child in module.named_children():
+ if isinstance(child, torch.nn.Linear):
+ if child.in_features % 8 != 0 or child.out_features % 16 != 0:
+ # https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/examples/fp8_primer.html#FP8-autocasting
+ rank_zero_warn(
+ "Support for FP8 in the linear layers with this plugin is currently limited to"
+ " tensors with shapes where the dimensions are divisible by 8 and 16 respectively."
+ f" The layer {name!r} does not fit this criteria. You might want to add padding to your inputs."
+ )
+ continue
+ has_bias = child.bias is not None
+ replacement = te.Linear(child.in_features, child.out_features, bias=has_bias)
+ replacement.weight.data = child.weight.data.clone()
+ if has_bias:
+ replacement.bias.data = child.bias.data.clone()
+ log.debug(f"Replacing layer {name!r} with Transformer Engine equivalent")
+ module.__setattr__(name, replacement)
+ elif isinstance(child, torch.nn.LayerNorm):
+ replacement = te.LayerNorm(child.normalized_shape[0], eps=child.eps)
+ replacement.weight.data = child.weight.data.clone()
+ # Check if bias exists before attempting to clone its data
+ if child.bias is not None and replacement.bias is not None:
+ replacement.bias.data = child.bias.data.clone()
+ log.debug(f"Replacing layer {name!r} with Transformer Engine equivalent")
+ module.__setattr__(name, replacement)
+ else:
+ # there are other transformer engine layers that we could convert but require fusion. full list at:
+ # https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/api/pytorch.html
+ _convert_layers(child)
diff --git a/src/lightning/fabric/plugins/precision/xla.py b/src/lightning/fabric/plugins/precision/xla.py
index 0946f394ca954..fdb30032b3cdd 100644
--- a/src/lightning/fabric/plugins/precision/xla.py
+++ b/src/lightning/fabric/plugins/precision/xla.py
@@ -11,13 +11,14 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import os
from typing import Any, Literal
import torch
-from typing_extensions import override
+from typing_extensions import get_args, override
-from lightning.fabric.plugins.precision.precision import _PRECISION_INPUT_STR, Precision
-from lightning.fabric.utilities.imports import _raise_enterprise_not_available
+from lightning.fabric.accelerators.xla import _XLA_AVAILABLE
+from lightning.fabric.plugins.precision.precision import Precision
from lightning.fabric.utilities.types import Optimizable
_PRECISION_INPUT = Literal["32-true", "16-true", "bf16-true"]
@@ -36,11 +37,24 @@ class XLAPrecision(Precision):
"""
def __init__(self, precision: _PRECISION_INPUT) -> None:
- super().__init__()
- _raise_enterprise_not_available()
- from pytorch_lightning_enterprise.plugins.precision.xla import XLAPrecision as EnterpriseXLAPrecision
-
- self.xla_impl = EnterpriseXLAPrecision(precision=precision)
+ if not _XLA_AVAILABLE:
+ raise ModuleNotFoundError(str(_XLA_AVAILABLE))
+ supported_precision = get_args(_PRECISION_INPUT)
+ if precision not in supported_precision:
+ raise ValueError(
+ f"`precision={precision!r})` is not supported in XLA."
+ f" `precision` must be one of: {supported_precision}."
+ )
+ self.precision = precision
+
+ if precision == "16-true":
+ os.environ["XLA_USE_F16"] = "1"
+ self._desired_dtype = torch.float16
+ elif precision == "bf16-true":
+ os.environ["XLA_USE_BF16"] = "1"
+ self._desired_dtype = torch.bfloat16
+ else:
+ self._desired_dtype = torch.float32
@override
def optimizer_step(
@@ -48,24 +62,12 @@ def optimizer_step(
optimizer: Optimizable,
**kwargs: Any,
) -> Any:
- return self.xla_impl.optimizer_step(optimizer, **kwargs)
+ import torch_xla.core.xla_model as xm
+
+ # you always want to `xm.mark_step()` after `optimizer.step` for better performance, so we set `barrier=True`
+ return xm.optimizer_step(optimizer, optimizer_args=kwargs, barrier=True)
@override
def teardown(self) -> None:
- return self.xla_impl.teardown()
-
- @property
- def _desired_dtype(self) -> torch.dtype:
- return self.xla_impl._desired_dtype
-
- @_desired_dtype.setter
- def _desired_dtype(self, dtype: torch.dtype) -> None:
- self.xla_impl._desired_dtype = dtype
-
- @property
- def precision(self) -> _PRECISION_INPUT_STR:
- return self.xla_impl.precision
-
- @precision.setter
- def precision(self, precision: _PRECISION_INPUT_STR) -> None:
- self.xla_impl.precision = precision
+ os.environ.pop("XLA_USE_BF16", None)
+ os.environ.pop("XLA_USE_F16", None)
diff --git a/src/lightning/fabric/strategies/deepspeed.py b/src/lightning/fabric/strategies/deepspeed.py
index 93d05016868dd..883546fea1f2d 100644
--- a/src/lightning/fabric/strategies/deepspeed.py
+++ b/src/lightning/fabric/strategies/deepspeed.py
@@ -11,30 +11,45 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import argparse
+import json
import logging
-from contextlib import AbstractContextManager
+import os
+import platform
+from collections.abc import Mapping
+from contextlib import AbstractContextManager, ExitStack
from datetime import timedelta
+from itertools import chain
+from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
import torch
+from lightning_utilities.core.imports import RequirementCache
from torch.nn import Module
from torch.optim import Optimizer
from typing_extensions import override
-from lightning.fabric.accelerators import Accelerator
+from lightning.fabric.accelerators import Accelerator, CUDAAccelerator
from lightning.fabric.plugins.collectives.torch_collective import default_pg_timeout
from lightning.fabric.plugins.environments.cluster_environment import ClusterEnvironment
from lightning.fabric.plugins.precision import Precision
from lightning.fabric.strategies.ddp import DDPStrategy
from lightning.fabric.strategies.registry import _StrategyRegistry
from lightning.fabric.strategies.strategy import _Sharded
-from lightning.fabric.utilities.imports import _raise_enterprise_not_available
+from lightning.fabric.utilities.distributed import log
+from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_6
+from lightning.fabric.utilities.load import _move_state_into
+from lightning.fabric.utilities.rank_zero import rank_zero_info, rank_zero_warn
+from lightning.fabric.utilities.seed import reset_seed
from lightning.fabric.utilities.types import _PATH
if TYPE_CHECKING:
from deepspeed import DeepSpeedEngine
from torch.optim.lr_scheduler import _LRScheduler
+_DEEPSPEED_AVAILABLE = RequirementCache("deepspeed")
+_DEEPSPEED_GREATER_EQUAL_0_16 = RequirementCache("deepspeed>=0.16.0")
+
# TODO(fabric): Links in the docstrings to PL-specific deepspeed user docs need to be replaced.
class DeepSpeedStrategy(DDPStrategy, _Sharded):
@@ -220,11 +235,24 @@ def __init__(
exclude_frozen_parameters: Exclude frozen parameters when saving checkpoints.
"""
+ if not _DEEPSPEED_AVAILABLE:
+ raise ImportError(
+ "To use the `DeepSpeedStrategy`, you must have DeepSpeed installed."
+ " Install it by running `pip install -U deepspeed`."
+ )
- _raise_enterprise_not_available()
- from pytorch_lightning_enterprise.strategies.deepspeed import (
- DeepSpeedStrategyFabric as EnterpriseDeepSpeedStrategy,
- )
+ if _TORCH_GREATER_EQUAL_2_6 and not _DEEPSPEED_GREATER_EQUAL_0_16:
+ # Starting with PyTorch 2.6, `torch.load` defaults to `weights_only=True` when loading full checkpoints.
+ # DeepSpeed added support for this behavior in version 0.16.0.
+ import deepspeed
+
+ deepspeed_version = deepspeed.__version__
+
+ raise ImportError(
+ f"PyTorch >= 2.6 requires DeepSpeed >= 0.16.0. "
+ f"Detected DeepSpeed version: {deepspeed_version}. "
+ "Please upgrade by running `pip install -U 'deepspeed>=0.16.0'`."
+ )
super().__init__(
accelerator=accelerator,
@@ -233,68 +261,77 @@ def __init__(
precision=precision,
process_group_backend=process_group_backend,
)
- self.deepspeed_impl = EnterpriseDeepSpeedStrategy(
- outer_object=self,
- accelerator=accelerator,
- zero_optimization=zero_optimization,
- stage=stage,
- remote_device=remote_device,
- offload_optimizer=offload_optimizer,
- offload_parameters=offload_parameters,
- offload_params_device=offload_params_device,
- nvme_path=nvme_path,
- params_buffer_count=params_buffer_count,
- params_buffer_size=params_buffer_size,
- max_in_cpu=max_in_cpu,
- offload_optimizer_device=offload_optimizer_device,
- optimizer_buffer_count=optimizer_buffer_count,
- block_size=block_size,
- queue_depth=queue_depth,
- single_submit=single_submit,
- overlap_events=overlap_events,
- thread_count=thread_count,
- pin_memory=pin_memory,
- sub_group_size=sub_group_size,
- contiguous_gradients=contiguous_gradients,
- overlap_comm=overlap_comm,
- allgather_partitions=allgather_partitions,
- reduce_scatter=reduce_scatter,
- allgather_bucket_size=allgather_bucket_size,
- reduce_bucket_size=reduce_bucket_size,
- zero_allow_untested_optimizer=zero_allow_untested_optimizer,
- logging_batch_size_per_gpu=logging_batch_size_per_gpu,
- config=config,
- logging_level=logging_level,
- parallel_devices=parallel_devices,
- cluster_environment=cluster_environment,
- loss_scale=loss_scale,
- initial_scale_power=initial_scale_power,
- loss_scale_window=loss_scale_window,
- hysteresis=hysteresis,
- min_loss_scale=min_loss_scale,
- partition_activations=partition_activations,
- cpu_checkpointing=cpu_checkpointing,
- contiguous_memory_optimization=contiguous_memory_optimization,
- synchronize_checkpoint_boundary=synchronize_checkpoint_boundary,
- load_full_weights=load_full_weights,
- precision=precision,
- process_group_backend=process_group_backend,
- timeout=timeout,
- exclude_frozen_parameters=exclude_frozen_parameters,
- )
+ self._backward_sync_control = None # DeepSpeed handles gradient accumulation internally
+ self._timeout: Optional[timedelta] = timeout
+
+ self.config = self._load_config(config)
+ if self.config is None:
+ # User has not overridden config, set defaults
+ self.config = self._create_default_config(
+ zero_optimization,
+ zero_allow_untested_optimizer,
+ logging_batch_size_per_gpu,
+ offload_optimizer=offload_optimizer,
+ offload_parameters=offload_parameters,
+ nvme_path=nvme_path,
+ offload_params_device=offload_params_device,
+ params_buffer_count=params_buffer_count,
+ params_buffer_size=params_buffer_size,
+ max_in_cpu=max_in_cpu,
+ pin_memory=pin_memory,
+ offload_optimizer_device=offload_optimizer_device,
+ optimizer_buffer_count=optimizer_buffer_count,
+ block_size=block_size,
+ queue_depth=queue_depth,
+ single_submit=single_submit,
+ overlap_events=overlap_events,
+ thread_count=thread_count,
+ partition_activations=partition_activations,
+ cpu_checkpointing=cpu_checkpointing,
+ contiguous_memory_optimization=contiguous_memory_optimization,
+ synchronize_checkpoint_boundary=synchronize_checkpoint_boundary,
+ stage=stage,
+ contiguous_gradients=contiguous_gradients,
+ overlap_comm=overlap_comm,
+ allgather_partitions=allgather_partitions,
+ reduce_scatter=reduce_scatter,
+ allgather_bucket_size=allgather_bucket_size,
+ reduce_bucket_size=reduce_bucket_size,
+ sub_group_size=sub_group_size,
+ )
+
+ import deepspeed
+
+ self._config_initialized = False
+ deepspeed.utils.logging.logger.setLevel(logging_level)
+
+ self.remote_device = remote_device
+ self.load_full_weights = load_full_weights
+ self.exclude_frozen_parameters = exclude_frozen_parameters
+
+ # default FP16 parameters.
+ self.loss_scale = loss_scale
+ self.initial_scale_power = initial_scale_power
+ self.loss_scale_window = loss_scale_window
+ self.hysteresis = hysteresis
+ self.min_loss_scale = min_loss_scale
+
+ self._deepspeed_engine: Optional[DeepSpeedEngine] = None
@property
def zero_stage_3(self) -> bool:
- return self.deepspeed_impl.zero_stage_3
+ assert isinstance(self.config, dict)
+ zero_optimization = self.config.get("zero_optimization")
+ return zero_optimization is not None and zero_optimization.get("stage") == 3
@property
@override
def distributed_sampler_kwargs(self) -> dict[str, int]:
- return self.deepspeed_impl.distributed_sampler_kwargs
+ return {"num_replicas": self.world_size, "rank": self.global_rank}
@property
def model(self) -> "DeepSpeedEngine":
- return self.deepspeed_impl._deepspeed_engine
+ return self._deepspeed_engine
@override
def setup_module_and_optimizers(
@@ -308,9 +345,14 @@ def setup_module_and_optimizers(
deepspeed optimizer, and an optional learning rate scheduler.
"""
- return self.deepspeed_impl.setup_module_and_optimizers(
- module=module, optimizers=optimizers, scheduler=scheduler
- )
+ if len(optimizers) != 1:
+ raise ValueError(
+ f"Currently only one optimizer is supported with DeepSpeed. Got {len(optimizers)} optimizers instead."
+ )
+
+ self._deepspeed_engine, optimizer, scheduler = self._initialize_engine(module, optimizers[0], scheduler)
+ self._set_deepspeed_activation_checkpointing()
+ return self._deepspeed_engine, [optimizer], scheduler
@override
def setup_module(self, module: Module) -> "DeepSpeedEngine":
@@ -319,7 +361,8 @@ def setup_module(self, module: Module) -> "DeepSpeedEngine":
For training, see :meth:`setup_module_and_optimizers`.
"""
- return self.deepspeed_impl.setup_module(module=module)
+ self._deepspeed_engine, _, _ = self._initialize_engine(module)
+ return self._deepspeed_engine
@override
def setup_optimizer(self, optimizer: Optimizer) -> Optimizer:
@@ -328,15 +371,34 @@ def setup_optimizer(self, optimizer: Optimizer) -> Optimizer:
Please use :meth:`setup_module_and_optimizers` to set up both module and optimizer together.
"""
- return self.deepspeed_impl.setup_optimizer(optimizer=optimizer)
+ raise NotImplementedError(self._err_msg_joint_setup_required())
@override
def module_init_context(self, empty_init: Optional[bool] = None) -> AbstractContextManager:
- return self.deepspeed_impl.module_init_context(empty_init=empty_init)
+ if self.zero_stage_3 and empty_init is False:
+ raise NotImplementedError(
+ f"`{empty_init=}` is not a valid choice with `DeepSpeedStrategy` when ZeRO stage 3 is enabled."
+ )
+ module_sharded_ctx = self.module_sharded_context()
+ stack = ExitStack()
+ if not self.zero_stage_3:
+ stack.enter_context(super().module_init_context(empty_init=empty_init))
+ stack.enter_context(module_sharded_ctx)
+ return stack
@override
def module_sharded_context(self) -> AbstractContextManager:
- return self.deepspeed_impl.module_sharded_context()
+ # Current limitation in Fabric: The config needs to be fully determined at the time of calling the context
+ # manager. Later modifications through e.g. `Fabric.setup()` won't have an effect here.
+
+ import deepspeed
+
+ assert self._config_initialized
+ return deepspeed.zero.Init(
+ enabled=self.zero_stage_3,
+ remote_device=self.remote_device,
+ config_dict_or_path=self.config,
+ )
@override
def save_checkpoint(
@@ -363,8 +425,46 @@ def save_checkpoint(
:class:`deepspeed.DeepSpeedEngine` objects were found.
"""
- return self.deepspeed_impl.save_checkpoint(
- path=path, state=state, storage_options=storage_options, filter=filter
+ if storage_options is not None:
+ raise TypeError(
+ "`DeepSpeedStrategy.save_checkpoint(..., storage_options=...)` is not supported because"
+ " `DeepSpeedStrategy` does not use the `CheckpointIO`."
+ )
+ if filter is not None:
+ raise TypeError(
+ "`DeepSpeedStrategy.save_checkpoint(..., filter=...)` is not supported because"
+ " `DeepSpeedStrategy` manages the state serialization internally."
+ )
+
+ engines = _get_deepspeed_engines_from_state(state)
+ if len(engines) == 0:
+ raise ValueError(
+ "Could not find a DeepSpeed model in the provided checkpoint state. Please provide the model as"
+ " part of the state like so: `save_checkpoint(..., state={'model': model, ...})`. Make sure"
+ " you set up the model (and optimizers if any) through the strategy before saving the checkpoint."
+ )
+ if len(engines) > 1:
+ raise ValueError(
+ "Found multiple DeepSpeed engine modules in the given state. Saving checkpoints with DeepSpeed is"
+ " currently limited to a single model per checkpoint. To save multiple models, call the"
+ " save method for each model separately with a different path."
+ )
+ engine = engines[0]
+
+ # broadcast the path from rank 0 to ensure all the states are saved in a common path
+ path = self.broadcast(path)
+
+ # split the checkpoint into two parts:
+ # 1) the deepspeed engine encapsulating both the model and optionally the optimizer(s)
+ # 2) the rest of the user's state, which in deepspeed is called `client state`
+ excluded_objects = (engine, engine.optimizer) if engine.optimizer is not None else (engine,)
+ state = {k: v for k, v in state.items() if v not in excluded_objects}
+ _validate_state_keys(state)
+ # there might be other stateful objects unrelated to the deepspeed engine - convert them to a state_dict
+ state = self._convert_stateful_objects_in_state(state, filter={})
+ # use deepspeed's internal checkpointing function to handle partitioned weights across processes
+ engine.save_checkpoint(
+ path, client_state=state, tag="checkpoint", exclude_frozen_parameters=self.exclude_frozen_parameters
)
@override
@@ -395,7 +495,59 @@ def load_checkpoint(
not in the expected DeepSpeed format.
"""
- return self.deepspeed_impl.load_checkpoint(path=path, state=state, strict=strict)
+ if isinstance(state, (Module, Optimizer)) or self.load_full_weights and self.zero_stage_3:
+ # This code path to enables loading a checkpoint from a non-deepspeed checkpoint or from
+ # a consolidated checkpoint
+ path = self.broadcast(path)
+ return super().load_checkpoint(path=path, state=state, strict=strict, weights_only=weights_only)
+
+ if not state:
+ raise ValueError(
+ f"Got DeepSpeedStrategy.load_checkpoint(..., state={state!r}) but a state with at least "
+ f" a model instance to reload is required. Pass it in like so:"
+ " DeepSpeedStrategy.load_checkpoint(..., state={'model': model, ...})"
+ )
+ _validate_checkpoint_directory(path)
+
+ engines = _get_deepspeed_engines_from_state(state)
+ if len(engines) == 0:
+ raise ValueError(
+ "Could not find a DeepSpeed model in the provided checkpoint state. Please provide the model as"
+ " part of the state like so: `load_checkpoint(..., state={'model': model, ...})`. Make sure"
+ " you set up the model (and optimizers if any) through the strategy before loading the checkpoint."
+ )
+ if len(engines) > 1:
+ raise ValueError(
+ "Found multiple DeepSpeed engine modules in the given state. Saving and loading checkpoints"
+ " with DeepSpeed is currently limited to a single model per checkpoint. To load multiple model"
+ " states, call the load method for each model checkpoint separately."
+ )
+ engine = engines[0]
+
+ from deepspeed.runtime.base_optimizer import DeepSpeedOptimizer
+
+ optimzer_state_requested = any(isinstance(item, (Optimizer, DeepSpeedOptimizer)) for item in state.values())
+
+ torch.cuda.empty_cache()
+ _, client_state = engine.load_checkpoint(
+ path,
+ tag="checkpoint",
+ load_optimizer_states=optimzer_state_requested,
+ load_lr_scheduler_states=False,
+ load_module_strict=strict,
+ )
+
+ if client_state is None:
+ raise RuntimeError(
+ "DeepSpeed was unable to load the checkpoint. Ensure you passed in a DeepSpeed compatible checkpoint"
+ " or a single checkpoint file by setting `DeepSpeedStrategy(..., load_full_weights=True)`."
+ )
+
+ # `Engine.load_checkpoint` adds useless keys 'optimizer' and 'lr_scheduler' to the client state; remove
+ # them to avoid name collision with user state
+ keys = set(client_state) & set(state) - {"optimizer", "lr_scheduler"}
+ _move_state_into(source=client_state, destination=state, keys=keys)
+ return client_state
@override
def clip_gradients_norm(
@@ -406,19 +558,19 @@ def clip_gradients_norm(
norm_type: Union[float, int] = 2.0,
error_if_nonfinite: bool = True,
) -> torch.Tensor:
- return self.deepspeed_impl.clip_gradients_norm(
- module=module,
- optimizer=optimizer,
- max_norm=max_norm,
- norm_type=norm_type,
- error_if_nonfinite=error_if_nonfinite,
+ raise NotImplementedError(
+ "DeepSpeed handles gradient clipping automatically within the optimizer. "
+ "Make sure to set the `gradient_clipping` value in your Config."
)
@override
def clip_gradients_value(
self, module: "DeepSpeedEngine", optimizer: Optimizer, clip_val: Union[float, int]
) -> None:
- return self.deepspeed_impl.clip_gradients_value(module=module, optimizer=optimizer, clip_val=clip_val)
+ raise NotImplementedError(
+ "DeepSpeed handles gradient clipping automatically within the optimizer. "
+ "Make sure to set the `gradient_clipping` value in your Config."
+ )
@classmethod
@override
@@ -461,22 +613,338 @@ def register_strategies(cls, strategy_registry: _StrategyRegistry) -> None:
offload_optimizer_device="nvme",
)
+ def _initialize_engine(
+ self, model: Module, optimizer: Optional[Optimizer] = None, scheduler: Optional["_LRScheduler"] = None
+ ) -> tuple["DeepSpeedEngine", Optimizer, Any]:
+ """Initialize one model and one optimizer with an optional learning rate scheduler.
+
+ This calls ``deepspeed.initialize`` internally.
+
+ """
+ import deepspeed
+
+ model_parameters = filter(lambda p: p.requires_grad, model.parameters())
+ deepspeed_engine, deepspeed_optimizer, _, deepspeed_scheduler = deepspeed.initialize(
+ args=argparse.Namespace(device_rank=self.root_device.index),
+ config=self.config,
+ model=model,
+ model_parameters=model_parameters,
+ optimizer=optimizer,
+ lr_scheduler=scheduler,
+ dist_init_required=False,
+ )
+ return deepspeed_engine, deepspeed_optimizer, deepspeed_scheduler
+
@override
def setup_environment(self) -> None:
- return self.deepspeed_impl.setup_environment()
+ if not isinstance(self.accelerator, CUDAAccelerator):
+ raise RuntimeError(
+ f"The DeepSpeed strategy is only supported on CUDA GPUs but `{self.accelerator.__class__.__name__}`"
+ " is used."
+ )
+ super().setup_environment()
@override
def _setup_distributed(self) -> None:
- return self.deepspeed_impl._setup_distributed()
+ assert self.parallel_devices is not None
+ _validate_device_index_selection(self.parallel_devices)
+ reset_seed()
+ self._set_world_ranks()
+ self._init_deepspeed_distributed()
+ if not self._config_initialized:
+ self._format_config()
+ self._config_initialized = True
+
+ def _init_deepspeed_distributed(self) -> None:
+ import deepspeed
+
+ assert self.cluster_environment is not None
+ if platform.system() != "Windows":
+ # do not set env variables on windows, allow deepspeed to control setup
+ self._set_node_environment_variables()
+ log.info(
+ "initializing deepspeed distributed: "
+ f"GLOBAL_RANK: {self.global_rank}, "
+ f"MEMBER: {self.global_rank + 1}/{self.world_size}"
+ )
+ self._process_group_backend = self._get_process_group_backend()
+ deepspeed.init_distributed(
+ self._process_group_backend, distributed_port=self.cluster_environment.main_port, timeout=self._timeout
+ )
- @property
- def config(self) -> dict[str, Any]:
- return self.deepspeed_impl.config
+ def _set_node_environment_variables(self) -> None:
+ assert self.cluster_environment is not None
+ os.environ["MASTER_ADDR"] = self.cluster_environment.main_address
+ os.environ["MASTER_PORT"] = str(self.cluster_environment.main_port)
+ os.environ["RANK"] = str(self.global_rank)
+ os.environ["WORLD_SIZE"] = str(self.world_size)
+ os.environ["LOCAL_RANK"] = str(self.local_rank)
+
+ def _set_deepspeed_activation_checkpointing(self) -> None:
+ import deepspeed
+
+ assert isinstance(self.config, dict)
+ if self.config.get("activation_checkpointing"):
+ checkpoint_config = self.config["activation_checkpointing"]
+ deepspeed.checkpointing.configure(
+ mpu_=None,
+ partition_activations=checkpoint_config.get("partition_activations"),
+ contiguous_checkpointing=checkpoint_config.get("contiguous_memory_optimization"),
+ checkpoint_in_cpu=checkpoint_config.get("cpu_checkpointing"),
+ profile=checkpoint_config.get("profile"),
+ )
+
+ def _format_config(self) -> None:
+ if self.config is None:
+ raise ValueError(
+ "To use DeepSpeed you must pass in a DeepSpeed config dict, or a path to a JSON config."
+ " See: https://lightning.ai/docs/pytorch/stable/advanced/model_parallel.html#deepspeed"
+ )
+
+ self.config.setdefault("train_micro_batch_size_per_gpu", 1)
+ _format_precision_config(
+ config=self.config,
+ precision=self.precision.precision,
+ loss_scale=self.loss_scale,
+ loss_scale_window=self.loss_scale_window,
+ min_loss_scale=self.min_loss_scale,
+ initial_scale_power=self.initial_scale_power,
+ hysteresis=self.hysteresis,
+ )
- @config.setter
- def config(self, config: dict[str, Any]) -> None:
- self.deepspeed_impl.config = config
+ def _create_default_config(
+ self,
+ zero_optimization: bool,
+ zero_allow_untested_optimizer: bool,
+ logging_batch_size_per_gpu: Optional[int],
+ partition_activations: bool,
+ cpu_checkpointing: bool,
+ contiguous_memory_optimization: bool,
+ synchronize_checkpoint_boundary: bool,
+ offload_optimizer: bool,
+ offload_parameters: bool,
+ nvme_path: str,
+ offload_params_device: str,
+ params_buffer_count: int,
+ params_buffer_size: int,
+ max_in_cpu: int,
+ offload_optimizer_device: str,
+ optimizer_buffer_count: int,
+ pin_memory: bool,
+ block_size: int,
+ queue_depth: int,
+ single_submit: bool,
+ overlap_events: bool,
+ thread_count: int,
+ **zero_kwargs: Any,
+ ) -> dict:
+ cfg = {
+ "activation_checkpointing": {
+ "partition_activations": partition_activations,
+ "cpu_checkpointing": cpu_checkpointing,
+ "contiguous_memory_optimization": contiguous_memory_optimization,
+ "synchronize_checkpoint_boundary": synchronize_checkpoint_boundary,
+ },
+ "aio": {
+ "block_size": block_size,
+ "queue_depth": queue_depth,
+ "single_submit": single_submit,
+ "overlap_events": overlap_events,
+ "thread_count": thread_count,
+ },
+ }
+ if zero_optimization:
+ zero_config = zero_kwargs
+
+ if offload_optimizer:
+ zero_config["offload_optimizer"] = {
+ "device": offload_optimizer_device,
+ "nvme_path": nvme_path,
+ "buffer_count": optimizer_buffer_count,
+ "pin_memory": pin_memory,
+ }
+ if offload_parameters:
+ zero_config["offload_param"] = {
+ "device": offload_params_device,
+ "nvme_path": nvme_path,
+ "buffer_count": params_buffer_count,
+ "buffer_size": params_buffer_size,
+ "max_in_cpu": max_in_cpu,
+ "pin_memory": pin_memory,
+ }
+ cfg.update({
+ "zero_allow_untested_optimizer": zero_allow_untested_optimizer,
+ "zero_optimization": zero_config,
+ })
+ if logging_batch_size_per_gpu:
+ cfg["train_micro_batch_size_per_gpu"] = logging_batch_size_per_gpu
+ return cfg
+
+ def _restore_zero_state(self, module: Module, ckpt: Mapping[str, Any]) -> None:
+ """Overrides the normal load_state_dict behaviour in PyTorch to ensure we gather parameters that may be sharded
+ across processes before loading the state dictionary when using ZeRO stage 3. This is then automatically synced
+ across processes.
- @property
- def load_full_weights(self) -> bool:
- return self.deepspeed_impl.load_full_weights
+ Args:
+ ckpt: The ckpt file.
+
+ """
+ import deepspeed
+
+ def load(module: torch.nn.Module, prefix: str = "") -> None:
+ missing_keys: list[str] = []
+ unexpected_keys: list[str] = []
+ error_msgs: list[str] = []
+ state_dict = ckpt["state_dict"]
+
+ # copy state_dict so _load_from_state_dict can modify it
+ metadata = getattr(state_dict, "_metadata", None)
+ state_dict = state_dict.copy()
+ if metadata is not None:
+ state_dict._metadata = metadata
+
+ local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
+ # because zero3 puts placeholders in model params, this context
+ # manager gathers (unpartitions) the params of the current layer, then loads from
+ # the state dict and then re-partitions them again
+ with deepspeed.zero.GatheredParameters(list(module.parameters(recurse=False)), modifier_rank=0):
+ if self.is_global_zero:
+ module._load_from_state_dict(
+ state_dict=state_dict,
+ prefix=prefix,
+ local_metadata=local_metadata,
+ strict=True,
+ missing_keys=missing_keys,
+ unexpected_keys=unexpected_keys,
+ error_msgs=error_msgs,
+ )
+
+ for name, child in module._modules.items():
+ if child is not None:
+ load(child, prefix + name + ".")
+
+ load(module, prefix="")
+
+ def _load_config(self, config: Optional[Union[_PATH, dict[str, Any]]]) -> Optional[dict[str, Any]]:
+ if config is None and self.DEEPSPEED_ENV_VAR in os.environ:
+ rank_zero_info(f"Loading DeepSpeed config from set {self.DEEPSPEED_ENV_VAR} environment variable")
+ config = os.environ[self.DEEPSPEED_ENV_VAR]
+ if isinstance(config, (str, Path)):
+ if not os.path.isfile(config):
+ raise FileNotFoundError(
+ f"You passed in a path to a DeepSpeed config but the path does not exist: {config}"
+ )
+ with open(config) as f:
+ config = json.load(f)
+ assert isinstance(config, dict) or config is None
+ return config
+
+
+def _get_deepspeed_engines_from_state(state: dict[str, Any]) -> list["DeepSpeedEngine"]:
+ from deepspeed import DeepSpeedEngine
+
+ modules = chain(*(module.modules() for module in state.values() if isinstance(module, Module)))
+ return [engine for engine in modules if isinstance(engine, DeepSpeedEngine)]
+
+
+def _validate_state_keys(state: dict[str, Any]) -> None:
+ # DeepSpeed merges the client state into its internal engine state when saving, but it does not check for
+ # colliding keys from the user. We explicitly check it here:
+ deepspeed_internal_keys = {
+ "module",
+ "buffer_names",
+ "optimizer",
+ "param_shapes",
+ "lr_scheduler",
+ "sparse_tensor_module_names",
+ "skipped_steps",
+ "global_steps",
+ "global_samples",
+ "dp_world_size",
+ "mp_world_size",
+ "ds_config",
+ "ds_version",
+ }
+ colliding_keys = deepspeed_internal_keys.intersection(state.keys())
+ if colliding_keys:
+ rank_zero_warn(
+ "Your state has keys that collide with DeepSpeed's internal engine state. This could result in your"
+ " values being overwritten by DeepSpeed. Consider changing the name of these keys to something else: "
+ + ", ".join(colliding_keys)
+ )
+
+
+def _validate_device_index_selection(parallel_devices: list[torch.device]) -> None:
+ selected_device_indices = [device.index for device in parallel_devices]
+ expected_device_indices = list(range(len(parallel_devices)))
+ if selected_device_indices != expected_device_indices:
+ raise RuntimeError(
+ f"The selected device indices {selected_device_indices!r} don't match the local rank values of processes."
+ " If you need to select GPUs at a specific index, set the `CUDA_VISIBLE_DEVICES` environment variable"
+ f" instead. For example: `CUDA_VISIBLE_DEVICES={','.join(str(i) for i in selected_device_indices)}`."
+ )
+
+
+def _is_deepspeed_checkpoint(path: Path) -> bool:
+ """Heuristic check whether the path points to a top-level DeepSpeed checkpoint directory."""
+ return path.is_dir() and (path / "checkpoint").is_dir()
+
+
+def _validate_checkpoint_directory(path: _PATH) -> None:
+ """Validates that the path points to a DeepSpeed checkpoint directory and suggests fixes for user error."""
+ # Example DeepSpeed checkpoint directory:
+ #
+ # epoch=5-step=10999.ckpt
+ # ├── checkpoint
+ # │ ├── zero_pp_rank_0_mp_rank_00_model_states.pt
+ # │ ├── zero_pp_rank_0_mp_rank_00_optim_states.pt
+ # │ ├── zero_pp_rank_1_mp_rank_00_model_states.pt
+ # │ └── zero_pp_rank_1_mp_rank_00_optim_states.pt
+ # ├── latest
+ # └── zero_to_fp32.py
+
+ path = Path(path)
+ path_is_ds_checkpoint = _is_deepspeed_checkpoint(path)
+ default_message = f"The provided path is not a valid DeepSpeed checkpoint: {path}"
+
+ if not path_is_ds_checkpoint:
+ # Case 1: User may have accidentally passed the subfolder "checkpoint"
+ parent_is_ds_checkpoint = _is_deepspeed_checkpoint(path.parent)
+ if parent_is_ds_checkpoint:
+ raise FileNotFoundError(
+ f"{default_message}. It looks like you passed the path to a subfolder."
+ f" Try to load using this parent directory instead: {path.parent}"
+ )
+ # Case 2: User may have accidentally passed the path to a file inside the "checkpoint" subfolder
+ parent_parent_is_ds_checkpoint = path.is_file() and _is_deepspeed_checkpoint(path.parent.parent)
+ if parent_parent_is_ds_checkpoint:
+ raise FileNotFoundError(
+ f"{default_message}. It looks like you passed the path to a file inside a DeepSpeed checkpoint folder."
+ f" Try to load using this parent directory instead: {path.parent.parent}"
+ )
+ raise FileNotFoundError(default_message)
+
+
+def _format_precision_config(
+ config: dict[str, Any],
+ precision: str,
+ loss_scale: float,
+ loss_scale_window: int,
+ min_loss_scale: int,
+ initial_scale_power: int,
+ hysteresis: int,
+) -> None:
+ if "fp16" not in config and precision in ("16-mixed", "16-true"):
+ # FP16 is a DeepSpeed standalone AMP implementation
+ rank_zero_info("Enabling DeepSpeed FP16. Model parameters and inputs will be cast to `float16`.")
+ config["fp16"] = {
+ "enabled": True,
+ "loss_scale": loss_scale,
+ "initial_scale_power": initial_scale_power,
+ "loss_scale_window": loss_scale_window,
+ "hysteresis": hysteresis,
+ "min_loss_scale": min_loss_scale,
+ }
+ elif "bf16" not in config and precision in ("bf16-mixed", "bf16-true"):
+ rank_zero_info("Enabling DeepSpeed BF16. Model parameters and inputs will be cast to `bfloat16`.")
+ config["bf16"] = {"enabled": True}
diff --git a/src/lightning/fabric/strategies/launchers/xla.py b/src/lightning/fabric/strategies/launchers/xla.py
index 566312754339b..639de55805646 100644
--- a/src/lightning/fabric/strategies/launchers/xla.py
+++ b/src/lightning/fabric/strategies/launchers/xla.py
@@ -11,12 +11,17 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import TYPE_CHECKING, Any, Callable, Union
+import queue
+import time
+from typing import TYPE_CHECKING, Any, Callable, Optional, Union
+import torch.multiprocessing as mp
from typing_extensions import override
+from lightning.fabric.accelerators.xla import _XLA_AVAILABLE
from lightning.fabric.strategies.launchers.launcher import _Launcher
-from lightning.fabric.utilities.imports import _raise_enterprise_not_available
+from lightning.fabric.strategies.launchers.multiprocessing import _GlobalStateSnapshot
+from lightning.fabric.utilities.apply_func import move_data_to_device
if TYPE_CHECKING:
from lightning.fabric.strategies import XLAFSDPStrategy, XLAStrategy
@@ -40,16 +45,15 @@ class _XLALauncher(_Launcher):
"""
def __init__(self, strategy: Union["XLAStrategy", "XLAFSDPStrategy"]) -> None:
- super().__init__()
- _raise_enterprise_not_available()
- from pytorch_lightning_enterprise.strategies.xla.launcher import XLALauncherFabric as EnterpriseXLALauncher
-
- self.xla_impl = EnterpriseXLALauncher(strategy=strategy)
+ if not _XLA_AVAILABLE:
+ raise ModuleNotFoundError(str(_XLA_AVAILABLE))
+ self._strategy = strategy
+ self._start_method = "fork"
@property
@override
def is_interactive_compatible(self) -> bool:
- return self.xla_impl.is_interactive_compatible
+ return True
@override
def launch(self, function: Callable, *args: Any, **kwargs: Any) -> Any:
@@ -64,12 +68,61 @@ def launch(self, function: Callable, *args: Any, **kwargs: Any) -> Any:
**kwargs: Optional keyword arguments to be passed to the given function.
"""
- return self.xla_impl.launch(function=function, *args, **kwargs)
-
- @property
- def _start_method(self) -> str:
- return self.xla_impl._start_method
-
- @_start_method.setter
- def _start_method(self, start_method: str) -> None:
- self.xla_impl._start_method = start_method
+ return_queue: Union[queue.Queue, mp.SimpleQueue]
+ return_queue = mp.Manager().Queue()
+
+ import torch_xla.distributed.xla_multiprocessing as xmp
+
+ spawn_kwargs = {}
+ nprocs = self._strategy.num_processes
+ if nprocs == 1:
+ # avoid warning: "Unsupported nprocs". If it's 1, it will call the launched function directly.
+ # otherwise it will use all devices
+ spawn_kwargs["nprocs"] = nprocs
+
+ xmp.spawn(
+ self._wrapping_function,
+ args=(function, args, kwargs, return_queue),
+ start_method=self._start_method,
+ **spawn_kwargs,
+ )
+ return return_queue.get()
+
+ def _wrapping_function(
+ self,
+ # XLA's multiprocessing returns the global index, not the local index as torch's multiprocessing
+ # https://github.com/pytorch/xla/blob/v1.13.0/torch_xla/distributed/xla_multiprocessing.py#L321
+ process_idx: int,
+ function: Callable,
+ args: Any,
+ kwargs: Any,
+ return_queue: Union[mp.SimpleQueue, queue.Queue],
+ global_states: Optional[_GlobalStateSnapshot] = None,
+ ) -> None:
+ import torch_xla.core.xla_model as xm
+
+ if len(xm.get_xla_supported_devices()) > 1:
+ # `get_xla_supported_devices` in the spawned process returns the logical devices (2 for v2/v3 and 1 for v4)
+ # so when there's more than one (multithreading), objects need to be deep-copied
+ import copy
+
+ function, args, kwargs = copy.deepcopy((function, args, kwargs))
+
+ results = function(*args, **kwargs)
+
+ if self._strategy.local_rank == 0:
+ return_queue.put(move_data_to_device(results, "cpu"))
+
+ _rank_teardown(self._strategy.local_rank)
+
+
+def _rank_teardown(rank: int) -> None:
+ import torch_xla.core.xla_model as xm
+
+ # Make all processes wait for each other before joining
+ # https://github.com/pytorch/xla/issues/1801#issuecomment-602799542
+ xm.rendezvous("end-process")
+ # Ensure that the rank 0 process is the one exiting last
+ # https://github.com/pytorch/xla/issues/2190#issuecomment-641665358
+ if rank == 0:
+ time.sleep(1)
diff --git a/src/lightning/fabric/strategies/single_xla.py b/src/lightning/fabric/strategies/single_xla.py
index 59b38ff6157f6..ba2fce91f1146 100644
--- a/src/lightning/fabric/strategies/single_xla.py
+++ b/src/lightning/fabric/strategies/single_xla.py
@@ -13,14 +13,15 @@
# limitations under the License.
from typing import Optional
+import torch
from typing_extensions import override
from lightning.fabric.accelerators import Accelerator
+from lightning.fabric.accelerators.xla import _XLA_AVAILABLE
from lightning.fabric.plugins import CheckpointIO, Precision, XLAPrecision
from lightning.fabric.plugins.io.xla import XLACheckpointIO
from lightning.fabric.strategies import _StrategyRegistry
from lightning.fabric.strategies.single_device import SingleDeviceStrategy
-from lightning.fabric.utilities.imports import _raise_enterprise_not_available
from lightning.fabric.utilities.types import _DEVICE
@@ -34,16 +35,20 @@ def __init__(
checkpoint_io: Optional[XLACheckpointIO] = None,
precision: Optional[XLAPrecision] = None,
):
- _raise_enterprise_not_available()
- from pytorch_lightning_enterprise.strategies.xla.single import validate_xla_strategy
+ if not _XLA_AVAILABLE:
+ raise ModuleNotFoundError(str(_XLA_AVAILABLE))
+ if isinstance(device, torch.device):
+ # unwrap the `torch.device` in favor of `xla_device`
+ device = device.index
+
+ import torch_xla.core.xla_model as xm
super().__init__(
accelerator=accelerator,
- device=device,
+ device=xm.xla_device(device),
checkpoint_io=checkpoint_io,
precision=precision,
)
- validate_xla_strategy(strategy=self, device=device)
@property
@override
diff --git a/src/lightning/fabric/strategies/xla.py b/src/lightning/fabric/strategies/xla.py
index 5a877a5c0dcc8..3a571fef37f00 100644
--- a/src/lightning/fabric/strategies/xla.py
+++ b/src/lightning/fabric/strategies/xla.py
@@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import io
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
import torch
@@ -21,13 +22,14 @@
from typing_extensions import override
from lightning.fabric.accelerators import Accelerator
+from lightning.fabric.accelerators.xla import _XLA_GREATER_EQUAL_2_1
from lightning.fabric.plugins import CheckpointIO, Precision, XLAPrecision
from lightning.fabric.plugins.environments import XLAEnvironment
from lightning.fabric.plugins.io.xla import XLACheckpointIO
from lightning.fabric.strategies import ParallelStrategy, _StrategyRegistry
from lightning.fabric.strategies.launchers.xla import _XLALauncher
from lightning.fabric.strategies.strategy import TBroadcast
-from lightning.fabric.utilities.imports import _raise_enterprise_not_available
+from lightning.fabric.utilities.rank_zero import rank_zero_only
from lightning.fabric.utilities.types import _PATH, ReduceOp
if TYPE_CHECKING:
@@ -53,19 +55,22 @@ def __init__(
checkpoint_io=checkpoint_io,
precision=precision,
)
- _raise_enterprise_not_available()
- from pytorch_lightning_enterprise.strategies.xla.ddp import XLAStrategyFabric as EnterpriseXLAStrategy
-
- self.xla_strategy_impl = EnterpriseXLAStrategy(outer_object=self, sync_module_states=sync_module_states)
+ self._backward_sync_control = None # XLA synchronizes gradients in the optimizer.step() call
+ self._launched = False
+ self._sync_module_states = sync_module_states
@property
@override
def root_device(self) -> torch.device:
- return self.xla_strategy_impl.root_device
+ if not self._launched:
+ raise RuntimeError("Accessing the XLA device before processes have spawned is not allowed.")
+ import torch_xla.core.xla_model as xm
+
+ return xm.xla_device()
@property
def num_processes(self) -> int:
- return self.xla_strategy_impl.num_processes
+ return len(self.parallel_devices) if self.parallel_devices is not None else 0
@property
@override
@@ -102,22 +107,22 @@ def precision(self, precision: Optional[Precision]) -> None:
@property
@override
def global_rank(self) -> int:
- return self.xla_strategy_impl.global_rank
+ return super().global_rank if self._launched else 0
@property
@override
def local_rank(self) -> int:
- return self.xla_strategy_impl.local_rank
+ return super().local_rank if self._launched else 0
@property
@override
def node_rank(self) -> int:
- return self.xla_strategy_impl.node_rank
+ return super().node_rank if self._launched else 0
@property
@override
def world_size(self) -> int:
- return self.xla_strategy_impl.world_size
+ return super().world_size if self._launched else 1
@override
def _configure_launcher(self) -> None:
@@ -125,19 +130,48 @@ def _configure_launcher(self) -> None:
@override
def setup_environment(self) -> None:
- return self.xla_strategy_impl.setup_environment()
+ assert self.parallel_devices is not None
+ if len(self.parallel_devices) == 1:
+ # spawning only 1 device with PjRT is not supported:
+ # https://github.com/Lightning-AI/pytorch-lightning/pull/17408#discussion_r1170671732
+ raise NotImplementedError(
+ f"The {type(self).__name__} does not support running on a single device with the PjRT runtime."
+ " Try using all devices or the `SingleDeviceXLAStrategy` strategy"
+ )
+
+ self._launched = True
+ rank_zero_only.rank = self.global_rank
+ super().setup_environment()
@override
def setup_module(self, module: Module) -> Module:
- return self.xla_strategy_impl.setup_module(module=module)
+ if self._sync_module_states:
+ if _XLA_GREATER_EQUAL_2_1:
+ from torch_xla.core.xla_model import broadcast_master_param
+ else:
+ from torch_xla.experimental.pjrt import broadcast_master_param
+
+ broadcast_master_param(module)
+
+ return module
@override
def module_to_device(self, module: Module) -> None:
- return self.xla_strategy_impl.module_to_device(module=module)
+ module.to(self.root_device)
@override
def process_dataloader(self, dataloader: DataLoader) -> "MpDeviceLoader":
- return self.xla_strategy_impl.process_dataloader(dataloader=dataloader)
+ from torch_xla.distributed.parallel_loader import MpDeviceLoader
+
+ if isinstance(dataloader, MpDeviceLoader):
+ # dataloader is already wrapped by MpDeviceLoader
+ return dataloader
+
+ dataloader = MpDeviceLoader(dataloader, self.root_device)
+ # Mimic interface to torch.utils.data.DataLoader
+ dataloader.dataset = dataloader._loader.dataset
+ dataloader.batch_sampler = getattr(dataloader._loader, "batch_sampler", None)
+ return dataloader
@override
def all_gather(self, tensor: Tensor, group: Optional[Any] = None, sync_grads: bool = False) -> Tensor:
@@ -151,21 +185,92 @@ def all_gather(self, tensor: Tensor, group: Optional[Any] = None, sync_grads: bo
A tensor of shape (world_size, ...)
"""
- return self.xla_strategy_impl.all_gather(tensor=tensor, group=group, sync_grads=sync_grads)
+ if not self._launched:
+ return tensor
+ if not isinstance(tensor, Tensor):
+ raise NotImplementedError(
+ f"`{type(self).__name__}.all_gather` is only implemented for tensors. Given {tensor}"
+ )
+ if tensor.dim() == 0:
+ tensor = tensor.unsqueeze(0)
+ original_device = tensor.device
+ tensor = tensor.to(self.root_device)
+
+ import torch_xla.core.functions as xf
+ import torch_xla.core.xla_model as xm
+
+ tensor = xf.all_gather(tensor) if sync_grads else xm.all_gather(tensor)
+ tensor = tensor.to(original_device)
+ return tensor
@override
def all_reduce(
self, output: Union[Tensor, Any], group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = None
) -> Tensor:
- return self.xla_strategy_impl.all_reduce(output=output, group=group, reduce_op=reduce_op)
+ if not isinstance(output, Tensor):
+ output = torch.tensor(output, device=self.root_device)
+
+ invalid_reduce_op = isinstance(reduce_op, ReduceOp) and reduce_op != ReduceOp.SUM
+ invalid_reduce_op_str = isinstance(reduce_op, str) and reduce_op.lower() not in ("sum", "mean", "avg")
+ if invalid_reduce_op or invalid_reduce_op_str:
+ raise ValueError(
+ "Currently, the XLAStrategy only supports `sum`, `mean`, `avg` for the reduce operation, got:"
+ f" {reduce_op}"
+ )
+ import torch_xla.core.xla_model as xm
+
+ output = xm.mesh_reduce("reduce", output, sum)
+
+ if isinstance(reduce_op, str) and reduce_op.lower() in ("avg", "mean"):
+ output = output / self.world_size
+
+ return output
@override
def barrier(self, name: Optional[str] = None, *args: Any, **kwargs: Any) -> None:
- return self.xla_strategy_impl.barrier(name=name, *args, **kwargs)
+ if not self._launched:
+ return
+ import torch_xla.core.xla_model as xm
+
+ if name is None:
+ # `None` is not supported: "TypeError: _xla_rendezvous(): incompatible function arguments"
+ name = ""
+ xm.rendezvous(name)
@override
def broadcast(self, obj: TBroadcast, src: int = 0) -> TBroadcast:
- return self.xla_strategy_impl.broadcast(obj=obj, src=src)
+ if not self._launched:
+ return obj
+
+ import torch_xla.core.xla_model as xm
+
+ is_tensor = isinstance(obj, Tensor)
+ if is_tensor:
+ if obj.dim() == 0:
+ obj = obj.unsqueeze(0)
+ original_device = obj.device
+ # XLA distributed requires that the data is on the XLA device
+ obj = obj.to(self.root_device)
+ else:
+ # support for arbitrary pickle-ables
+ buffer = io.BytesIO()
+ torch.save(obj, buffer)
+ obj = torch.tensor( # type: ignore[assignment]
+ bytearray(buffer.getbuffer()), device=self.root_device, dtype=torch.float
+ )
+
+ obj = [obj]
+ xm.collective_broadcast(obj, root_ordinal=src)
+ obj = obj[0]
+
+ if not is_tensor:
+ # this will preserve the dtype and device of any tensors
+ buffer = io.BytesIO(obj.cpu().byte().numpy())
+ obj = torch.load(buffer)
+ else:
+ obj = obj.to(original_device)
+
+ return obj
@override
def save_checkpoint(
@@ -186,9 +291,12 @@ def save_checkpoint(
boolean indicating whether the given parameter should be saved (``True``) or filtered out (``False``).
"""
- return self.xla_strategy_impl.save_checkpoint(
- path=path, state=state, storage_options=storage_options, filter=filter
- )
+ import torch_xla.core.xla_model as xm
+
+ # sync any pending lazy tensors on all ranks before saving to prevent potential collective hangs
+ xm.mark_step()
+ # save on global rank zero only
+ super().save_checkpoint(path, state, storage_options=storage_options, filter=filter)
@classmethod
@override
diff --git a/src/lightning/fabric/strategies/xla_fsdp.py b/src/lightning/fabric/strategies/xla_fsdp.py
index f54267ca59a72..51b528eff26ff 100644
--- a/src/lightning/fabric/strategies/xla_fsdp.py
+++ b/src/lightning/fabric/strategies/xla_fsdp.py
@@ -11,7 +11,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from contextlib import AbstractContextManager
+import io
+from contextlib import AbstractContextManager, ExitStack, nullcontext
+from functools import partial
+from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable, Literal, Optional, Union
import torch
@@ -22,16 +25,22 @@
from typing_extensions import override
from lightning.fabric.accelerators import Accelerator
+from lightning.fabric.accelerators.xla import _XLA_AVAILABLE
from lightning.fabric.plugins import CheckpointIO, Precision, XLAPrecision
from lightning.fabric.plugins.environments import XLAEnvironment
from lightning.fabric.plugins.io.xla import XLACheckpointIO
from lightning.fabric.strategies import ParallelStrategy, _StrategyRegistry
+from lightning.fabric.strategies.fsdp import _apply_filter
from lightning.fabric.strategies.launchers.xla import _XLALauncher
from lightning.fabric.strategies.strategy import (
TBroadcast,
+ _BackwardSyncControl,
_Sharded,
+ _validate_keys_for_strict_loading,
)
-from lightning.fabric.utilities.imports import _raise_enterprise_not_available
+from lightning.fabric.utilities.cloud_io import get_filesystem
+from lightning.fabric.utilities.init import _EmptyInit
+from lightning.fabric.utilities.rank_zero import rank_zero_only, rank_zero_warn
from lightning.fabric.utilities.types import _PATH, Optimizable, ReduceOp
if TYPE_CHECKING:
@@ -84,6 +93,8 @@ def __init__(
sequential_save: bool = False,
**kwargs: Any,
) -> None:
+ if not _XLA_AVAILABLE:
+ raise ModuleNotFoundError(str(_XLA_AVAILABLE))
super().__init__(
accelerator=accelerator,
parallel_devices=parallel_devices,
@@ -91,28 +102,27 @@ def __init__(
checkpoint_io=checkpoint_io,
precision=precision,
)
- _raise_enterprise_not_available()
- from pytorch_lightning_enterprise.strategies.xla.fsdp import (
- XLAFSDPStrategyFabric as EnterpriseXLAFSDPStrategy,
- )
+ self._backward_sync_control = _XLAFSDPBackwardSyncControl()
- self.xla_fsdp_impl = EnterpriseXLAFSDPStrategy(
- outer_object=self,
- auto_wrap_policy=auto_wrap_policy,
- activation_checkpointing_policy=activation_checkpointing_policy,
- state_dict_type=state_dict_type,
- sequential_save=sequential_save,
- **kwargs,
- )
+ self._auto_wrap_policy = auto_wrap_policy
+ self._activation_checkpointing_policy = activation_checkpointing_policy
+ self._fsdp_kwargs = kwargs
+ self._state_dict_type = state_dict_type
+ self._sequential_save = sequential_save
+ self._launched = False
@property
@override
def root_device(self) -> torch.device:
- return self.xla_fsdp_impl.root_device
+ if not self._launched:
+ raise RuntimeError("Accessing the XLA device before processes have spawned is not allowed.")
+ import torch_xla.core.xla_model as xm
+
+ return xm.xla_device()
@property
def num_processes(self) -> int:
- return self.xla_fsdp_impl.num_processes
+ return len(self.parallel_devices) if self.parallel_devices is not None else 0
@property
@override
@@ -149,22 +159,22 @@ def precision(self, precision: Optional[Precision]) -> None:
@property
@override
def global_rank(self) -> int:
- return self.xla_fsdp_impl.global_rank
+ return super().global_rank if self._launched else 0
@property
@override
def local_rank(self) -> int:
- return self.xla_fsdp_impl.local_rank
+ return super().local_rank if self._launched else 0
@property
@override
def node_rank(self) -> int:
- return self.xla_fsdp_impl.node_rank
+ return super().node_rank if self._launched else 0
@property
@override
def world_size(self) -> int:
- return self.xla_fsdp_impl.world_size
+ return super().world_size if self._launched else 1
@override
def _configure_launcher(self) -> None:
@@ -172,40 +182,108 @@ def _configure_launcher(self) -> None:
@override
def setup_environment(self) -> None:
- return self.xla_fsdp_impl.setup_environment()
+ assert self.parallel_devices is not None
+ if len(self.parallel_devices) == 1:
+ # spawning only 1 device with PjRT is not supported:
+ # https://github.com/Lightning-AI/pytorch-lightning/pull/17408#discussion_r1170671732
+ raise NotImplementedError(
+ f"The {type(self).__name__} does not support running on a single device with the PjRT runtime."
+ " Try using all devices or the `SingleDeviceXLAStrategy` strategy"
+ )
+
+ self._launched = True
+ rank_zero_only.rank = self.global_rank
+ super().setup_environment()
@override
def setup_module_and_optimizers(
self, module: Module, optimizers: list[Optimizer], scheduler: Optional["_LRScheduler"] = None
) -> tuple[Module, list[Optimizer], Optional["_LRScheduler"]]:
- return self.xla_fsdp_impl.setup_module_and_optimizers(module=module, optimizers=optimizers, scheduler=scheduler)
+ """Returns NotImplementedError since for XLAFSDP optimizer setup must happen after module setup."""
+ raise NotImplementedError(
+ f"The `{type(self).__name__}` does not support the joint setup of module and optimizer(s)."
+ " Please do it in this order: Create the model, call `setup_module`, create the optimizer,"
+ " call `setup_optimizer`."
+ )
@override
def setup_module(self, module: Module) -> Module:
- return self.xla_fsdp_impl.setup_module(module=module)
+ from torch_xla.distributed.fsdp import XlaFullyShardedDataParallel as XLAFSDP
+
+ kwargs = self._parse_fsdp_kwargs()
+ if any(isinstance(mod, XLAFSDP) for mod in module.modules()) and "auto_wrap_policy" in kwargs:
+ rank_zero_warn(
+ "A XLAFSDP `auto_wrap_policy` is set, but at least one submodule is already wrapped."
+ " The policy will be ignored."
+ )
+ del kwargs["auto_wrap_policy"]
+ # XLA FSDP requires that the root is wrapped, even if submodules are already wrapped
+ if not isinstance(module, XLAFSDP):
+ module = XLAFSDP(module=module, **kwargs)
+ return module
@override
def module_to_device(self, module: Module) -> None:
- return self.xla_fsdp_impl.module_to_device(module=module)
+ pass
def module_init_context(self, empty_init: Optional[bool] = None) -> AbstractContextManager:
- return self.xla_fsdp_impl.module_init_context(empty_init=empty_init)
+ precision_init_ctx = self.precision.module_init_context()
+ module_sharded_ctx = self.module_sharded_context()
+ stack = ExitStack()
+ stack.enter_context(_EmptyInit(enabled=bool(empty_init)))
+ stack.enter_context(precision_init_ctx)
+ stack.enter_context(module_sharded_ctx)
+ return stack
@override
def module_sharded_context(self) -> AbstractContextManager:
- return self.xla_fsdp_impl.module_sharded_context()
+ return nullcontext()
@override
def process_dataloader(self, dataloader: DataLoader) -> "MpDeviceLoader":
- return self.xla_fsdp_impl.process_dataloader(dataloader=dataloader)
+ from torch_xla.distributed.parallel_loader import MpDeviceLoader
+
+ if isinstance(dataloader, MpDeviceLoader):
+ # dataloader is already wrapped by MpDeviceLoader
+ return dataloader
+
+ dataloader = MpDeviceLoader(dataloader, self.root_device)
+ # Mimic interface to torch.utils.data.DataLoader
+ dataloader.dataset = dataloader._loader.dataset
+ dataloader.batch_sampler = getattr(dataloader._loader, "batch_sampler", None)
+ return dataloader
@override
def setup_optimizer(self, optimizer: Optimizer) -> Optimizer:
- return self.xla_fsdp_impl.setup_optimizer(optimizer=optimizer)
+ """Set up an optimizer for a model wrapped with XLAFSDP.
+
+ This setup method doesn't modify the optimizer or wrap the optimizer. The only thing it currently does is verify
+ that the optimizer was created after the model was wrapped with :meth:`setup_module` with a reference to the
+ flattened parameters.
+
+ """
+ if any(getattr(p, "_is_sharded", False) for group in optimizer.param_groups for p in group["params"]):
+ return optimizer
+ raise ValueError(
+ "The optimizer does not seem to reference any XLAFSDP parameters. HINT: Make sure to create the optimizer"
+ " after setting up the model."
+ )
@override
def optimizer_step(self, optimizer: Optimizable, **kwargs: Any) -> Any:
- return self.xla_fsdp_impl.optimizer_step(optimizer=optimizer, **kwargs)
+ """Overrides default tpu optimizer_step since FSDP should not call `torch_xla.core.xla_model.optimizer_step`.
+ Performs the actual optimizer step.
+
+ Args:
+ optimizer: the optimizer performing the step
+ **kwargs: Any extra arguments to ``optimizer.step``
+
+ """
+ loss = optimizer.step(**kwargs)
+ import torch_xla.core.xla_model as xm
+
+ xm.mark_step()
+ return loss
@override
def clip_gradients_norm(
@@ -217,18 +295,17 @@ def clip_gradients_norm(
error_if_nonfinite: bool = True,
) -> Tensor:
"""Clip gradients by norm."""
- return self.xla_fsdp_impl.clip_gradients_norm(
- module=module,
- optimizer=optimizer,
- max_norm=max_norm,
- norm_type=norm_type,
- error_if_nonfinite=error_if_nonfinite,
- )
+ self.precision.unscale_gradients(optimizer)
+ assert callable(module.clip_grad_norm_)
+ return module.clip_grad_norm_(max_norm=max_norm, norm_type=norm_type)
@override
def clip_gradients_value(self, module: Module, optimizer: Optimizer, clip_val: Union[float, int]) -> None:
"""Clip gradients by value."""
- return self.xla_fsdp_impl.clip_gradients_value(module=module, optimizer=optimizer, clip_val=clip_val)
+ raise NotImplementedError(
+ "XLA's FSDP strategy does not support to clip gradients by value."
+ " Consider clipping by norm instead or choose another strategy!"
+ )
@override
def all_gather(self, tensor: Tensor, group: Optional[Any] = None, sync_grads: bool = False) -> Tensor:
@@ -242,21 +319,92 @@ def all_gather(self, tensor: Tensor, group: Optional[Any] = None, sync_grads: bo
A tensor of shape (world_size, ...)
"""
- return self.xla_fsdp_impl.all_gather(tensor=tensor, group=group, sync_grads=sync_grads)
+ if not self._launched:
+ return tensor
+ if not isinstance(tensor, Tensor):
+ raise NotImplementedError(
+ f"`{type(self).__name__}.all_gather` is only implemented for tensors. Given {tensor}"
+ )
+ if tensor.dim() == 0:
+ tensor = tensor.unsqueeze(0)
+ original_device = tensor.device
+ tensor = tensor.to(self.root_device)
+
+ import torch_xla.core.functions as xf
+ import torch_xla.core.xla_model as xm
+
+ tensor = xf.all_gather(tensor) if sync_grads else xm.all_gather(tensor)
+ tensor = tensor.to(original_device)
+ return tensor
@override
def all_reduce(
self, output: Union[Tensor, Any], group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = None
) -> Tensor:
- return self.xla_fsdp_impl.all_reduce(output=output, group=group, reduce_op=reduce_op)
+ if not isinstance(output, Tensor):
+ output = torch.tensor(output, device=self.root_device)
+
+ invalid_reduce_op = isinstance(reduce_op, ReduceOp) and reduce_op != ReduceOp.SUM
+ invalid_reduce_op_str = isinstance(reduce_op, str) and reduce_op.lower() not in ("sum", "mean", "avg")
+ if invalid_reduce_op or invalid_reduce_op_str:
+ raise ValueError(
+ "Currently, the XLAFSDPStrategy only supports `sum`, `mean`, `avg` for the reduce operation, got:"
+ f" {reduce_op}"
+ )
+ import torch_xla.core.xla_model as xm
+
+ output = xm.mesh_reduce("reduce", output, sum)
+
+ if isinstance(reduce_op, str) and reduce_op.lower() in ("avg", "mean"):
+ output = output / self.world_size
+
+ return output
@override
def barrier(self, name: Optional[str] = None, *args: Any, **kwargs: Any) -> None:
- return self.xla_fsdp_impl.barrier(name=name, *args, **kwargs)
+ if not self._launched:
+ return
+ import torch_xla.core.xla_model as xm
+
+ if name is None:
+ # `None` is not supported: "TypeError: _xla_rendezvous(): incompatible function arguments"
+ name = ""
+ xm.rendezvous(name)
@override
def broadcast(self, obj: TBroadcast, src: int = 0) -> TBroadcast:
- return self.xla_fsdp_impl.broadcast(obj=obj, src=src)
+ if not self._launched:
+ return obj
+
+ import torch_xla.core.xla_model as xm
+
+ is_tensor = isinstance(obj, Tensor)
+ if is_tensor:
+ if obj.dim() == 0:
+ obj = obj.unsqueeze(0)
+ original_device = obj.device
+ # XLA distributed requires that the data is on the XLA device
+ obj = obj.to(self.root_device)
+ else:
+ # support for arbitrary pickle-ables
+ buffer = io.BytesIO()
+ torch.save(obj, buffer)
+ obj = torch.tensor( # type: ignore[assignment]
+ bytearray(buffer.getbuffer()), device=self.root_device, dtype=torch.float
+ )
+
+ obj = [obj]
+ xm.collective_broadcast(obj, root_ordinal=src)
+ obj = obj[0]
+
+ if not is_tensor:
+ # this will preserve the dtype and device of any tensors
+ buffer = io.BytesIO(obj.cpu().byte().numpy())
+ obj = torch.load(buffer)
+ else:
+ obj = obj.to(original_device)
+
+ return obj
@override
def save_checkpoint(
@@ -273,8 +421,93 @@ def save_checkpoint(
consolidated checkpoint combining all of the sharded checkpoints.
"""
- return self.xla_fsdp_impl.save_checkpoint(
- path=path, state=state, storage_options=storage_options, filter=filter
+ # broadcast the path from rank 0 to ensure all the states are saved in a common path
+ path = Path(self.broadcast(path))
+ if path.is_dir() and any(path.iterdir()):
+ raise FileExistsError(f"The checkpoint directory already exists and is not empty: {path}")
+ from torch_xla.distributed.fsdp import XlaFullyShardedDataParallel as XLAFSDP
+
+ modules = [module for module in state.values() if isinstance(module, XLAFSDP)]
+ if len(modules) == 0:
+ raise ValueError(
+ "Could not find a XLAFSDP model in the provided checkpoint state. Please provide the model as"
+ " part of the state like so: `save_checkpoint(..., state={'model': model, ...})`. Make sure"
+ " you set up the model (and optimizers if any) through the strategy before saving the checkpoint."
+ )
+ if len(modules) > 1:
+ raise ValueError(
+ "Found multiple XLAFSDP modules in the given state. Saving checkpoints with FSDP is"
+ " currently limited to a single model per checkpoint. To save multiple models, call the"
+ " save method for each model separately with a different path."
+ )
+ import torch_xla.core.xla_model as xm
+
+ # ensure model parameters are updated
+ xm.mark_step()
+
+ parallel_devices = self.parallel_devices
+ assert parallel_devices is not None
+ if self._sequential_save:
+ # each host runs this in parallel, but the ranks in the host run it sequentially
+ for rank in range(len(parallel_devices)):
+ if rank == self.local_rank:
+ self._save_checkpoint_shard(path, state, storage_options, filter)
+ self.barrier(f"wait-for-{rank}-save")
+ else:
+ self._save_checkpoint_shard(path, state, storage_options, filter)
+
+ if self._state_dict_type == "full":
+ ckpt_prefix = str(path / "checkpoint")
+ ckpt_suffix = "_rank-*-of-*.pth"
+ if len(parallel_devices) != self.world_size: # multihost
+ raise OSError(
+ "Multihost setups do not have a shared filesystem, so the checkpoint shards cannot be consolidated"
+ " into a single checkpoint after saving them. Please switch to"
+ " `XLAFSDPStrategy(state_dict_type='sharded')`. TIP: You can consolidate them manually by getting"
+ " them together into a single directory and running `python -m"
+ f" torch_xla.distributed.fsdp.consolidate_sharded_ckpts --ckpt_prefix {ckpt_prefix!r} --ckpt_suffix"
+ f" {ckpt_suffix!r} --save_path 'path/to/consolidated.ckpt'`."
+ )
+
+ from torch_xla.distributed.fsdp import consolidate_sharded_model_checkpoints
+
+ self.barrier("before_ckpt_consolidation")
+ if self.is_global_zero:
+ save_path = path.parent / "consolidated.ckpt"
+ # save consolidated checkpoint separate to the shards
+ consolidate_sharded_model_checkpoints(ckpt_prefix, ckpt_suffix, str(save_path))
+ # remove the shards directory
+ self.checkpoint_io.remove_checkpoint(path)
+ # mv the consolidated checkpoint where the user would expect it
+ get_filesystem(save_path).mv(str(save_path), str(path))
+ self.barrier("after_ckpt_consolidation")
+
+ def _save_checkpoint_shard(
+ self,
+ path: Path,
+ state: dict[str, Union[Module, Optimizer, Any]],
+ storage_options: Optional[Any],
+ filter: Optional[dict[str, Callable[[str, Any], bool]]],
+ ) -> None:
+ from torch_xla.distributed.fsdp import XlaFullyShardedDataParallel as XLAFSDP
+
+ converted_state: dict[str, Any] = {}
+ for key, obj in state.items():
+ # convert the state
+ if isinstance(obj, Module) and isinstance(obj, XLAFSDP):
+ converted = obj.state_dict()
+ # add shard_metadata to state
+ converted_state["shard_metadata"] = obj.get_shard_metadata()
+ elif isinstance(obj, Optimizer):
+ converted = obj.state_dict()
+ else:
+ converted = obj
+ _apply_filter(key, filter or {}, converted, converted_state)
+
+ self.checkpoint_io.save_checkpoint(
+ converted_state,
+ path / f"checkpoint_rank-{self.global_rank:08d}-of-{self.world_size:08d}.pth",
+ storage_options=storage_options,
)
@override
@@ -291,9 +524,164 @@ def load_checkpoint(
directory of multiple files rather than a single file.
"""
- return self.xla_fsdp_impl.load_checkpoint(path=path, state=state, strict=strict, weights_only=weights_only)
+ if not state:
+ raise ValueError(
+ f"Got `XLAFSDPStrategy.load_checkpoint(..., state={state!r})` but a state with at least "
+ " a model instance to reload is required. Pass it in like so:"
+ " `FSDPStrategy.load_checkpoint(..., state={'model': model, ...})`"
+ )
+
+ # broadcast the path from rank 0 to ensure all the states are loaded from a common path
+ path = Path(self.broadcast(path))
+
+ if isinstance(state, (Module, Optimizer)):
+ raise NotImplementedError(
+ "Loading a single module or optimizer object from a checkpoint"
+ " is not supported yet with the XLAFSDP strategy."
+ )
+
+ from torch_xla.distributed.fsdp import XlaFullyShardedDataParallel as XLAFSDP
+
+ modules = {key: module for key, module in state.items() if isinstance(module, XLAFSDP)}
+ optimizers = {key: optim for key, optim in state.items() if isinstance(optim, Optimizer)}
+ if self._state_dict_type == "sharded":
+ file = path / f"checkpoint_rank-{self.global_rank:08d}-of-{self.world_size:08d}.pth"
+ if not file.is_file():
+ raise ValueError(
+ f"The path {str(file)!r} does not point to valid sharded checkpoints. Make sure the path points to"
+ " a directory with XLAFSDP checkpoint shards."
+ )
+ if len(modules) == 0:
+ raise ValueError(
+ "Could not find a XLAFSDP model in the provided checkpoint state. Please provide the model as"
+ " part of the state like so: `load_checkpoint(..., state={'model': model, ...})`. Make sure"
+ " you set up the model (and optimizers if any) through the strategy before loading the checkpoint."
+ )
+ if len(modules) > 1:
+ raise ValueError(
+ "Found multiple XLAFSDP modules in the given state. Loading checkpoints with FSDP is"
+ " currently limited to a single model per checkpoint. To load multiple models, call the"
+ " load method for each model separately with a different path."
+ )
+
+ _, module = list(modules.items())[0]
+ sharded_ckpt = torch.load(file)
+
+ module.load_state_dict(sharded_ckpt["model"], strict=strict)
+ for opt_key, opt in optimizers.items():
+ opt.load_state_dict(sharded_ckpt[opt_key])
+
+ # Load anything leftover from sharded_ckpt
+ loaded_metadata_keys = sharded_ckpt.keys() - modules.keys() - optimizers.keys()
+ requested_metadata_keys = state.keys() - modules.keys() - optimizers.keys()
+ _validate_keys_for_strict_loading(requested_metadata_keys, loaded_metadata_keys, strict=strict)
+ for key in requested_metadata_keys:
+ if key in loaded_metadata_keys:
+ state[key] = sharded_ckpt[key]
+ loaded_metadata_keys.remove(key)
+
+ metadata = {}
+ if len(loaded_metadata_keys):
+ for key in loaded_metadata_keys:
+ metadata[key] = sharded_ckpt[key]
+
+ # remove "shard_metadata" that is loaded in
+ if "shard_metadata" in metadata:
+ metadata.pop("shard_metadata")
+
+ return metadata
+
+ if self._state_dict_type == "full":
+ if not path.is_file():
+ raise ValueError(
+ f"The path {str(path)!r} does not point to a valid full checkpoint. Make sure the path points to a"
+ " directory with a full XLAFSDP checkpoint."
+ )
+ if len(optimizers) > 0 or len(state.keys() - modules.keys() - optimizers.keys()) > 0:
+ rank_zero_warn(
+ "Loading a full checkpoint will only load the full model."
+ " The optimizer and any additional metadata are not included."
+ )
+ if len(modules) > 0:
+ raise ValueError(
+ "Found a XLAFSDP model in the provided checkpoint state."
+ " Please provide the model without any XLAFSDP wrapper."
+ )
+ if "model" not in state or not isinstance(model := state["model"], torch.nn.Module):
+ raise NotImplementedError("XLAFSDP only supports a single model instance with 'model' as the key.")
+ full_ckpt = torch.load(path, weights_only=weights_only)
+ model.load_state_dict(full_ckpt.pop("model"), strict=strict)
+ return full_ckpt
+
+ raise ValueError(f"Unknown state_dict_type: {self._state_dict_type}")
@classmethod
@override
def register_strategies(cls, strategy_registry: _StrategyRegistry) -> None:
strategy_registry.register("xla_fsdp", cls, description=cls.__name__)
+
+ def _parse_fsdp_kwargs(self) -> dict:
+ # this needs to be delayed because `self.precision` isn't available at init
+ kwargs = self._fsdp_kwargs.copy()
+ precision = self.precision
+ if isinstance(precision, XLAPrecision):
+ # the `compute_dtype` will be passed to the `auto_wrapper_callable` automatically, so we don't need to pass
+ # it when creating it
+ kwargs.setdefault("compute_dtype", precision._desired_dtype)
+ kwargs = _auto_wrap_policy_kwargs(self._auto_wrap_policy, kwargs)
+ return _activation_checkpointing_kwargs(self._activation_checkpointing_policy, kwargs)
+
+
+def _auto_wrap_policy_kwargs(policy: Optional["_POLICY"], kwargs: dict) -> dict:
+ if policy is None:
+ return kwargs
+ if isinstance(policy, set):
+ from torch_xla.distributed.fsdp.wrap import transformer_auto_wrap_policy
+
+ # this is not transformer specific despite the name
+ policy = partial(transformer_auto_wrap_policy, transformer_layer_cls=policy)
+ kwargs["auto_wrap_policy"] = policy
+ return kwargs
+
+
+def _activation_checkpointing_auto_wrapper(policy: _POLICY_SET, module: Module, *args: Any, **kwargs: Any) -> Module:
+ from torch_xla.distributed.fsdp import XlaFullyShardedDataParallel as XLAFSDP
+ from torch_xla.distributed.fsdp import checkpoint_module
+
+ module = checkpoint_module(module) if isinstance(module, tuple(policy)) else module
+ return XLAFSDP(module, *args, **kwargs)
+
+
+def _activation_checkpointing_kwargs(policy: Optional[_POLICY_SET], kwargs: dict) -> dict:
+ if not policy:
+ return kwargs
+ if "auto_wrapper_callable" in kwargs:
+ raise ValueError(
+ "You cannot set both `auto_wrapper_callable` and `activation_checkpointing_policy`. Choose one"
+ )
+ if not isinstance(policy, set):
+ raise TypeError(
+ f"`activation_checkpointing_policy` must be a set, found {policy}. You can try defining and"
+ " passing `auto_wrapper_callable` instead."
+ )
+ auto_wrapper_callable = partial(_activation_checkpointing_auto_wrapper, policy)
+ kwargs["auto_wrapper_callable"] = auto_wrapper_callable
+ return kwargs
+
+
+class _XLAFSDPBackwardSyncControl(_BackwardSyncControl):
+ @override
+ def no_backward_sync(self, module: Module, enabled: bool) -> AbstractContextManager:
+ """Blocks gradient synchronization inside the :class:`~torch_xla.distributed.fsdp.XlaFullyShardedDataParallel`
+ wrapper."""
+ if not enabled:
+ return nullcontext()
+ from torch_xla.distributed.fsdp import XlaFullyShardedDataParallel as XLAFSDP
+
+ if not isinstance(module, XLAFSDP):
+ raise TypeError(
+ "Blocking backward sync is only possible if the module passed to"
+ f" `{self.__class__.__name__}.no_backward_sync` is wrapped in `XlaFullyShardedDataParallel`."
+ f" Got: {module.__class__.__name__}."
+ )
+ return module.no_sync()
diff --git a/src/lightning/fabric/utilities/imports.py b/src/lightning/fabric/utilities/imports.py
index bac63dd9419c5..ac706cc403a5d 100644
--- a/src/lightning/fabric/utilities/imports.py
+++ b/src/lightning/fabric/utilities/imports.py
@@ -40,22 +40,3 @@
_TORCH_GREATER_EQUAL_2_6 = compare_version("torch", operator.ge, "2.6.0")
_TORCHMETRICS_GREATER_EQUAL_1_0_0 = compare_version("torchmetrics", operator.ge, "1.0.0")
_PYTHON_GREATER_EQUAL_3_10_0 = (sys.version_info.major, sys.version_info.minor) >= (3, 10)
-
-_WANDB_AVAILABLE = RequirementCache("wandb>=0.12.10")
-_COMET_AVAILABLE = RequirementCache("comet-ml>=3.44.4")
-_MLFLOW_AVAILABLE = RequirementCache("mlflow>=1.0.0")
-_MLFLOW_SYNCHRONOUS_AVAILABLE = RequirementCache("mlflow>=2.8.0")
-_NEPTUNE_AVAILABLE = RequirementCache("neptune>=1.0")
-
-_DEEPSPEED_AVAILABLE = RequirementCache("deepspeed")
-_DEEPSPEED_GREATER_EQUAL_0_16 = RequirementCache("deepspeed>=0.16.0")
-_ENTERPRISE_AVAILABLE = RequirementCache("pytorch_lightning_enterprise")
-_TRANSFORMER_ENGINE_AVAILABLE = RequirementCache("transformer_engine>=0.11.0")
-
-
-def _raise_enterprise_not_available() -> None:
- if not _ENTERPRISE_AVAILABLE:
- raise ModuleNotFoundError(
- "pytorch_lightning_enterprise is required to use this feature. "
- "Install it with `pip install pytorch-lightning-enterprise`"
- )
diff --git a/src/lightning/fabric/utilities/testing/_runif.py b/src/lightning/fabric/utilities/testing/_runif.py
index 92521362a2885..d085e4138d742 100644
--- a/src/lightning/fabric/utilities/testing/_runif.py
+++ b/src/lightning/fabric/utilities/testing/_runif.py
@@ -23,7 +23,8 @@
from lightning.fabric.accelerators import XLAAccelerator
from lightning.fabric.accelerators.cuda import num_cuda_devices
from lightning.fabric.accelerators.mps import MPSAccelerator
-from lightning.fabric.utilities.imports import _DEEPSPEED_AVAILABLE, _TORCH_GREATER_EQUAL_2_4
+from lightning.fabric.strategies.deepspeed import _DEEPSPEED_AVAILABLE
+from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_4
def _runif_reasons(
diff --git a/src/lightning/pytorch/accelerators/xla.py b/src/lightning/pytorch/accelerators/xla.py
index c8b322f477d4d..700a29c2e3e7b 100644
--- a/src/lightning/pytorch/accelerators/xla.py
+++ b/src/lightning/pytorch/accelerators/xla.py
@@ -39,7 +39,15 @@ def get_device_stats(self, device: _DEVICE) -> dict[str, Any]:
A dictionary mapping the metrics (free memory and peak memory) to their values.
"""
- return self.accelerator_impl.get_device_stats(device)
+ import torch_xla.core.xla_model as xm
+
+ memory_info = xm.get_memory_info(device)
+ free_memory = memory_info["kb_free"]
+ peak_memory = memory_info["kb_total"] - free_memory
+ return {
+ "avg. free memory (MB)": free_memory,
+ "avg. peak memory (MB)": peak_memory,
+ }
@staticmethod
@override
diff --git a/src/lightning/pytorch/loggers/comet.py b/src/lightning/pytorch/loggers/comet.py
index 5e57e38dc4072..b544212e755e2 100644
--- a/src/lightning/pytorch/loggers/comet.py
+++ b/src/lightning/pytorch/loggers/comet.py
@@ -17,15 +17,18 @@
"""
import logging
+import os
from argparse import Namespace
from collections.abc import Mapping
from typing import TYPE_CHECKING, Any, Literal, Optional, Union
+from lightning_utilities.core.imports import RequirementCache
from torch import Tensor
from torch.nn import Module
from typing_extensions import override
-from lightning.fabric.utilities.imports import _raise_enterprise_not_available
+from lightning.fabric.utilities.logger import _convert_params
+from lightning.fabric.utilities.rank_zero import _get_rank
from lightning.pytorch.loggers.logger import Logger, rank_zero_experiment
from lightning.pytorch.utilities.rank_zero import rank_zero_only
@@ -33,6 +36,7 @@
from comet_ml import ExistingExperiment, Experiment, OfflineExperiment
log = logging.getLogger(__name__)
+_COMET_AVAILABLE = RequirementCache("comet-ml>=3.44.4", module="comet_ml")
FRAMEWORK_NAME = "pytorch-lightning"
comet_experiment = Union["Experiment", "ExistingExperiment", "OfflineExperiment"]
@@ -204,23 +208,99 @@ def __init__(
prefix: Optional[str] = None,
**kwargs: Any,
):
- _raise_enterprise_not_available()
+ if not _COMET_AVAILABLE:
+ raise ModuleNotFoundError(str(_COMET_AVAILABLE))
super().__init__()
- from pytorch_lightning_enterprise.loggers.comet import CometLogger as EnterpriseCometLogger
-
- self.logger_impl = EnterpriseCometLogger(
- api_key=api_key,
- workspace=workspace,
- project=project,
- experiment_key=experiment_key,
- mode=mode,
- online=online,
- prefix=prefix,
- **kwargs,
+ ##################################################
+ # HANDLE PASSED OLD TYPE PARAMS
+
+ # handle old "experiment_name" param
+ if "experiment_name" in kwargs:
+ log.warning("The parameter `experiment_name` is deprecated, please use `name` instead.")
+ experiment_name = kwargs.pop("experiment_name")
+
+ if "name" not in kwargs:
+ kwargs["name"] = experiment_name
+ else:
+ log.warning("You specified both `experiment_name` and `name` parameters, please use `name` only")
+
+ # handle old "project_name" param
+ if "project_name" in kwargs:
+ log.warning("The parameter `project_name` is deprecated, please use `project` instead.")
+ if project is None:
+ project = kwargs.pop("project_name")
+ else:
+ log.warning("You specified both `project_name` and `project` parameters, please use `project` only")
+
+ # handle old "offline" experiment flag
+ if "offline" in kwargs:
+ log.warning("The parameter `offline is deprecated, please use `online` instead.")
+ if online is None:
+ online = kwargs.pop("offline")
+ else:
+ log.warning("You specified both `offline` and `online` parameters, please use `online` only")
+
+ # handle old "save_dir" param
+ if "save_dir" in kwargs:
+ log.warning("The parameter `save_dir` is deprecated, please use `offline_directory` instead.")
+ if "offline_directory" not in kwargs:
+ kwargs["offline_directory"] = kwargs.pop("save_dir")
+ else:
+ log.warning(
+ "You specified both `save_dir` and `offline_directory` parameters, "
+ "please use `offline_directory` only"
+ )
+ ##################################################
+
+ self._api_key: Optional[str] = api_key
+ self._experiment: Optional[comet_experiment] = None
+ self._workspace: Optional[str] = workspace
+ self._mode: Optional[Literal["get_or_create", "get", "create"]] = mode
+ self._online: Optional[bool] = online
+ self._project_name: Optional[str] = project
+ self._experiment_key: Optional[str] = experiment_key
+ self._prefix: Optional[str] = prefix
+ self._kwargs: dict[str, Any] = kwargs
+
+ # needs to be set before the first `comet_ml` import
+ # because comet_ml imported after another machine learning libraries (Torch)
+ os.environ["COMET_DISABLE_AUTO_LOGGING"] = "1"
+
+ import comet_ml
+
+ config_kwargs = self._kwargs.copy()
+ if online is False:
+ config_kwargs["disabled"] = True
+ self._comet_config = comet_ml.ExperimentConfig(**config_kwargs)
+
+ # create real experiment only on main node/process (when strategy=auto/ddp)
+ if _get_rank() is not None and _get_rank() != 0:
+ return
+
+ self._create_experiment()
+
+ def _create_experiment(self) -> None:
+ import comet_ml
+
+ self._experiment = comet_ml.start(
+ api_key=self._api_key,
+ workspace=self._workspace,
+ project=self._project_name,
+ experiment_key=self._experiment_key,
+ mode=self._mode,
+ online=self._online,
+ experiment_config=self._comet_config,
)
+ if self._experiment is None:
+ raise comet_ml.exceptions.ExperimentNotFound("Failed to create Comet experiment.")
+
+ self._experiment_key = self._experiment.get_key()
+ self._project_name = self._experiment.project_name
+ self._experiment.log_other("Created from", FRAMEWORK_NAME)
+
@property
@rank_zero_experiment
def experiment(self) -> comet_experiment:
@@ -233,24 +313,55 @@ def experiment(self) -> comet_experiment:
"""
- return self.logger_impl.experiment
+ # if by some chance there is no experiment created yet (for example, when strategy=ddp_spawn)
+ # then we will create a new one
+ if not self._experiment:
+ self._create_experiment()
+
+ return self._experiment
@override
@rank_zero_only
def log_hyperparams(self, params: Union[dict[str, Any], Namespace]) -> None:
- return self.logger_impl.log_hyperparams(params)
+ params = _convert_params(params)
+ self.experiment.__internal_api__log_parameters__(
+ parameters=params,
+ framework=FRAMEWORK_NAME,
+ flatten_nested=True,
+ source="manual",
+ )
@override
@rank_zero_only
def log_metrics(self, metrics: Mapping[str, Union[Tensor, float]], step: Optional[int] = None) -> None:
- return self.logger_impl.log_metrics(metrics, step)
+ assert rank_zero_only.rank == 0, "experiment tried to log from global_rank != 0"
+ # Comet.com expects metrics to be a dictionary of detached tensors on CPU
+ metrics_without_epoch = metrics.copy()
+ for key, val in metrics_without_epoch.items():
+ if isinstance(val, Tensor):
+ metrics_without_epoch[key] = val.cpu().detach()
+
+ epoch = metrics_without_epoch.pop("epoch", None)
+ self.experiment.__internal_api__log_metrics__(
+ metrics_without_epoch,
+ step=step,
+ epoch=epoch,
+ prefix=self._prefix,
+ framework=FRAMEWORK_NAME,
+ )
@override
@rank_zero_only
def finalize(self, status: str) -> None:
"""We will not end experiment (will not call self._experiment.end()) here to have an ability to continue using
it after training is complete but instead of ending we will upload/save all the data."""
- return self.logger_impl.finalize(status)
+ if self._experiment is None:
+ # When using multiprocessing, finalize() should be a no-op on the main process, as no experiment has been
+ # initialized there
+ return
+
+ # just save the data
+ self.experiment.flush()
@property
@override
@@ -261,7 +372,7 @@ def save_dir(self) -> Optional[str]:
The path to the save directory.
"""
- return self.logger_impl.save_dir
+ return self._comet_config.offline_directory
@property
@override
@@ -272,7 +383,7 @@ def name(self) -> Optional[str]:
The project name if it is specified.
"""
- return self.logger_impl.name
+ return self._project_name
@property
@override
@@ -284,8 +395,27 @@ def version(self) -> Optional[str]:
"""
# Don't create an experiment if we don't have one
- return self.logger_impl.version
+ if self._experiment is not None:
+ return self._experiment.get_key()
+
+ def __getstate__(self) -> dict[str, Any]:
+ state = self.__dict__.copy()
+
+ # Save the experiment id in case an experiment object already exists,
+ # this way we could create an ExistingExperiment pointing to the same
+ # experiment
+ state["_experiment_key"] = self._experiment.get_key() if self._experiment is not None else None
+
+ # Remove the experiment object as it contains hard to pickle objects
+ # (like network connections), the experiment object will be recreated if
+ # needed later
+ state["_experiment"] = None
+ return state
@override
def log_graph(self, model: Module, input_array: Optional[Tensor] = None) -> None:
- return self.logger_impl.log_graph(model, input_array)
+ if self._experiment is not None:
+ self._experiment.__internal_api__set_model_graph__(
+ graph=model,
+ framework=FRAMEWORK_NAME,
+ )
diff --git a/src/lightning/pytorch/loggers/mlflow.py b/src/lightning/pytorch/loggers/mlflow.py
index 725fd2cb8bfbb..ff9b2b0d7e542 100644
--- a/src/lightning/pytorch/loggers/mlflow.py
+++ b/src/lightning/pytorch/loggers/mlflow.py
@@ -16,21 +16,35 @@
-------------
"""
+import logging
import os
+import re
+import tempfile
from argparse import Namespace
from collections.abc import Mapping
-from typing import TYPE_CHECKING, Any, Literal, Optional, Union
+from pathlib import Path
+from time import time
+from typing import TYPE_CHECKING, Any, Callable, Literal, Optional, Union
+import yaml
+from lightning_utilities.core.imports import RequirementCache
+from torch import Tensor
from typing_extensions import override
-from lightning.fabric.utilities.imports import _raise_enterprise_not_available
+from lightning.fabric.utilities.logger import _add_prefix, _convert_params, _flatten_dict
from lightning.pytorch.callbacks.model_checkpoint import ModelCheckpoint
from lightning.pytorch.loggers.logger import Logger, rank_zero_experiment
-from lightning.pytorch.utilities.rank_zero import rank_zero_only
+from lightning.pytorch.loggers.utilities import _scan_checkpoints
+from lightning.pytorch.utilities.rank_zero import rank_zero_only, rank_zero_warn
if TYPE_CHECKING:
from mlflow.tracking import MlflowClient
+log = logging.getLogger(__name__)
+LOCAL_FILE_URI_PREFIX = "file:"
+_MLFLOW_AVAILABLE = RequirementCache("mlflow>=1.0.0", "mlflow")
+_MLFLOW_SYNCHRONOUS_AVAILABLE = RequirementCache("mlflow>=2.8.0", "mlflow")
+
class MLFlowLogger(Logger):
"""Log using `MLflow `_.
@@ -112,22 +126,31 @@ def __init__(
run_id: Optional[str] = None,
synchronous: Optional[bool] = None,
):
- _raise_enterprise_not_available()
- from pytorch_lightning_enterprise.loggers.mlflow import MLFlowLogger as EnterpriseMLFlowLogger
-
+ if not _MLFLOW_AVAILABLE:
+ raise ModuleNotFoundError(str(_MLFLOW_AVAILABLE))
+ if synchronous is not None and not _MLFLOW_SYNCHRONOUS_AVAILABLE:
+ raise ModuleNotFoundError("`synchronous` requires mlflow>=2.8.0")
super().__init__()
- self.logger_impl = EnterpriseMLFlowLogger(
- experiment_name=experiment_name,
- run_name=run_name,
- tracking_uri=tracking_uri,
- tags=tags,
- save_dir=save_dir,
- log_model=log_model,
- prefix=prefix,
- artifact_location=artifact_location,
- run_id=run_id,
- synchronous=synchronous,
- )
+ if not tracking_uri:
+ tracking_uri = f"{LOCAL_FILE_URI_PREFIX}{save_dir}"
+
+ self._experiment_name = experiment_name
+ self._experiment_id: Optional[str] = None
+ self._tracking_uri = tracking_uri
+ self._run_name = run_name
+ self._run_id = run_id
+ self.tags = tags
+ self._log_model = log_model
+ self._logged_model_time: dict[str, float] = {}
+ self._checkpoint_callback: Optional[ModelCheckpoint] = None
+ self._prefix = prefix
+ self._artifact_location = artifact_location
+ self._log_batch_kwargs = {} if synchronous is None else {"synchronous": synchronous}
+ self._initialized = False
+
+ from mlflow.tracking import MlflowClient
+
+ self._mlflow_client = MlflowClient(tracking_uri)
@property
@rank_zero_experiment
@@ -140,7 +163,46 @@ def experiment(self) -> "MlflowClient":
self.logger.experiment.some_mlflow_function()
"""
- return self.logger_impl.experiment
+ import mlflow
+
+ if self._initialized:
+ return self._mlflow_client
+
+ mlflow.set_tracking_uri(self._tracking_uri)
+
+ if self._run_id is not None:
+ run = self._mlflow_client.get_run(self._run_id)
+ self._experiment_id = run.info.experiment_id
+ self._initialized = True
+ return self._mlflow_client
+
+ if self._experiment_id is None:
+ expt = self._mlflow_client.get_experiment_by_name(self._experiment_name)
+ if expt is not None and expt.lifecycle_stage != "deleted":
+ self._experiment_id = expt.experiment_id
+ else:
+ log.warning(f"Experiment with name {self._experiment_name} not found. Creating it.")
+ self._experiment_id = self._mlflow_client.create_experiment(
+ name=self._experiment_name, artifact_location=self._artifact_location
+ )
+
+ if self._run_id is None:
+ if self._run_name is not None:
+ self.tags = self.tags or {}
+
+ from mlflow.utils.mlflow_tags import MLFLOW_RUN_NAME
+
+ if MLFLOW_RUN_NAME in self.tags:
+ log.warning(
+ f"The tag {MLFLOW_RUN_NAME} is found in tags. The value will be overridden by {self._run_name}."
+ )
+ self.tags[MLFLOW_RUN_NAME] = self._run_name
+
+ resolve_tags = _get_resolve_tags()
+ run = self._mlflow_client.create_run(experiment_id=self._experiment_id, tags=resolve_tags(self.tags))
+ self._run_id = run.info.run_id
+ self._initialized = True
+ return self._mlflow_client
@property
def run_id(self) -> Optional[str]:
@@ -150,7 +212,8 @@ def run_id(self) -> Optional[str]:
The run id.
"""
- return self.logger_impl.run_id
+ _ = self.experiment
+ return self._run_id
@property
def experiment_id(self) -> Optional[str]:
@@ -160,22 +223,71 @@ def experiment_id(self) -> Optional[str]:
The experiment id.
"""
- return self.logger_impl.experiment_id
+ _ = self.experiment
+ return self._experiment_id
@override
@rank_zero_only
def log_hyperparams(self, params: Union[dict[str, Any], Namespace]) -> None:
- return self.logger_impl.log_hyperparams(params)
+ params = _convert_params(params)
+ params = _flatten_dict(params)
+
+ from mlflow.entities import Param
+
+ # Truncate parameter values to 250 characters.
+ # TODO: MLflow 1.28 allows up to 500 characters: https://github.com/mlflow/mlflow/releases/tag/v1.28.0
+ params_list = [Param(key=k, value=str(v)[:250]) for k, v in params.items()]
+
+ # Log in chunks of 100 parameters (the maximum allowed by MLflow).
+ for idx in range(0, len(params_list), 100):
+ self.experiment.log_batch(run_id=self.run_id, params=params_list[idx : idx + 100], **self._log_batch_kwargs)
@override
@rank_zero_only
def log_metrics(self, metrics: Mapping[str, float], step: Optional[int] = None) -> None:
- return self.logger_impl.log_metrics(metrics, step)
+ assert rank_zero_only.rank == 0, "experiment tried to log from global_rank != 0"
+
+ from mlflow.entities import Metric
+
+ metrics = _add_prefix(metrics, self._prefix, self.LOGGER_JOIN_CHAR)
+ metrics_list: list[Metric] = []
+
+ timestamp_ms = int(time() * 1000)
+ for k, v in metrics.items():
+ if isinstance(v, str):
+ log.warning(f"Discarding metric with string value {k}={v}.")
+ continue
+
+ new_k = re.sub("[^a-zA-Z0-9_/. -]+", "", k)
+ if k != new_k:
+ rank_zero_warn(
+ "MLFlow only allows '_', '/', '.' and ' ' special characters in metric name."
+ f" Replacing {k} with {new_k}.",
+ category=RuntimeWarning,
+ )
+ k = new_k
+ metrics_list.append(Metric(key=k, value=v, timestamp=timestamp_ms, step=step or 0))
+
+ self.experiment.log_batch(run_id=self.run_id, metrics=metrics_list, **self._log_batch_kwargs)
@override
@rank_zero_only
def finalize(self, status: str = "success") -> None:
- return self.logger_impl.finalize(status)
+ if not self._initialized:
+ return
+ if status == "success":
+ status = "FINISHED"
+ elif status == "failed":
+ status = "FAILED"
+ elif status == "finished":
+ status = "FINISHED"
+
+ # log checkpoints as artifacts
+ if self._checkpoint_callback:
+ self._scan_and_log_checkpoints(self._checkpoint_callback)
+
+ if self.experiment.get_run(self.run_id):
+ self.experiment.set_terminated(self.run_id, status)
@property
@override
@@ -187,7 +299,9 @@ def save_dir(self) -> Optional[str]:
Otherwise returns `None`.
"""
- return self.logger_impl.save_dir
+ if self._tracking_uri.startswith(LOCAL_FILE_URI_PREFIX):
+ return self._tracking_uri[len(LOCAL_FILE_URI_PREFIX) :]
+ return None
@property
@override
@@ -198,7 +312,7 @@ def name(self) -> Optional[str]:
The experiment id.
"""
- return self.logger_impl.name
+ return self.experiment_id
@property
@override
@@ -209,8 +323,76 @@ def version(self) -> Optional[str]:
The run id.
"""
- return self.logger_impl.version
+ return self.run_id
@override
def after_save_checkpoint(self, checkpoint_callback: ModelCheckpoint) -> None:
- return self.logger_impl.after_save_checkpoint(checkpoint_callback)
+ # log checkpoints as artifacts
+ if self._log_model == "all" or self._log_model is True and checkpoint_callback.save_top_k == -1:
+ self._scan_and_log_checkpoints(checkpoint_callback)
+ elif self._log_model is True:
+ self._checkpoint_callback = checkpoint_callback
+
+ def _scan_and_log_checkpoints(self, checkpoint_callback: ModelCheckpoint) -> None:
+ # get checkpoints to be saved with associated score
+ checkpoints = _scan_checkpoints(checkpoint_callback, self._logged_model_time)
+
+ # log iteratively all new checkpoints
+ for t, p, s, tag in checkpoints:
+ metadata = {
+ # Ensure .item() is called to store Tensor contents
+ "score": s.item() if isinstance(s, Tensor) else s,
+ "original_filename": Path(p).name,
+ "Checkpoint": {
+ k: getattr(checkpoint_callback, k)
+ for k in [
+ "monitor",
+ "mode",
+ "save_last",
+ "save_top_k",
+ "save_weights_only",
+ "_every_n_train_steps",
+ "_every_n_val_epochs",
+ ]
+ # ensure it does not break if `Checkpoint` args change
+ if hasattr(checkpoint_callback, k)
+ },
+ }
+ aliases = ["latest", "best"] if p == checkpoint_callback.best_model_path else ["latest"]
+
+ # Artifact path on mlflow
+ artifact_path = Path(p).stem
+
+ # Log the checkpoint
+ self.experiment.log_artifact(self._run_id, p, artifact_path)
+
+ # Create a temporary directory to log on mlflow
+ with tempfile.TemporaryDirectory() as tmp_dir:
+ # Log the metadata
+ with open(f"{tmp_dir}/metadata.yaml", "w") as tmp_file_metadata:
+ yaml.dump(metadata, tmp_file_metadata, default_flow_style=False)
+
+ # Log the aliases
+ with open(f"{tmp_dir}/aliases.txt", "w") as tmp_file_aliases:
+ tmp_file_aliases.write(str(aliases))
+
+ # Log the metadata and aliases
+ self.experiment.log_artifacts(self._run_id, tmp_dir, artifact_path)
+
+ # remember logged models - timestamp needed in case filename didn't change (lastkckpt or custom name)
+ self._logged_model_time[p] = t
+
+
+def _get_resolve_tags() -> Callable:
+ from mlflow.tracking import context
+
+ # before v1.1.0
+ if hasattr(context, "resolve_tags"):
+ from mlflow.tracking.context import resolve_tags
+ # since v1.1.0
+ elif hasattr(context, "registry"):
+ from mlflow.tracking.context.registry import resolve_tags
+ else:
+ resolve_tags = lambda tags: tags
+
+ return resolve_tags
diff --git a/src/lightning/pytorch/loggers/neptune.py b/src/lightning/pytorch/loggers/neptune.py
index d6ec7d482fdc0..bf9669c824784 100644
--- a/src/lightning/pytorch/loggers/neptune.py
+++ b/src/lightning/pytorch/loggers/neptune.py
@@ -16,17 +16,23 @@
--------------
"""
+import contextlib
import logging
+import os
from argparse import Namespace
-from typing import TYPE_CHECKING, Any, Optional, Union
+from collections.abc import Generator
+from functools import wraps
+from typing import TYPE_CHECKING, Any, Callable, Optional, Union
+from lightning_utilities.core.imports import RequirementCache
from torch import Tensor
from typing_extensions import override
import lightning.pytorch as pl
-from lightning.fabric.utilities.imports import _raise_enterprise_not_available
+from lightning.fabric.utilities.logger import _add_prefix, _convert_params, _sanitize_callable_params
from lightning.pytorch.callbacks import Checkpoint
from lightning.pytorch.loggers.logger import Logger, rank_zero_experiment
+from lightning.pytorch.utilities.model_summary import ModelSummary
from lightning.pytorch.utilities.rank_zero import rank_zero_only
if TYPE_CHECKING:
@@ -35,6 +41,27 @@
log = logging.getLogger(__name__)
+# Neptune is available with two names on PyPI : `neptune` and `neptune-client`
+# `neptune` was introduced as a name transition of neptune-client and the long-term target is to get
+# rid of Neptune-client package completely someday. It was introduced as a part of breaking-changes with a release
+# of neptune-client==1.0. neptune-client>=1.0 is just an alias of neptune package and have some breaking-changes
+# in compare to neptune-client<1.0.0.
+_NEPTUNE_AVAILABLE = RequirementCache("neptune>=1.0")
+_INTEGRATION_VERSION_KEY = "source_code/integrations/pytorch-lightning"
+
+
+# Neptune client throws `InactiveRunException` when trying to log to an inactive run.
+# This may happen when the run was stopped through the UI and the logger is still trying to log to it.
+def _catch_inactive(func: Callable) -> Callable:
+ @wraps(func)
+ def wrapper(*args: Any, **kwargs: Any) -> Any:
+ from neptune.exceptions import InactiveRunException
+
+ with contextlib.suppress(InactiveRunException):
+ return func(*args, **kwargs)
+
+ return wrapper
+
class NeptuneLogger(Logger):
r"""Log using `Neptune `_.
@@ -215,19 +242,113 @@ def __init__(
prefix: str = "training",
**neptune_run_kwargs: Any,
):
- _raise_enterprise_not_available()
+ if not _NEPTUNE_AVAILABLE:
+ raise ModuleNotFoundError(str(_NEPTUNE_AVAILABLE))
+
+ # verify if user passed proper init arguments
+ self._verify_input_arguments(api_key, project, name, run, neptune_run_kwargs)
super().__init__()
- from pytorch_lightning_enterprise.loggers.neptune import NeptuneLogger as EnterpriseNeptuneLogger
-
- self.logger_impl = EnterpriseNeptuneLogger(
- api_key=api_key,
- project=project,
- name=name,
- run=run,
- log_model_checkpoints=log_model_checkpoints,
- prefix=prefix,
- **neptune_run_kwargs,
- )
+ self._log_model_checkpoints = log_model_checkpoints
+ self._prefix = prefix
+ self._run_name = name
+ self._project_name = project
+ self._api_key = api_key
+ self._run_instance = run
+ self._neptune_run_kwargs = neptune_run_kwargs
+ self._run_short_id: Optional[str] = None
+
+ if self._run_instance is not None:
+ self._retrieve_run_data()
+
+ from neptune.handler import Handler
+
+ # make sure that we've log integration version for outside `Run` instances
+ root_obj = self._run_instance
+ if isinstance(root_obj, Handler):
+ root_obj = root_obj.get_root_object()
+
+ root_obj[_INTEGRATION_VERSION_KEY] = pl.__version__
+
+ def _retrieve_run_data(self) -> None:
+ from neptune.handler import Handler
+
+ assert self._run_instance is not None
+ root_obj = self._run_instance
+ if isinstance(root_obj, Handler):
+ root_obj = root_obj.get_root_object()
+
+ root_obj.wait()
+
+ if root_obj.exists("sys/id"):
+ self._run_short_id = root_obj["sys/id"].fetch()
+ self._run_name = root_obj["sys/name"].fetch()
+ else:
+ self._run_short_id = "OFFLINE"
+ self._run_name = "offline-name"
+
+ @property
+ def _neptune_init_args(self) -> dict:
+ args: dict = {}
+ # Backward compatibility in case of previous version retrieval
+ with contextlib.suppress(AttributeError):
+ args = self._neptune_run_kwargs
+
+ if self._project_name is not None:
+ args["project"] = self._project_name
+
+ if self._api_key is not None:
+ args["api_token"] = self._api_key
+
+ if self._run_short_id is not None:
+ args["run"] = self._run_short_id
+
+ # Backward compatibility in case of previous version retrieval
+ with contextlib.suppress(AttributeError):
+ if self._run_name is not None:
+ args["name"] = self._run_name
+
+ return args
+
+ def _construct_path_with_prefix(self, *keys: str) -> str:
+ """Return sequence of keys joined by `LOGGER_JOIN_CHAR`, started with `_prefix` if defined."""
+ if self._prefix:
+ return self.LOGGER_JOIN_CHAR.join([self._prefix, *keys])
+ return self.LOGGER_JOIN_CHAR.join(keys)
+
+ @staticmethod
+ def _verify_input_arguments(
+ api_key: Optional[str],
+ project: Optional[str],
+ name: Optional[str],
+ run: Optional[Union["Run", "Handler"]],
+ neptune_run_kwargs: dict,
+ ) -> None:
+ from neptune import Run
+ from neptune.handler import Handler
+
+ # check if user passed the client `Run`/`Handler` object
+ if run is not None and not isinstance(run, (Run, Handler)):
+ raise ValueError("Run parameter expected to be of type `neptune.Run`, or `neptune.handler.Handler`.")
+
+ # check if user passed redundant neptune.init_run arguments when passed run
+ any_neptune_init_arg_passed = any(arg is not None for arg in [api_key, project, name]) or neptune_run_kwargs
+ if run is not None and any_neptune_init_arg_passed:
+ raise ValueError(
+ "When an already initialized run object is provided, you can't provide other `neptune.init_run()`"
+ " parameters."
+ )
+
+ def __getstate__(self) -> dict[str, Any]:
+ state = self.__dict__.copy()
+ # Run instance can't be pickled
+ state["_run_instance"] = None
+ return state
+
+ def __setstate__(self, state: dict[str, Any]) -> None:
+ import neptune
+
+ self.__dict__ = state
+ self._run_instance = neptune.init_run(**self._neptune_init_args)
@property
@rank_zero_experiment
@@ -257,23 +378,73 @@ def training_step(self, batch, batch_idx):
with NeptuneLogger.
"""
- return self.logger_impl.experiment
+ return self.run
@property
@rank_zero_experiment
def run(self) -> "Run":
- return self.logger_impl.run
+ import neptune
+
+ if not self._run_instance:
+ self._run_instance = neptune.init_run(**self._neptune_init_args)
+ self._retrieve_run_data()
+ # make sure that we've log integration version for newly created
+ self._run_instance[_INTEGRATION_VERSION_KEY] = pl.__version__
+
+ return self._run_instance
@override
@rank_zero_only
+ @_catch_inactive
def log_hyperparams(self, params: Union[dict[str, Any], Namespace]) -> None:
- return self.logger_impl.log_hyperparams(params)
+ r"""Log hyperparameters to the run.
+
+ Hyperparameters will be logged under the "/hyperparams" namespace.
+
+ Note:
+
+ You can also log parameters by directly using the logger instance:
+ ``neptune_logger.experiment["model/hyper-parameters"] = params_dict``.
+
+ In this way you can keep hierarchical structure of the parameters.
+
+ Args:
+ params: `dict`.
+ Python dictionary structure with parameters.
+
+ Example::
+
+ from lightning.pytorch.loggers import NeptuneLogger
+ import neptune
+
+ PARAMS = {
+ "batch_size": 64,
+ "lr": 0.07,
+ "decay_factor": 0.97,
+ }
+
+ neptune_logger = NeptuneLogger(
+ api_key=neptune.ANONYMOUS_API_TOKEN,
+ project="common/pytorch-lightning-integration"
+ )
+
+ neptune_logger.log_hyperparams(PARAMS)
+
+ """
+ from neptune.utils import stringify_unsupported
+
+ params = _convert_params(params)
+ params = _sanitize_callable_params(params)
+
+ parameters_key = self.PARAMETERS_KEY
+ parameters_key = self._construct_path_with_prefix(parameters_key)
+
+ self.run[parameters_key] = stringify_unsupported(params)
@override
@rank_zero_only
- def log_metrics( # type: ignore[override]
- self, metrics: dict[str, Union[Tensor, float]], step: Optional[int] = None
- ) -> None:
+ @_catch_inactive
+ def log_metrics(self, metrics: dict[str, Union[Tensor, float]], step: Optional[int] = None) -> None:
"""Log metrics (numeric values) in Neptune runs.
Args:
@@ -281,12 +452,26 @@ def log_metrics( # type: ignore[override]
step: Step number at which the metrics should be recorded
"""
- return self.logger_impl.log_metrics(metrics, step)
+ if rank_zero_only.rank != 0:
+ raise ValueError("run tried to log from global_rank != 0")
+
+ metrics = _add_prefix(metrics, self._prefix, self.LOGGER_JOIN_CHAR)
+
+ for key, val in metrics.items():
+ self.run[key].append(val, step=step)
@override
@rank_zero_only
+ @_catch_inactive
def finalize(self, status: str) -> None:
- return self.logger_impl.finalize(status)
+ if not self._run_instance:
+ # When using multiprocessing, finalize() should be a no-op on the main process, as no experiment has been
+ # initialized there
+ return
+ if status:
+ self.run[self._construct_path_with_prefix("status")] = status
+
+ super().finalize(status)
@property
@override
@@ -298,14 +483,21 @@ def save_dir(self) -> Optional[str]:
the root directory where experiment logs get saved
"""
- return self.logger_impl.save_dir
+ return os.path.join(os.getcwd(), ".neptune")
@rank_zero_only
+ @_catch_inactive
def log_model_summary(self, model: "pl.LightningModule", max_depth: int = -1) -> None:
- return self.logger_impl.log_model_summary(model, max_depth)
+ from neptune.types import File
+
+ model_str = str(ModelSummary(model=model, max_depth=max_depth))
+ self.run[self._construct_path_with_prefix("model/summary")] = File.from_content(
+ content=model_str, extension="txt"
+ )
@override
@rank_zero_only
+ @_catch_inactive
def after_save_checkpoint(self, checkpoint_callback: Checkpoint) -> None:
"""Automatically log checkpointed model. Called after model checkpoint callback saves a new checkpoint.
@@ -313,13 +505,83 @@ def after_save_checkpoint(self, checkpoint_callback: Checkpoint) -> None:
checkpoint_callback: the model checkpoint callback instance
"""
- return self.logger_impl.after_save_checkpoint(checkpoint_callback)
+ if not self._log_model_checkpoints:
+ return
+
+ file_names = set()
+ checkpoints_namespace = self._construct_path_with_prefix("model/checkpoints")
+
+ # save last model
+ if hasattr(checkpoint_callback, "last_model_path") and checkpoint_callback.last_model_path:
+ model_last_name = self._get_full_model_name(checkpoint_callback.last_model_path, checkpoint_callback)
+ file_names.add(model_last_name)
+ self.run[f"{checkpoints_namespace}/{model_last_name}"].upload(checkpoint_callback.last_model_path)
+
+ # save best k models
+ if hasattr(checkpoint_callback, "best_k_models"):
+ for key in checkpoint_callback.best_k_models:
+ model_name = self._get_full_model_name(key, checkpoint_callback)
+ file_names.add(model_name)
+ self.run[f"{checkpoints_namespace}/{model_name}"].upload(key)
+
+ # log best model path and checkpoint
+ if hasattr(checkpoint_callback, "best_model_path") and checkpoint_callback.best_model_path:
+ self.run[self._construct_path_with_prefix("model/best_model_path")] = checkpoint_callback.best_model_path
+
+ model_name = self._get_full_model_name(checkpoint_callback.best_model_path, checkpoint_callback)
+ file_names.add(model_name)
+ self.run[f"{checkpoints_namespace}/{model_name}"].upload(checkpoint_callback.best_model_path)
+
+ # remove old models logged to experiment if they are not part of best k models at this point
+ if self.run.exists(checkpoints_namespace):
+ exp_structure = self.run.get_structure()
+ uploaded_model_names = self._get_full_model_names_from_exp_structure(exp_structure, checkpoints_namespace)
+
+ for file_to_drop in list(uploaded_model_names - file_names):
+ del self.run[f"{checkpoints_namespace}/{file_to_drop}"]
+
+ # log best model score
+ if hasattr(checkpoint_callback, "best_model_score") and checkpoint_callback.best_model_score:
+ self.run[self._construct_path_with_prefix("model/best_model_score")] = (
+ checkpoint_callback.best_model_score.cpu().detach().numpy()
+ )
+
+ @staticmethod
+ def _get_full_model_name(model_path: str, checkpoint_callback: Checkpoint) -> str:
+ """Returns model name which is string `model_path` appended to `checkpoint_callback.dirpath`."""
+ if hasattr(checkpoint_callback, "dirpath"):
+ model_path = os.path.normpath(model_path)
+ expected_model_path = os.path.normpath(checkpoint_callback.dirpath)
+ if not model_path.startswith(expected_model_path):
+ raise ValueError(f"{model_path} was expected to start with {expected_model_path}.")
+ # Remove extension from filepath
+ filepath, _ = os.path.splitext(model_path[len(expected_model_path) + 1 :])
+ return filepath.replace(os.sep, "/")
+ return model_path.replace(os.sep, "/")
+
+ @classmethod
+ def _get_full_model_names_from_exp_structure(cls, exp_structure: dict[str, Any], namespace: str) -> set[str]:
+ """Returns all paths to properties which were already logged in `namespace`"""
+ structure_keys: list[str] = namespace.split(cls.LOGGER_JOIN_CHAR)
+ for key in structure_keys:
+ exp_structure = exp_structure[key]
+ uploaded_models_dict = exp_structure
+ return set(cls._dict_paths(uploaded_models_dict))
+
+ @classmethod
+ def _dict_paths(cls, d: dict[str, Any], path_in_build: Optional[str] = None) -> Generator:
+ for k, v in d.items():
+ path = f"{path_in_build}/{k}" if path_in_build is not None else k
+ if not isinstance(v, dict):
+ yield path
+ else:
+ yield from cls._dict_paths(v, path)
@property
@override
def name(self) -> Optional[str]:
"""Return the experiment name or 'offline-name' when exp is run in offline mode."""
- return self.logger_impl.name
+ return self._run_name
@property
@override
@@ -329,4 +591,4 @@ def version(self) -> Optional[str]:
It's Neptune Run's short_id
"""
- return self.logger_impl.version
+ return self._run_short_id
diff --git a/src/lightning/pytorch/loggers/wandb.py b/src/lightning/pytorch/loggers/wandb.py
index 41d38b460d0f0..37ca362fa40c1 100644
--- a/src/lightning/pytorch/loggers/wandb.py
+++ b/src/lightning/pytorch/loggers/wandb.py
@@ -16,24 +16,37 @@
-------------------------
"""
+import os
from argparse import Namespace
from collections.abc import Mapping
+from pathlib import Path
from typing import TYPE_CHECKING, Any, Literal, Optional, Union
import torch.nn as nn
+from lightning_utilities.core.imports import RequirementCache
+from torch import Tensor
from typing_extensions import override
-from lightning.fabric.utilities.imports import _raise_enterprise_not_available
+from lightning.fabric.utilities.logger import (
+ _add_prefix,
+ _convert_json_serializable,
+ _convert_params,
+ _sanitize_callable_params,
+)
from lightning.fabric.utilities.types import _PATH
from lightning.pytorch.callbacks.model_checkpoint import ModelCheckpoint
from lightning.pytorch.loggers.logger import Logger, rank_zero_experiment
-from lightning.pytorch.utilities.rank_zero import rank_zero_only
+from lightning.pytorch.loggers.utilities import _scan_checkpoints
+from lightning.pytorch.utilities.exceptions import MisconfigurationException
+from lightning.pytorch.utilities.rank_zero import rank_zero_only, rank_zero_warn
if TYPE_CHECKING:
from wandb import Artifact
from wandb.sdk.lib import RunDisabled
from wandb.wandb_run import Run
+_WANDB_AVAILABLE = RequirementCache("wandb>=0.12.10")
+
class WandbLogger(Logger):
r"""Log using `Weights and Biases `_.
@@ -295,27 +308,70 @@ def __init__(
add_file_policy: Literal["mutable", "immutable"] = "mutable",
**kwargs: Any,
) -> None:
- _raise_enterprise_not_available()
+ if not _WANDB_AVAILABLE:
+ raise ModuleNotFoundError(str(_WANDB_AVAILABLE))
+
+ if offline and log_model:
+ raise MisconfigurationException(
+ f"Providing log_model={log_model} and offline={offline} is an invalid configuration"
+ " since model checkpoints cannot be uploaded in offline mode.\n"
+ "Hint: Set `offline=False` to log your model."
+ )
super().__init__()
- from pytorch_lightning_enterprise.loggers.wandb import WandbLogger as EnterpriseWandbLogger
-
- self.logger_impl = EnterpriseWandbLogger(
- name=name,
- save_dir=save_dir,
- version=version,
- offline=offline,
- dir=dir,
- id=id,
- anonymous=anonymous,
- project=project,
- log_model=log_model,
- experiment=experiment,
- prefix=prefix,
- checkpoint_name=checkpoint_name,
- add_file_policy=add_file_policy,
- **kwargs,
- )
+ self._offline = offline
+ self._log_model = log_model
+ self._prefix = prefix
+ self._experiment = experiment
+ self._logged_model_time: dict[str, float] = {}
+ self._checkpoint_callbacks: dict[int, ModelCheckpoint] = {}
+ self.add_file_policy = add_file_policy
+
+ # paths are processed as strings
+ if save_dir is not None:
+ save_dir = os.fspath(save_dir)
+ elif dir is not None:
+ dir = os.fspath(dir)
+
+ project = project or os.environ.get("WANDB_PROJECT", "lightning_logs")
+
+ # set wandb init arguments
+ self._wandb_init: dict[str, Any] = {
+ "name": name,
+ "project": project,
+ "dir": save_dir or dir,
+ "id": version or id,
+ "resume": "allow",
+ "anonymous": ("allow" if anonymous else None),
+ }
+ self._wandb_init.update(**kwargs)
+ # extract parameters
+ self._project = self._wandb_init.get("project")
+ self._save_dir = self._wandb_init.get("dir")
+ self._name = self._wandb_init.get("name")
+ self._id = self._wandb_init.get("id")
+ self._checkpoint_name = checkpoint_name
+
+ def __getstate__(self) -> dict[str, Any]:
+ import wandb
+
+ # Hack: If the 'spawn' launch method is used, the logger will get pickled and this `__getstate__` gets called.
+ # We create an experiment here in the main process, and attach to it in the worker process.
+ # Using wandb-service, we persist the same experiment even if multiple `Trainer.fit/test/validate` calls
+ # are made.
+ wandb.require("service")
+ _ = self.experiment
+
+ state = self.__dict__.copy()
+ # args needed to reload correct experiment
+ if self._experiment is not None:
+ state["_id"] = getattr(self._experiment, "id", None)
+ state["_attach_id"] = getattr(self._experiment, "_attach_id", None)
+ state["_name"] = self._experiment.name
+
+ # cannot be pickled
+ state["_experiment"] = None
+ return state
@property
@rank_zero_experiment
@@ -330,7 +386,40 @@ def experiment(self) -> Union["Run", "RunDisabled"]:
self.logger.experiment.some_wandb_function()
"""
- return self.logger_impl.experiment
+ import wandb
+ from wandb.sdk.lib import RunDisabled
+ from wandb.wandb_run import Run
+
+ if self._experiment is None:
+ if self._offline:
+ os.environ["WANDB_MODE"] = "dryrun"
+
+ attach_id = getattr(self, "_attach_id", None)
+ if wandb.run is not None:
+ # wandb process already created in this instance
+ rank_zero_warn(
+ "There is a wandb run already in progress and newly created instances of `WandbLogger` will reuse"
+ " this run. If this is not desired, call `wandb.finish()` before instantiating `WandbLogger`."
+ )
+ self._experiment = wandb.run
+ elif attach_id is not None and hasattr(wandb, "_attach"):
+ # attach to wandb process referenced
+ self._experiment = wandb._attach(attach_id)
+ else:
+ # create new wandb process
+ self._experiment = wandb.init(**self._wandb_init)
+
+ # define default x-axis
+ if isinstance(self._experiment, (Run, RunDisabled)) and getattr(
+ self._experiment, "define_metric", None
+ ):
+ if self._wandb_init.get("sync_tensorboard"):
+ self._experiment.define_metric("*", step_metric="global_step")
+ else:
+ self._experiment.define_metric("trainer/global_step")
+ self._experiment.define_metric("*", step_metric="trainer/global_step", step_sync=True)
+
+ return self._experiment
def watch(
self, model: nn.Module, log: Optional[str] = "gradients", log_freq: int = 100, log_graph: bool = True
@@ -340,12 +429,21 @@ def watch(
@override
@rank_zero_only
def log_hyperparams(self, params: Union[dict[str, Any], Namespace]) -> None:
- return self.logger_impl.log_hyperparams(params)
+ params = _convert_params(params)
+ params = _sanitize_callable_params(params)
+ params = _convert_json_serializable(params)
+ self.experiment.config.update(params, allow_val_change=True)
@override
@rank_zero_only
def log_metrics(self, metrics: Mapping[str, float], step: Optional[int] = None) -> None:
- return self.logger_impl.log_metrics(metrics, step)
+ assert rank_zero_only.rank == 0, "experiment tried to log from global_rank != 0"
+
+ metrics = _add_prefix(metrics, self._prefix, self.LOGGER_JOIN_CHAR)
+ if step is not None and not self._wandb_init.get("sync_tensorboard"):
+ self.experiment.log(dict(metrics, **{"trainer/global_step": step}))
+ else:
+ self.experiment.log(metrics)
@rank_zero_only
def log_table(
@@ -361,7 +459,10 @@ def log_table(
Can be defined either with `columns` and `data` or with `dataframe`.
"""
- return self.logger_impl.log_table(key, columns, data, dataframe, step)
+ import wandb
+
+ metrics = {key: wandb.Table(columns=columns, data=data, dataframe=dataframe)}
+ self.log_metrics(metrics, step)
@rank_zero_only
def log_text(
@@ -377,7 +478,8 @@ def log_text(
Can be defined either with `columns` and `data` or with `dataframe`.
"""
- return self.logger_impl.log_text(key, columns, data, dataframe, step)
+
+ self.log_table(key, columns, data, dataframe, step)
@rank_zero_only
def log_image(self, key: str, images: list[Any], step: Optional[int] = None, **kwargs: Any) -> None:
@@ -386,7 +488,18 @@ def log_image(self, key: str, images: list[Any], step: Optional[int] = None, **k
Optional kwargs are lists passed to each image (ex: caption, masks, boxes).
"""
- return self.logger_impl.log_image(key, images, step, **kwargs)
+ if not isinstance(images, list):
+ raise TypeError(f'Expected a list as "images", found {type(images)}')
+ n = len(images)
+ for k, v in kwargs.items():
+ if len(v) != n:
+ raise ValueError(f"Expected {n} items but only found {len(v)} for {k}")
+ kwarg_list = [{k: kwargs[k][i] for k in kwargs} for i in range(n)]
+
+ import wandb
+
+ metrics = {key: [wandb.Image(img, **kwarg) for img, kwarg in zip(images, kwarg_list)]}
+ self.log_metrics(metrics, step) # type: ignore[arg-type]
@rank_zero_only
def log_audio(self, key: str, audios: list[Any], step: Optional[int] = None, **kwargs: Any) -> None:
@@ -401,7 +514,18 @@ def log_audio(self, key: str, audios: list[Any], step: Optional[int] = None, **k
Optional kwargs are lists passed to each audio (ex: caption, sample_rate).
"""
- return self.logger_impl.log_audio(key, audios, step, **kwargs)
+ if not isinstance(audios, list):
+ raise TypeError(f'Expected a list as "audios", found {type(audios)}')
+ n = len(audios)
+ for k, v in kwargs.items():
+ if len(v) != n:
+ raise ValueError(f"Expected {n} items but only found {len(v)} for {k}")
+ kwarg_list = [{k: kwargs[k][i] for k in kwargs} for i in range(n)]
+
+ import wandb
+
+ metrics = {key: [wandb.Audio(audio, **kwarg) for audio, kwarg in zip(audios, kwarg_list)]}
+ self.log_metrics(metrics, step) # type: ignore[arg-type]
@rank_zero_only
def log_video(self, key: str, videos: list[Any], step: Optional[int] = None, **kwargs: Any) -> None:
@@ -416,7 +540,18 @@ def log_video(self, key: str, videos: list[Any], step: Optional[int] = None, **k
Optional kwargs are lists passed to each video (ex: caption, fps, format).
"""
- return self.logger_impl.log_video(key, videos, step, **kwargs)
+ if not isinstance(videos, list):
+ raise TypeError(f'Expected a list as "videos", found {type(videos)}')
+ n = len(videos)
+ for k, v in kwargs.items():
+ if len(v) != n:
+ raise ValueError(f"Expected {n} items but only found {len(v)} for {k}")
+ kwarg_list = [{k: kwargs[k][i] for k in kwargs} for i in range(n)]
+
+ import wandb
+
+ metrics = {key: [wandb.Video(video, **kwarg) for video, kwarg in zip(videos, kwarg_list)]}
+ self.log_metrics(metrics, step) # type: ignore[arg-type]
@property
@override
@@ -427,7 +562,7 @@ def save_dir(self) -> Optional[str]:
The path to the save directory.
"""
- return self.logger_impl.save_dir
+ return self._save_dir
@property
@override
@@ -439,7 +574,7 @@ def name(self) -> Optional[str]:
name. To access wandb's internal experiment name, use ``logger.experiment.name`` instead.
"""
- return self.logger_impl.name
+ return self._project
@property
@override
@@ -451,12 +586,15 @@ def version(self) -> Optional[str]:
"""
# don't create an experiment if we don't have one
- return self.logger_impl.version
+ return self._experiment.id if self._experiment else self._id
@override
def after_save_checkpoint(self, checkpoint_callback: ModelCheckpoint) -> None:
# log checkpoints as artifacts
- return self.logger_impl.after_save_checkpoint(checkpoint_callback)
+ if self._log_model == "all" or self._log_model is True and checkpoint_callback.save_top_k == -1:
+ self._scan_and_log_checkpoints(checkpoint_callback)
+ elif self._log_model is True:
+ self._checkpoint_callbacks[id(checkpoint_callback)] = checkpoint_callback
@staticmethod
@rank_zero_only
@@ -478,10 +616,16 @@ def download_artifact(
The path to the downloaded artifact.
"""
- _raise_enterprise_not_available()
- from pytorch_lightning_enterprise.loggers.wandb import WandbLogger as EnterpriseWandbLogger
+ import wandb
+
+ if wandb.run is not None and use_artifact:
+ artifact = wandb.run.use_artifact(artifact)
+ else:
+ api = wandb.Api()
+ artifact = api.artifact(artifact, type=artifact_type)
- return EnterpriseWandbLogger.download_artifact(artifact, save_dir, artifact_type, use_artifact)
+ save_dir = None if save_dir is None else os.fspath(save_dir)
+ return artifact.download(root=save_dir)
def use_artifact(self, artifact: str, artifact_type: Optional[str] = None) -> "Artifact":
"""Logs to the wandb dashboard that the mentioned artifact is used by the run.
@@ -494,9 +638,49 @@ def use_artifact(self, artifact: str, artifact_type: Optional[str] = None) -> "A
wandb Artifact object for the artifact.
"""
- return self.logger_impl.use_artifact(artifact, artifact_type)
+ return self.experiment.use_artifact(artifact, type=artifact_type)
@override
@rank_zero_only
def finalize(self, status: str) -> None:
- return self.logger_impl.finalize(status)
+ if status != "success":
+ # Currently, checkpoints only get logged on success
+ return
+ # log checkpoints as artifacts
+ if self._experiment is not None:
+ for checkpoint_callback in self._checkpoint_callbacks.values():
+ self._scan_and_log_checkpoints(checkpoint_callback)
+
+ def _scan_and_log_checkpoints(self, checkpoint_callback: ModelCheckpoint) -> None:
+ import wandb
+
+ # get checkpoints to be saved with associated score
+ checkpoints = _scan_checkpoints(checkpoint_callback, self._logged_model_time)
+
+ # log iteratively all new checkpoints
+ for t, p, s, tag in checkpoints:
+ metadata = {
+ "score": s.item() if isinstance(s, Tensor) else s,
+ "original_filename": Path(p).name,
+ checkpoint_callback.__class__.__name__: {
+ k: getattr(checkpoint_callback, k)
+ for k in [
+ "monitor",
+ "mode",
+ "save_last",
+ "save_top_k",
+ "save_weights_only",
+ "_every_n_train_steps",
+ ]
+ # ensure it does not break if `ModelCheckpoint` args change
+ if hasattr(checkpoint_callback, k)
+ },
+ }
+ if not self._checkpoint_name:
+ self._checkpoint_name = f"model-{self.experiment.id}"
+ artifact = wandb.Artifact(name=self._checkpoint_name, type="model", metadata=metadata)
+ artifact.add_file(p, name="model.ckpt", policy=self.add_file_policy)
+ aliases = ["latest", "best"] if p == checkpoint_callback.best_model_path else ["latest"]
+ self.experiment.log_artifact(artifact, aliases=aliases)
+ # remember logged models - timestamp needed in case filename didn't change (lastkckpt or custom name)
+ self._logged_model_time[p] = t
diff --git a/src/lightning/pytorch/plugins/precision/deepspeed.py b/src/lightning/pytorch/plugins/precision/deepspeed.py
index 29834f5d519d0..9225e3bb9e7be 100644
--- a/src/lightning/pytorch/plugins/precision/deepspeed.py
+++ b/src/lightning/pytorch/plugins/precision/deepspeed.py
@@ -11,21 +11,30 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from contextlib import AbstractContextManager
-from typing import Any, Callable, Optional, Union
+from contextlib import AbstractContextManager, nullcontext
+from typing import TYPE_CHECKING, Any, Callable, Optional, Union
import torch
+from lightning_utilities import apply_to_collection
from torch import Tensor
from torch.nn import Module
-from torch.optim import Optimizer
-from typing_extensions import override
+from torch.optim import LBFGS, Optimizer
+from typing_extensions import get_args, override
import lightning.pytorch as pl
-from lightning.fabric.plugins.precision.deepspeed import _PRECISION_INPUT, _PRECISION_INPUT_STR
-from lightning.fabric.utilities.imports import _raise_enterprise_not_available
+from lightning.fabric.plugins.precision.deepspeed import _PRECISION_INPUT
+from lightning.fabric.plugins.precision.utils import _convert_fp_tensor, _DtypeContextManager
from lightning.fabric.utilities.types import Steppable
from lightning.pytorch.plugins.precision.precision import Precision
from lightning.pytorch.utilities import GradClipAlgorithmType
+from lightning.pytorch.utilities.exceptions import MisconfigurationException
+from lightning.pytorch.utilities.model_helpers import is_overridden
+from lightning.pytorch.utilities.rank_zero import WarningCache
+
+if TYPE_CHECKING:
+ import deepspeed
+
+warning_cache = WarningCache()
class DeepSpeedPrecision(Precision):
@@ -44,29 +53,41 @@ class DeepSpeedPrecision(Precision):
"""
def __init__(self, precision: _PRECISION_INPUT) -> None:
- super().__init__()
- _raise_enterprise_not_available()
- from pytorch_lightning_enterprise.plugins.precision.deepspeed import (
- DeepSpeedPrecisionTrainer as EnterpriseDeepSpeedPrecision,
- )
-
- self.deepspeed_precision_impl = EnterpriseDeepSpeedPrecision(outer_object=self, precision=precision)
+ supported_precision = get_args(_PRECISION_INPUT)
+ if precision not in supported_precision:
+ raise ValueError(
+ f"`Trainer(strategy='deepspeed', precision={precision!r})` is not supported."
+ f" `precision` must be one of: {supported_precision}."
+ )
+ self.precision = precision
+ precision_to_type = {
+ "bf16-mixed": torch.bfloat16,
+ "16-mixed": torch.float16,
+ "bf16-true": torch.bfloat16,
+ "16-true": torch.float16,
+ "32-true": torch.float32,
+ }
+ self._desired_dtype = precision_to_type[self.precision]
@override
def convert_module(self, module: Module) -> Module:
- return self.deepspeed_precision_impl.convert_module(module=module)
+ if "true" in self.precision:
+ return module.to(dtype=self._desired_dtype)
+ return module
@override
def convert_input(self, data: Any) -> Any:
- return self.deepspeed_precision_impl.convert_input(data=data)
+ return apply_to_collection(data, function=_convert_fp_tensor, dtype=Tensor, dst_type=self._desired_dtype)
@override
def tensor_init_context(self) -> AbstractContextManager:
- return self.deepspeed_precision_impl.tensor_init_context()
+ if "true" not in self.precision:
+ return nullcontext()
+ return _DtypeContextManager(self._desired_dtype)
@override
def module_init_context(self) -> AbstractContextManager:
- return self.deepspeed_precision_impl.module_init_context()
+ return self.tensor_init_context()
@override
def backward( # type: ignore[override]
@@ -77,7 +98,7 @@ def backward( # type: ignore[override]
*args: Any,
**kwargs: Any,
) -> None:
- r"""Performs back-propagation.
+ r"""Performs back-propagation using DeepSpeed's engine.
Args:
tensor: the loss tensor
@@ -87,7 +108,13 @@ def backward( # type: ignore[override]
\**kwargs: additional keyword arguments for the :meth:`deepspeed.DeepSpeedEngine.backward` call
"""
- return self.deepspeed_precision_impl.backward(tensor=tensor, model=model, optimizer=optimizer, *args, **kwargs)
+ if is_overridden("backward", model):
+ warning_cache.warn(
+ "You have overridden the `LightningModule.backward` hook but it will be ignored since DeepSpeed handles"
+ " the backward logic internally."
+ )
+ deepspeed_engine: deepspeed.DeepSpeedEngine = model.trainer.model
+ deepspeed_engine.backward(tensor, *args, **kwargs)
@override
def optimizer_step( # type: ignore[override]
@@ -97,7 +124,19 @@ def optimizer_step( # type: ignore[override]
closure: Callable[[], Any],
**kwargs: Any,
) -> Any:
- return self.deepspeed_precision_impl.optimizer_step(optimizer=optimizer, model=model, closure=closure, **kwargs)
+ if isinstance(optimizer, LBFGS):
+ raise MisconfigurationException("DeepSpeed and the LBFGS optimizer are not compatible.")
+ closure_result = closure()
+ self._after_closure(model, optimizer)
+ skipped_backward = closure_result is None
+ # in manual optimization, the closure does not return a value
+ if model.automatic_optimization and skipped_backward:
+ raise MisconfigurationException(
+ "Skipping backward by returning `None` from your `training_step` is not supported by `DeepSpeed`"
+ )
+ # DeepSpeed handles the optimizer step internally
+ deepspeed_engine: deepspeed.DeepSpeedEngine = model.trainer.model
+ return deepspeed_engine.step(**kwargs)
@override
def clip_gradients(
@@ -106,22 +145,4 @@ def clip_gradients(
clip_val: Union[int, float] = 0.0,
gradient_clip_algorithm: GradClipAlgorithmType = GradClipAlgorithmType.NORM,
) -> None:
- return self.deepspeed_precision_impl.clip_gradients(
- optimizer=optimizer, clip_val=clip_val, gradient_clip_algorithm=gradient_clip_algorithm
- )
-
- @property
- def precision(self) -> _PRECISION_INPUT_STR:
- return self.deepspeed_precision_impl.precision
-
- @precision.setter
- def precision(self, value: _PRECISION_INPUT_STR) -> None:
- self.deepspeed_precision_impl.precision = value
-
- @property
- def _desired_dtype(self) -> torch.dtype:
- return self.deepspeed_precision_impl._desired_dtype
-
- @_desired_dtype.setter
- def _desired_dtype(self, value: torch.dtype) -> None:
- self.deepspeed_precision_impl._desired_dtype = value
+ """DeepSpeed handles gradient clipping internally."""
diff --git a/src/lightning/pytorch/plugins/precision/xla.py b/src/lightning/pytorch/plugins/precision/xla.py
index 5d9a818aecf85..6890cc4c1d825 100644
--- a/src/lightning/pytorch/plugins/precision/xla.py
+++ b/src/lightning/pytorch/plugins/precision/xla.py
@@ -11,16 +11,19 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import os
+from functools import partial
from typing import Any, Callable
import torch
-from typing_extensions import override
+from typing_extensions import get_args, override
import lightning.pytorch as pl
-from lightning.fabric.plugins.precision.xla import _PRECISION_INPUT, _PRECISION_INPUT_STR
-from lightning.fabric.utilities.imports import _raise_enterprise_not_available
+from lightning.fabric.accelerators.xla import _XLA_AVAILABLE
+from lightning.fabric.plugins.precision.xla import _PRECISION_INPUT
from lightning.fabric.utilities.types import Optimizable
from lightning.pytorch.plugins.precision.precision import Precision
+from lightning.pytorch.utilities.exceptions import MisconfigurationException
class XLAPrecision(Precision):
@@ -36,12 +39,25 @@ class XLAPrecision(Precision):
"""
def __init__(self, precision: _PRECISION_INPUT = "32-true") -> None:
- super().__init__()
-
- _raise_enterprise_not_available()
- from pytorch_lightning_enterprise.plugins.precision.xla import XLAPrecision as EnterpriseXLAPrecision
-
- self.xla_impl = EnterpriseXLAPrecision(precision)
+ if not _XLA_AVAILABLE:
+ raise ModuleNotFoundError(str(_XLA_AVAILABLE))
+
+ supported_precision = get_args(_PRECISION_INPUT)
+ if precision not in supported_precision:
+ raise ValueError(
+ f"`precision={precision!r})` is not supported in XLA."
+ f" `precision` must be one of: {supported_precision}."
+ )
+ self.precision = precision
+
+ if precision == "16-true":
+ os.environ["XLA_USE_F16"] = "1"
+ self._desired_dtype = torch.float16
+ elif precision == "bf16-true":
+ os.environ["XLA_USE_BF16"] = "1"
+ self._desired_dtype = torch.bfloat16
+ else:
+ self._desired_dtype = torch.float32
@override
def optimizer_step( # type: ignore[override]
@@ -51,24 +67,31 @@ def optimizer_step( # type: ignore[override]
closure: Callable[[], Any],
**kwargs: Any,
) -> Any:
- return self.xla_impl.optimizer_step(optimizer, model, closure, **kwargs)
-
- @property
- def precision(self) -> _PRECISION_INPUT_STR:
- return self.xla_impl.precision
-
- @precision.setter
- def precision(self, precision: _PRECISION_INPUT_STR) -> None:
- self.xla_impl.precision = precision
-
- @property
- def _desired_dtype(self) -> torch.dtype:
- return self.xla_impl._desired_dtype
-
- @_desired_dtype.setter
- def _desired_dtype(self, dtype: torch.dtype) -> None:
- self.xla_impl._desired_dtype = dtype
+ import torch_xla.core.xla_model as xm
+
+ closure = partial(self._xla_wrap_closure, optimizer, closure)
+ closure = partial(self._wrap_closure, model, optimizer, closure)
+ closure_result = optimizer.step(closure=closure, **kwargs)
+ xm.mark_step()
+ skipped_backward = closure_result is None
+ # in manual optimization, the closure does not return a value
+ if model.automatic_optimization and skipped_backward:
+ # we lack coverage here so disable this - something to explore if there's demand
+ raise MisconfigurationException(
+ "Skipping backward by returning `None` from your `training_step` is not implemented with XLA."
+ " Please, open an issue in `https://github.com/Lightning-AI/pytorch-lightning/issues`"
+ " requesting this feature."
+ )
+ return closure_result
@override
def teardown(self) -> None:
- return self.xla_impl.teardown()
+ os.environ.pop("XLA_USE_BF16", None)
+ os.environ.pop("XLA_USE_F16", None)
+
+ def _xla_wrap_closure(self, optimizer: Optimizable, closure: Callable[[], Any]) -> Any:
+ import torch_xla.core.xla_model as xm
+
+ closure_result = closure()
+ xm.reduce_gradients(optimizer)
+ return closure_result
diff --git a/src/lightning/pytorch/strategies/deepspeed.py b/src/lightning/pytorch/strategies/deepspeed.py
index 8bad6a4c6d40b..1fce7b06887cd 100644
--- a/src/lightning/pytorch/strategies/deepspeed.py
+++ b/src/lightning/pytorch/strategies/deepspeed.py
@@ -11,27 +11,52 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
+import argparse
+import json
import logging
+import os
+import platform
from collections import OrderedDict
from collections.abc import Generator, Mapping
from contextlib import contextmanager
from datetime import timedelta
+from pathlib import Path
from typing import TYPE_CHECKING, Any, Optional, Union
import torch
from torch.nn import Module
from torch.optim import Optimizer
+from torch.optim.lr_scheduler import LRScheduler, ReduceLROnPlateau
from typing_extensions import override
import lightning.pytorch as pl
from lightning.fabric.plugins import ClusterEnvironment
from lightning.fabric.plugins.collectives.torch_collective import default_pg_timeout
from lightning.fabric.strategies import _StrategyRegistry
-from lightning.fabric.utilities.imports import _raise_enterprise_not_available
+from lightning.fabric.strategies.deepspeed import (
+ _DEEPSPEED_AVAILABLE,
+ _DEEPSPEED_GREATER_EQUAL_0_16,
+ _format_precision_config,
+ _validate_checkpoint_directory,
+ _validate_device_index_selection,
+)
+from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_6
+from lightning.fabric.utilities.optimizer import _optimizers_to_device
+from lightning.fabric.utilities.seed import reset_seed
from lightning.fabric.utilities.types import _PATH
+from lightning.pytorch.accelerators.cuda import CUDAAccelerator
+from lightning.pytorch.core.optimizer import _init_optimizers_and_lr_schedulers
from lightning.pytorch.plugins.precision import Precision
from lightning.pytorch.strategies.ddp import DDPStrategy
+from lightning.pytorch.trainer.states import TrainerFn
+from lightning.pytorch.utilities import GradClipAlgorithmType
+from lightning.pytorch.utilities.exceptions import MisconfigurationException
+from lightning.pytorch.utilities.model_helpers import is_overridden
+from lightning.pytorch.utilities.rank_zero import WarningCache, rank_zero_info, rank_zero_warn
+from lightning.pytorch.utilities.types import LRSchedulerConfig
+
+log = logging.getLogger(__name__)
+warning_cache = WarningCache()
if TYPE_CHECKING:
import deepspeed
@@ -233,84 +258,150 @@ def __init__(
exclude_frozen_parameters: Exclude frozen parameters when saving checkpoints.
"""
- super().__init__(
- accelerator=accelerator,
- parallel_devices=parallel_devices,
- cluster_environment=cluster_environment,
- precision_plugin=precision_plugin,
- process_group_backend=process_group_backend,
- )
- _raise_enterprise_not_available()
- from pytorch_lightning_enterprise.strategies.deepspeed import (
- DeepSpeedStrategyTrainer as EnterpriseDeepSpeedStrategy,
- )
+ if not _DEEPSPEED_AVAILABLE:
+ raise MisconfigurationException(
+ "To use the `DeepSpeedStrategy`, you must have DeepSpeed installed."
+ " Install it by running `pip install -U deepspeed`."
+ )
- self.deepspeed_strategy_impl = EnterpriseDeepSpeedStrategy(
- outer_object=self,
+ if _TORCH_GREATER_EQUAL_2_6 and not _DEEPSPEED_GREATER_EQUAL_0_16:
+ # Starting with PyTorch 2.6, `torch.load` defaults to `weights_only=True` when loading full checkpoints.
+ # DeepSpeed added support for this behavior in version 0.16.0.
+ import deepspeed
+
+ deepspeed_version = deepspeed.__version__
+
+ raise ImportError(
+ f"PyTorch >= 2.6 requires DeepSpeed >= 0.16.0. "
+ f"Detected DeepSpeed version: {deepspeed_version}. "
+ "Please upgrade by running `pip install -U 'deepspeed>=0.16.0'`."
+ )
+
+ super().__init__(
accelerator=accelerator,
- zero_optimization=zero_optimization,
- stage=stage,
- remote_device=remote_device,
- offload_optimizer=offload_optimizer,
- offload_parameters=offload_parameters,
- offload_params_device=offload_params_device,
- nvme_path=nvme_path,
- params_buffer_count=params_buffer_count,
- params_buffer_size=params_buffer_size,
- max_in_cpu=max_in_cpu,
- offload_optimizer_device=offload_optimizer_device,
- optimizer_buffer_count=optimizer_buffer_count,
- block_size=block_size,
- queue_depth=queue_depth,
- single_submit=single_submit,
- overlap_events=overlap_events,
- thread_count=thread_count,
- pin_memory=pin_memory,
- sub_group_size=sub_group_size,
- contiguous_gradients=contiguous_gradients,
- overlap_comm=overlap_comm,
- allgather_partitions=allgather_partitions,
- reduce_scatter=reduce_scatter,
- allgather_bucket_size=allgather_bucket_size,
- reduce_bucket_size=reduce_bucket_size,
- zero_allow_untested_optimizer=zero_allow_untested_optimizer,
- logging_batch_size_per_gpu=logging_batch_size_per_gpu,
- config=config,
- logging_level=logging_level,
parallel_devices=parallel_devices,
cluster_environment=cluster_environment,
- loss_scale=loss_scale,
- initial_scale_power=initial_scale_power,
- loss_scale_window=loss_scale_window,
- hysteresis=hysteresis,
- min_loss_scale=min_loss_scale,
- partition_activations=partition_activations,
- cpu_checkpointing=cpu_checkpointing,
- contiguous_memory_optimization=contiguous_memory_optimization,
- synchronize_checkpoint_boundary=synchronize_checkpoint_boundary,
- load_full_weights=load_full_weights,
precision_plugin=precision_plugin,
process_group_backend=process_group_backend,
- timeout=timeout,
- exclude_frozen_parameters=exclude_frozen_parameters,
)
+ self._timeout: Optional[timedelta] = timeout
+
+ self.config = self._load_config(config)
+ if self.config is None:
+ # User has not overridden config, set defaults
+ self.config = self._create_default_config(
+ zero_optimization,
+ zero_allow_untested_optimizer,
+ logging_batch_size_per_gpu,
+ offload_optimizer=offload_optimizer,
+ offload_parameters=offload_parameters,
+ nvme_path=nvme_path,
+ offload_params_device=offload_params_device,
+ params_buffer_count=params_buffer_count,
+ params_buffer_size=params_buffer_size,
+ max_in_cpu=max_in_cpu,
+ pin_memory=pin_memory,
+ offload_optimizer_device=offload_optimizer_device,
+ optimizer_buffer_count=optimizer_buffer_count,
+ block_size=block_size,
+ queue_depth=queue_depth,
+ single_submit=single_submit,
+ overlap_events=overlap_events,
+ thread_count=thread_count,
+ partition_activations=partition_activations,
+ cpu_checkpointing=cpu_checkpointing,
+ contiguous_memory_optimization=contiguous_memory_optimization,
+ synchronize_checkpoint_boundary=synchronize_checkpoint_boundary,
+ stage=stage,
+ contiguous_gradients=contiguous_gradients,
+ overlap_comm=overlap_comm,
+ allgather_partitions=allgather_partitions,
+ reduce_scatter=reduce_scatter,
+ allgather_bucket_size=allgather_bucket_size,
+ reduce_bucket_size=reduce_bucket_size,
+ sub_group_size=sub_group_size,
+ )
+ import deepspeed
+
+ self._config_initialized = False
+ deepspeed.utils.logging.logger.setLevel(logging_level)
+
+ self.remote_device = remote_device
+ self.load_full_weights = load_full_weights
+ self.exclude_frozen_parameters = exclude_frozen_parameters
+
+ # default FP16 parameters.
+ self.loss_scale = loss_scale
+ self.initial_scale_power = initial_scale_power
+ self.loss_scale_window = loss_scale_window
+ self.hysteresis = hysteresis
+ self.min_loss_scale = min_loss_scale
@override
def setup_environment(self) -> None:
- return self.deepspeed_strategy_impl.setup_environment()
+ if not isinstance(self.accelerator, CUDAAccelerator):
+ raise RuntimeError(
+ f"The DeepSpeed strategy is only supported on CUDA GPUs but `{self.accelerator.__class__.__name__}`"
+ " is used."
+ )
+ super().setup_environment()
@override
def setup_distributed(self) -> None:
- return self.deepspeed_strategy_impl.setup_distributed()
+ assert self.parallel_devices is not None
+ _validate_device_index_selection(self.parallel_devices)
+ reset_seed()
+ self.set_world_ranks()
+ self._init_deepspeed_distributed()
@override
def setup(self, trainer: "pl.Trainer") -> None:
- return self.deepspeed_strategy_impl.setup(trainer=trainer)
+ self._init_config_if_needed()
+ assert self.accelerator is not None
+ self.accelerator.setup(trainer)
+
+ assert self.model is not None
+ self.model = self.precision_plugin.convert_module(self.model)
+ self.model = self._setup_model(self.model)
+
+ if trainer.state.fn == TrainerFn.FITTING:
+ self.setup_optimizers(trainer)
+ self.setup_precision_plugin()
+ if trainer.state.fn == TrainerFn.FITTING:
+ _optimizers_to_device(self.optimizers, self.root_device)
+
+ self.init_deepspeed()
+ self.barrier()
+
+ def _init_deepspeed_distributed(self) -> None:
+ import deepspeed
+
+ assert self.cluster_environment is not None
+ if platform.system() != "Windows":
+ # do not set env variables on windows, allow deepspeed to control setup
+ self._set_node_environment_variables()
+ log.info(
+ "initializing deepspeed distributed: "
+ f"GLOBAL_RANK: {self.global_rank}, "
+ f"MEMBER: {self.global_rank + 1}/{self.world_size}"
+ )
+ self._process_group_backend = self._get_process_group_backend()
+ deepspeed.init_distributed(
+ self._process_group_backend, distributed_port=self.cluster_environment.main_port, timeout=self._timeout
+ )
+
+ def _set_node_environment_variables(self) -> None:
+ assert self.cluster_environment is not None
+ os.environ["MASTER_ADDR"] = self.cluster_environment.main_address
+ os.environ["MASTER_PORT"] = str(self.cluster_environment.main_port)
+ os.environ["RANK"] = str(self.global_rank)
+ os.environ["WORLD_SIZE"] = str(self.world_size)
+ os.environ["LOCAL_RANK"] = str(self.local_rank)
@property
@override
def restore_checkpoint_after_setup(self) -> bool:
- return self.deepspeed_strategy_impl.restore_checkpoint_after_setup
+ return True
@override
def _setup_model_and_optimizers(
@@ -325,28 +416,188 @@ def _setup_model_and_optimizers(
deepspeed optimizer.
"""
- return self.deepspeed_strategy_impl._setup_model_and_optimizers(model=model, optimizers=optimizers)
+ if len(optimizers) != 1:
+ raise ValueError(
+ f"Currently only one optimizer is supported with DeepSpeed. Got {len(optimizers)} optimizers instead."
+ )
+
+ # train_micro_batch_size_per_gpu is used for throughput logging purposes
+ # normally we set this to the batch size, but it is not available here unless the user provides it
+ # as part of the config
+ assert self.config is not None
+ self.config.setdefault("train_micro_batch_size_per_gpu", 1)
+ self.model, optimizer = self._setup_model_and_optimizer(model, optimizers[0])
+ self._set_deepspeed_activation_checkpointing()
+ return self.model, [optimizer]
+
+ def _setup_model_and_optimizer(
+ self,
+ model: Module,
+ optimizer: Optional[Optimizer],
+ lr_scheduler: Optional[Union[LRScheduler, ReduceLROnPlateau]] = None,
+ ) -> tuple["deepspeed.DeepSpeedEngine", Optimizer]:
+ """Initialize one model and one optimizer with an optional learning rate scheduler.
+
+ This calls ``deepspeed.initialize`` internally.
+
+ """
+ import deepspeed
+
+ model_parameters = filter(lambda p: p.requires_grad, model.parameters())
+ deepspeed_engine, deepspeed_optimizer, _, _ = deepspeed.initialize(
+ args=argparse.Namespace(device_rank=self.root_device.index),
+ config=self.config,
+ model=model,
+ model_parameters=model_parameters,
+ optimizer=optimizer,
+ lr_scheduler=lr_scheduler,
+ dist_init_required=False,
+ )
+ return deepspeed_engine, deepspeed_optimizer
+
+ def init_deepspeed(self) -> None:
+ assert self.lightning_module is not None
+ # deepspeed handles gradient clipping internally
+ if is_overridden("configure_gradient_clipping", self.lightning_module, pl.LightningModule):
+ rank_zero_warn(
+ "Since DeepSpeed handles gradient clipping internally, the default"
+ " `LightningModule.configure_gradient_clipping` implementation will not actually clip gradients."
+ " The hook will still be called. Consider setting"
+ " `Trainer(gradient_clip_val=..., gradient_clip_algorithm='norm')`"
+ " which will use the internal mechanism."
+ )
+
+ if self.lightning_module.trainer.gradient_clip_algorithm == GradClipAlgorithmType.VALUE:
+ raise MisconfigurationException("DeepSpeed does not support clipping gradients by value.")
+
+ assert isinstance(self.model, pl.LightningModule)
+ if self.lightning_module.trainer and self.lightning_module.trainer.training:
+ self._initialize_deepspeed_train(self.model)
+ else:
+ self._initialize_deepspeed_inference(self.model)
+
+ def _init_optimizers(self) -> tuple[Optimizer, Optional[LRSchedulerConfig]]:
+ assert self.lightning_module is not None
+ optimizers, lr_schedulers = _init_optimizers_and_lr_schedulers(self.lightning_module)
+ if len(optimizers) > 1 or len(lr_schedulers) > 1:
+ raise MisconfigurationException(
+ "DeepSpeed currently only supports single optimizer, single optional scheduler."
+ )
+ return optimizers[0], lr_schedulers[0] if lr_schedulers else None
@property
def zero_stage_3(self) -> bool:
- return self.deepspeed_strategy_impl.zero_stage_3
+ assert isinstance(self.config, dict)
+ zero_optimization = self.config.get("zero_optimization")
+ return zero_optimization is not None and zero_optimization.get("stage") == 3
+
+ def _initialize_deepspeed_train(self, model: Module) -> None:
+ optimizer, scheduler = None, None
+ assert isinstance(self.config, dict)
+ if "optimizer" in self.config:
+ rank_zero_info(
+ "You have specified an optimizer and/or scheduler within the DeepSpeed config."
+ " It is recommended to define it in `LightningModule.configure_optimizers`."
+ )
+ lr_scheduler = None
+ else:
+ (
+ optimizer,
+ lr_scheduler,
+ ) = self._init_optimizers()
+ if lr_scheduler is not None:
+ scheduler = lr_scheduler.scheduler
+
+ model, deepspeed_optimizer = self._setup_model_and_optimizer(model, optimizer, scheduler)
+ self._set_deepspeed_activation_checkpointing()
+
+ # although we set these here, deepspeed manages the specific optimizer logic
+ self.optimizers = [deepspeed_optimizer]
+
+ deepspeed_scheduler = model.lr_scheduler
+ if deepspeed_scheduler is not None:
+ # disable deepspeed lr scheduling as lightning manages scheduling
+ model.lr_scheduler = None
+ if lr_scheduler is None:
+ lr_scheduler = LRSchedulerConfig(deepspeed_scheduler, interval="step")
+ else:
+ lr_scheduler.scheduler = deepspeed_scheduler
+ self.lr_scheduler_configs = [lr_scheduler]
+ self.model = model
@contextmanager
@override
def tensor_init_context(self, empty_init: Optional[bool] = None) -> Generator[None, None, None]:
- with self.deepspeed_strategy_impl.tensor_init_context(empty_init=empty_init):
+ if self.zero_stage_3:
+ if empty_init is False:
+ raise NotImplementedError(
+ f"`{empty_init=}` is not a valid choice with `DeepSpeedStrategy` when ZeRO stage 3 is enabled."
+ )
+ yield
+ return
+ with super().tensor_init_context(empty_init=empty_init):
yield
@contextmanager
@override
def model_sharded_context(self) -> Generator[None, None, None]:
- with self.deepspeed_strategy_impl.model_sharded_context():
+ import deepspeed
+
+ self._init_config_if_needed()
+ with deepspeed.zero.Init(
+ enabled=self.zero_stage_3,
+ remote_device=self.remote_device,
+ config_dict_or_path=self.config,
+ ):
yield
+ def _set_deepspeed_activation_checkpointing(self) -> None:
+ import deepspeed
+
+ assert isinstance(self.config, dict)
+ if self.config.get("activation_checkpointing"):
+ checkpoint_config = self.config["activation_checkpointing"]
+ deepspeed.checkpointing.configure(
+ mpu_=None,
+ partition_activations=checkpoint_config.get("partition_activations"),
+ contiguous_checkpointing=checkpoint_config.get("contiguous_memory_optimization"),
+ checkpoint_in_cpu=checkpoint_config.get("cpu_checkpointing"),
+ profile=checkpoint_config.get("profile"),
+ )
+
+ def _initialize_deepspeed_inference(self, model: Module) -> None:
+ import deepspeed
+
+ assert isinstance(self.config, dict)
+
+ # todo: this is required for DeepSpeed throughput timers
+ inference_config = {"train_micro_batch_size_per_gpu": 1}
+ if "fp16" in self.config:
+ inference_config.update({"fp16": self.config["fp16"]})
+ if "bf16" in self.config:
+ inference_config.update({"bf16": self.config["bf16"]})
+ if self.zero_stage_3:
+ inference_config.update({
+ "zero_allow_untested_optimizer": self.config["zero_allow_untested_optimizer"],
+ "zero_optimization": self.config["zero_optimization"],
+ })
+ # Remove all module hooks before initializing new model
+ remove_module_hooks(model)
+ model, _, _, _ = deepspeed.initialize(
+ args=argparse.Namespace(device_rank=self.root_device.index),
+ config=inference_config,
+ model=model,
+ optimizer=None,
+ lr_scheduler=None,
+ model_parameters=[],
+ dist_init_required=False,
+ )
+ self.model = model
+
@property
@override
def distributed_sampler_kwargs(self) -> dict[str, int]:
- return self.deepspeed_strategy_impl.distributed_sampler_kwargs
+ return {"num_replicas": self.world_size, "rank": self.global_rank}
@override
def setup_optimizers(self, trainer: "pl.Trainer") -> None:
@@ -356,21 +607,29 @@ def setup_optimizers(self, trainer: "pl.Trainer") -> None:
trainer: the Trainer, these optimizers should be connected to
"""
- return self.deepspeed_strategy_impl.setup_optimizers(trainer=trainer)
+ # Skip initializing optimizers here as DeepSpeed handles optimizers via config.
+ # User may have specified config options instead in configure_optimizers, but this is handled
+ # via `_initialize_deepspeed_train`
+ # empty optimizers, schedulers
+ self.optimizers = []
+ self.lr_scheduler_configs = []
+
+ def _setup_model(self, model: Module) -> Module: # type: ignore[override]
+ return model
@property
@override
def handles_gradient_accumulation(self) -> bool:
"""Whether the strategy handles gradient accumulation internally."""
- return self.deepspeed_strategy_impl.handles_gradient_accumulation
+ return True
@property
def deepspeed_engine(self) -> "deepspeed.DeepSpeedEngine":
- return self.deepspeed_strategy_impl.deepspeed_engine
+ return self.model
@property
def _multi_device(self) -> bool:
- return self.deepspeed_strategy_impl._multi_device
+ return self.num_processes > 1 or self.num_nodes > 1
@override
def save_checkpoint(self, checkpoint: dict, filepath: _PATH, storage_options: Optional[Any] = None) -> None:
@@ -386,27 +645,135 @@ def save_checkpoint(self, checkpoint: dict, filepath: _PATH, storage_options: Op
If ``storage_options`` arg is passed in
"""
- return self.deepspeed_strategy_impl.save_checkpoint(
- checkpoint=checkpoint, filepath=filepath, storage_options=storage_options
+ # broadcast the filepath from rank 0 to ensure all the states are saved in a common filepath
+ filepath = self.broadcast(filepath)
+
+ if storage_options is not None:
+ raise TypeError(
+ "`Trainer.save_checkpoint(..., storage_options=...)` with `storage_options` arg"
+ f" is not supported for `{self.__class__.__name__}` as `CheckpointIO` is not used."
+ )
+
+ if self.zero_stage_3 and self._multi_device and self.is_global_zero:
+ warning_cache.warn(
+ "When saving the DeepSpeed Stage 3 checkpoint, "
+ "each worker will save a shard of the checkpoint within a directory. "
+ "If a single file is required after training, "
+ "see https://lightning.ai/docs/pytorch/stable/advanced/model_parallel.html#"
+ "deepspeed-zero-stage-3-single-file for instructions."
+ )
+ # Use deepspeed's internal checkpointing function to handle partitioned weights across processes
+ # dump states as a checkpoint dictionary object
+ _exclude_keys = ["state_dict", "optimizer_states"]
+ checkpoint = {k: v for k, v in checkpoint.items() if k not in _exclude_keys}
+ self.deepspeed_engine.save_checkpoint(
+ filepath,
+ client_state=checkpoint,
+ tag="checkpoint",
+ exclude_frozen_parameters=self.exclude_frozen_parameters,
)
@override
def load_checkpoint(self, checkpoint_path: _PATH, weights_only: Optional[bool] = None) -> dict[str, Any]:
- return self.deepspeed_strategy_impl.load_checkpoint(checkpoint_path=checkpoint_path, weights_only=weights_only)
+ if self.load_full_weights and self.zero_stage_3:
+ # Broadcast to ensure we load from the rank 0 checkpoint
+ # This doesn't have to be the case when using deepspeed sharded checkpointing
+ checkpoint_path = self.broadcast(checkpoint_path)
+ return super().load_checkpoint(checkpoint_path, weights_only)
+
+ _validate_checkpoint_directory(checkpoint_path)
+
+ # Rely on deepspeed to load the checkpoint and necessary information
+ assert self.lightning_module is not None
+
+ from lightning.pytorch.trainer.states import TrainerFn
+
+ is_fitting = self.lightning_module.trainer.state.fn == TrainerFn.FITTING
+
+ _, client_state = self.deepspeed_engine.load_checkpoint(
+ checkpoint_path,
+ load_optimizer_states=is_fitting,
+ load_lr_scheduler_states=False,
+ load_module_strict=self.lightning_module.strict_loading,
+ )
+ if client_state is None:
+ raise MisconfigurationException(
+ "DeepSpeed was unable to load the checkpoint. Ensure you passed in a DeepSpeed compatible checkpoint "
+ "or a single checkpoint file with `Trainer(strategy=DeepSpeedStrategy(load_full_weights=True))`."
+ )
+ return client_state
@property
@override
def lightning_restore_optimizer(self) -> bool:
- return self.deepspeed_strategy_impl.lightning_restore_optimizer
+ assert self.lightning_module is not None
+ # managed by DeepSpeed
+ if self.load_full_weights and self.zero_stage_3 and self.lightning_module.trainer.state.fn == TrainerFn.FITTING:
+ rank_zero_warn(
+ "A single checkpoint file has been given. This means optimizer states cannot be restored."
+ " If you'd like to restore these states, you must provide a path to the originally saved DeepSpeed"
+ " checkpoint. When using ZeRO 3, the original path should be a directory."
+ )
+ return False
@override
def load_model_state_dict(self, checkpoint: Mapping[str, Any], strict: bool = True) -> None:
# override to do nothing, deepspeed engine already loaded the weights in `load_checkpoint()`
- return self.deepspeed_strategy_impl.load_model_state_dict(checkpoint=checkpoint, strict=strict)
+ if self.load_full_weights and self.zero_stage_3:
+ self.model_to_device()
+ self._restore_zero_state(checkpoint, strict=strict)
+
+ def _restore_zero_state(self, ckpt: Mapping[str, Any], strict: bool) -> None:
+ """Overrides the normal load_state_dict behaviour in PyTorch to ensure we gather parameters that may be sharded
+ across processes before loading the state dictionary when using ZeRO stage 3. This is then automatically synced
+ across processes.
+
+ Args:
+ ckpt: The ckpt file.
+
+ """
+ import deepspeed
+
+ assert self.lightning_module is not None
+
+ def load(module: torch.nn.Module, prefix: str = "") -> None:
+ missing_keys: list[str] = []
+ unexpected_keys: list[str] = []
+ error_msgs: list[str] = []
+ state_dict = ckpt["state_dict"]
+
+ # copy state_dict so _load_from_state_dict can modify it
+ metadata = getattr(state_dict, "_metadata", None)
+ state_dict = state_dict.copy()
+ if metadata is not None:
+ state_dict._metadata = metadata
+
+ local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
+ # because zero3 puts placeholders in model params, this context
+ # manager gathers (unpartitions) the params of the current layer, then loads from
+ # the state dict and then re-partitions them again
+ with deepspeed.zero.GatheredParameters(list(module.parameters(recurse=False)), modifier_rank=0):
+ if self.is_global_zero:
+ module._load_from_state_dict(
+ state_dict=state_dict,
+ prefix=prefix,
+ local_metadata=local_metadata,
+ strict=strict,
+ missing_keys=missing_keys,
+ unexpected_keys=unexpected_keys,
+ error_msgs=error_msgs,
+ )
+
+ for name, child in module._modules.items():
+ if child is not None:
+ load(child, prefix + name + ".")
+
+ load(self.lightning_module, prefix="")
@override
def load_optimizer_state_dict(self, checkpoint: Mapping[str, Any]) -> None:
- return self.deepspeed_strategy_impl.load_optimizer_state_dict(checkpoint=checkpoint)
+ # Override to do nothing, the deepspeed engine already loaded the states in `load_checkpoint()`
+ pass
@classmethod
@override
@@ -442,10 +809,137 @@ def register_strategies(cls, strategy_registry: _StrategyRegistry) -> None:
offload_optimizer_device="nvme",
)
- @property
- def config(self) -> dict[str, Any]:
- return self.deepspeed_strategy_impl.config
+ def _load_config(self, config: Optional[Union[_PATH, dict[str, Any]]]) -> Optional[dict[str, Any]]:
+ if config is None and self.DEEPSPEED_ENV_VAR in os.environ:
+ rank_zero_info(f"Loading DeepSpeed config from set {self.DEEPSPEED_ENV_VAR} environment variable")
+ config = os.environ[self.DEEPSPEED_ENV_VAR]
+ if isinstance(config, (str, Path)):
+ if not os.path.isfile(config):
+ raise MisconfigurationException(
+ f"You passed in a path to a DeepSpeed config but the path does not exist: {config}"
+ )
+ with open(config) as f:
+ config = json.load(f)
+ assert isinstance(config, dict) or config is None
+ return config
+
+ def _init_config_if_needed(self) -> None:
+ if not self._config_initialized:
+ self._format_config()
+ self._config_initialized = True
+
+ def _format_config(self) -> None:
+ if self.config is None:
+ raise MisconfigurationException(
+ "To use DeepSpeed you must pass in a DeepSpeed config dict, or a path to a JSON config."
+ " See: https://lightning.ai/docs/pytorch/stable/advanced/model_parallel.html#deepspeed"
+ )
+ self._format_batch_size_and_grad_accum_config()
+ _format_precision_config(
+ config=self.config,
+ precision=self.precision_plugin.precision,
+ loss_scale=self.loss_scale,
+ loss_scale_window=self.loss_scale_window,
+ min_loss_scale=self.min_loss_scale,
+ initial_scale_power=self.initial_scale_power,
+ hysteresis=self.hysteresis,
+ )
- @property
- def load_full_weights(self) -> bool:
- return self.deepspeed_strategy_impl.load_full_weights
+ def _create_default_config(
+ self,
+ zero_optimization: bool,
+ zero_allow_untested_optimizer: bool,
+ logging_batch_size_per_gpu: Union[str, int],
+ partition_activations: bool,
+ cpu_checkpointing: bool,
+ contiguous_memory_optimization: bool,
+ synchronize_checkpoint_boundary: bool,
+ offload_optimizer: bool,
+ offload_parameters: bool,
+ nvme_path: str,
+ offload_params_device: str,
+ params_buffer_count: int,
+ params_buffer_size: int,
+ max_in_cpu: int,
+ offload_optimizer_device: str,
+ optimizer_buffer_count: int,
+ pin_memory: bool,
+ block_size: int,
+ queue_depth: int,
+ single_submit: bool,
+ overlap_events: bool,
+ thread_count: int,
+ **zero_kwargs: Any,
+ ) -> dict:
+ cfg = {
+ "activation_checkpointing": {
+ "partition_activations": partition_activations,
+ "cpu_checkpointing": cpu_checkpointing,
+ "contiguous_memory_optimization": contiguous_memory_optimization,
+ "synchronize_checkpoint_boundary": synchronize_checkpoint_boundary,
+ },
+ "aio": {
+ "block_size": block_size,
+ "queue_depth": queue_depth,
+ "single_submit": single_submit,
+ "overlap_events": overlap_events,
+ "thread_count": thread_count,
+ },
+ }
+ if zero_optimization:
+ zero_config = zero_kwargs
+
+ if offload_optimizer:
+ zero_config["offload_optimizer"] = {
+ "device": offload_optimizer_device,
+ "nvme_path": nvme_path,
+ "buffer_count": optimizer_buffer_count,
+ "pin_memory": pin_memory,
+ }
+ if offload_parameters:
+ zero_config["offload_param"] = {
+ "device": offload_params_device,
+ "nvme_path": nvme_path,
+ "buffer_count": params_buffer_count,
+ "buffer_size": params_buffer_size,
+ "max_in_cpu": max_in_cpu,
+ "pin_memory": pin_memory,
+ }
+ cfg = {
+ "zero_allow_untested_optimizer": zero_allow_untested_optimizer,
+ "zero_optimization": zero_config,
+ **cfg,
+ }
+ if logging_batch_size_per_gpu != "auto":
+ cfg = {"train_micro_batch_size_per_gpu": logging_batch_size_per_gpu, **cfg}
+ return cfg
+
+ def _format_batch_size_and_grad_accum_config(self) -> None:
+ # TODO: Using Fabric, we do not support these variables within the config
+ assert isinstance(self.config, dict)
+ if self.lightning_module is None:
+ return
+
+ if "gradient_accumulation_steps" in self.config:
+ raise MisconfigurationException(
+ "Do not set `gradient_accumulation_steps` in the DeepSpeed config"
+ " as this will be set with the `accumulate_grad_batches` argument passed via the Lightning Trainer."
+ )
+ self.config["gradient_accumulation_steps"] = self.lightning_module.trainer.accumulate_grad_batches
+ if "train_micro_batch_size_per_gpu" not in self.config:
+ batch_size = self._auto_select_batch_size()
+ self.config["train_micro_batch_size_per_gpu"] = batch_size
+ if "gradient_clipping" not in self.config:
+ self.config["gradient_clipping"] = self.lightning_module.trainer.gradient_clip_val or 0.0
+
+ def _auto_select_batch_size(self) -> int:
+ # train_micro_batch_size_per_gpu is used for throughput logging purposes
+ # by default we try to use the batch size of the loader
+ assert self.lightning_module is not None
+ batch_size = 1
+ data_source = self.lightning_module.trainer.fit_loop._data_source
+ if data_source.is_defined():
+ train_dataloader = data_source.dataloader()
+ if hasattr(train_dataloader, "batch_sampler"):
+ batch_size = train_dataloader.batch_sampler.batch_size
+ return batch_size
diff --git a/src/lightning/pytorch/strategies/launchers/xla.py b/src/lightning/pytorch/strategies/launchers/xla.py
index 63842d5557d7e..066fecc79f208 100644
--- a/src/lightning/pytorch/strategies/launchers/xla.py
+++ b/src/lightning/pytorch/strategies/launchers/xla.py
@@ -11,18 +11,23 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import os
import queue
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
import torch.multiprocessing as mp
from typing_extensions import override
-from lightning.fabric.utilities.imports import _raise_enterprise_not_available
+from lightning.fabric.accelerators.xla import _XLA_AVAILABLE
+from lightning.fabric.strategies.launchers.xla import _rank_teardown
+from lightning.fabric.utilities import move_data_to_device
from lightning.pytorch.strategies.launchers.multiprocessing import (
_GlobalStateSnapshot,
_MultiProcessingLauncher,
_WorkerOutput,
)
+from lightning.pytorch.trainer.states import TrainerFn
+from lightning.pytorch.utilities.rank_zero import rank_zero_debug
if TYPE_CHECKING:
import lightning.pytorch as pl
@@ -46,16 +51,14 @@ class _XLALauncher(_MultiProcessingLauncher):
"""
def __init__(self, strategy: "pl.strategies.XLAStrategy") -> None:
- super().__init__(strategy)
- _raise_enterprise_not_available()
- from pytorch_lightning_enterprise.strategies.xla.launcher import _XLALauncherTrainer as EnterpriseXLALauncher
-
- self.xla_launcher_impl = EnterpriseXLALauncher(strategy)
+ if not _XLA_AVAILABLE:
+ raise ModuleNotFoundError(str(_XLA_AVAILABLE))
+ super().__init__(strategy=strategy, start_method="fork")
@property
@override
def is_interactive_compatible(self) -> bool:
- return self.xla_launcher_impl.is_interactive_compatible
+ return True
@override
def launch(self, function: Callable, *args: Any, trainer: Optional["pl.Trainer"] = None, **kwargs: Any) -> Any:
@@ -72,7 +75,46 @@ def launch(self, function: Callable, *args: Any, trainer: Optional["pl.Trainer"]
**kwargs: Optional keyword arguments to be passed to the given function.
"""
- return self.xla_launcher_impl.launch(function, *args, trainer=trainer, **kwargs)
+ if self._already_fit and trainer is not None and trainer.state.fn == TrainerFn.FITTING:
+ # resolving https://github.com/Lightning-AI/pytorch-lightning/issues/18775 will lift this restriction
+ raise NotImplementedError(
+ "Calling `trainer.fit()` twice on the same Trainer instance using a spawn-based strategy is not"
+ " supported. You can work around this by creating a new Trainer instance and passing the"
+ " `fit(ckpt_path=...)` argument."
+ )
+
+ # pjrt requires that the queue is serializable
+ return_queue = mp.Manager().Queue()
+
+ import torch_xla.distributed.xla_multiprocessing as xmp
+
+ spawn_kwargs = {}
+ nprocs = self._strategy.num_processes
+ if nprocs == 1:
+ # avoid warning: "Unsupported nprocs". If it's 1, it will call the launched function directly.
+ # otherwise it will use all devices
+ spawn_kwargs["nprocs"] = nprocs
+
+ process_context = xmp.spawn(
+ self._wrapping_function,
+ args=(trainer, function, args, kwargs, return_queue),
+ start_method=self._start_method,
+ join=False, # we will join ourselves to get the process references
+ **spawn_kwargs,
+ )
+ # xla will not actually create processes if only 1 device
+ if process_context is not None:
+ self.procs = process_context.processes
+ while not process_context.join():
+ pass
+
+ worker_output = return_queue.get()
+ if trainer is None:
+ return worker_output
+
+ self._already_fit |= trainer.state.fn == TrainerFn.FITTING
+ self._recover_results_in_main_process(worker_output, trainer)
+ return worker_output.trainer_results
@override
def _wrapping_function(
@@ -87,10 +129,48 @@ def _wrapping_function(
return_queue: Union[mp.SimpleQueue, queue.Queue],
global_states: Optional[_GlobalStateSnapshot] = None,
) -> None:
- return self.xla_launcher_impl._wrapping_function(
- process_idx, trainer, function, args, kwargs, return_queue, global_states
- )
+ import torch_xla.core.xla_model as xm
+
+ if len(xm.get_xla_supported_devices()) > 1:
+ # `get_xla_supported_devices` in the spawned process returns the logical devices (2 for v2/v3 and 1 for v4)
+ # so when there's more than one (multithreading), objects need to be deep-copied
+ import copy
+
+ trainer, function, args, kwargs = copy.deepcopy((trainer, function, args, kwargs))
+
+ results = function(*args, **kwargs)
+
+ if trainer is not None:
+ results = self._collect_rank_zero_results(trainer, results)
+
+ if self._strategy.local_rank == 0:
+ return_queue.put(move_data_to_device(results, "cpu"))
+
+ _rank_teardown(self._strategy.local_rank)
@override
def _collect_rank_zero_results(self, trainer: "pl.Trainer", results: Any) -> Optional["_WorkerOutput"]:
- return self.xla_launcher_impl._collect_rank_zero_results(trainer, results)
+ rank_zero_debug("Collecting results from rank 0 process.")
+ checkpoint_callback = trainer.checkpoint_callback
+ best_model_path = (
+ checkpoint_callback.best_model_path
+ if checkpoint_callback and hasattr(checkpoint_callback, "best_model_path")
+ else None
+ )
+
+ # save the last weights
+ weights_path = None
+ if trainer.state.fn == TrainerFn.FITTING:
+ # requires to compute the state_dict on all processes in case Metrics are present
+ state_dict = self._strategy.lightning_module_state_dict()
+ weights_path = os.path.join(trainer.default_root_dir, ".temp.ckpt")
+ self._strategy.checkpoint_io.save_checkpoint(state_dict, weights_path)
+
+ # We use `local_rank` here as separate filesystems are used for each VM for TPU Pod Training
+ if self._strategy.local_rank != 0:
+ return None
+
+ # add extra result data from trainer to send to main process
+ extra = self.get_extra_results(trainer)
+
+ return _WorkerOutput(best_model_path, weights_path, trainer.state, results, extra)
diff --git a/src/lightning/pytorch/strategies/single_xla.py b/src/lightning/pytorch/strategies/single_xla.py
index 6e687ac815639..2a5e2f3a85b96 100644
--- a/src/lightning/pytorch/strategies/single_xla.py
+++ b/src/lightning/pytorch/strategies/single_xla.py
@@ -11,18 +11,23 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import os
from typing import Optional, Union
+import torch
from typing_extensions import override
import lightning.pytorch as pl
+from lightning.fabric.accelerators.xla import _XLA_AVAILABLE
from lightning.fabric.plugins import CheckpointIO, Precision, XLACheckpointIO
from lightning.fabric.strategies import _StrategyRegistry
-from lightning.fabric.utilities.imports import _raise_enterprise_not_available
+from lightning.fabric.utilities.optimizer import _optimizers_to_device
from lightning.fabric.utilities.types import _DEVICE
from lightning.pytorch.plugins.io.wrapper import _WrappingCheckpointIO
from lightning.pytorch.plugins.precision.xla import XLAPrecision
from lightning.pytorch.strategies.single_device import SingleDeviceStrategy
+from lightning.pytorch.trainer.states import TrainerFn
+from lightning.pytorch.utilities import find_shared_parameters, set_shared_parameters
class SingleDeviceXLAStrategy(SingleDeviceStrategy):
@@ -36,18 +41,20 @@ def __init__(
precision_plugin: Optional[XLAPrecision] = None,
debug: bool = False,
):
+ if not _XLA_AVAILABLE:
+ raise ModuleNotFoundError(str(_XLA_AVAILABLE))
+ if isinstance(device, torch.device):
+ # unwrap the `torch.device` in favor of `xla_device`
+ device = device.index
+ import torch_xla.core.xla_model as xm
+
super().__init__(
accelerator=accelerator,
- device=device,
+ device=xm.xla_device(device),
checkpoint_io=checkpoint_io,
precision_plugin=precision_plugin,
)
- _raise_enterprise_not_available()
- from pytorch_lightning_enterprise.strategies.xla.single import (
- SingleDeviceXLAStrategyTrainer as EnterpriseSingleDeviceXLAStrategy,
- )
-
- self.single_xla_strategy_impl = EnterpriseSingleDeviceXLAStrategy(outer_object=self, device=device, debug=debug)
+ self.debug = debug
@property
@override
@@ -83,7 +90,26 @@ def precision_plugin(self, precision_plugin: Optional[Precision]) -> None:
@override
def setup(self, trainer: "pl.Trainer") -> None:
- return self.single_xla_strategy_impl.setup(trainer=trainer)
+ if self.debug:
+ os.environ["PT_XLA_DEBUG"] = str(1)
+
+ assert self.accelerator is not None
+ self.accelerator.setup(trainer)
+
+ assert self.model is not None
+ self.precision_plugin.convert_module(self.model)
+
+ shared_params = find_shared_parameters(self.model)
+ self.model_to_device()
+ set_shared_parameters(self.model, shared_params)
+
+ self.model = self._setup_model(self.model)
+
+ if trainer.state.fn == TrainerFn.FITTING:
+ self.setup_optimizers(trainer)
+ self.setup_precision_plugin()
+ if trainer.state.fn == TrainerFn.FITTING:
+ _optimizers_to_device(self.optimizers, self.root_device)
@classmethod
@override
@@ -92,4 +118,5 @@ def register_strategies(cls, strategy_registry: _StrategyRegistry) -> None:
@override
def teardown(self) -> None:
- return self.single_xla_strategy_impl.teardown()
+ super().teardown()
+ os.environ.pop("PT_XLA_DEBUG", None)
diff --git a/src/lightning/pytorch/strategies/xla.py b/src/lightning/pytorch/strategies/xla.py
index e57c627f374d7..cbdc890a1ca32 100644
--- a/src/lightning/pytorch/strategies/xla.py
+++ b/src/lightning/pytorch/strategies/xla.py
@@ -11,6 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import io
+import os
from typing import TYPE_CHECKING, Any, Optional, Union
import torch
@@ -19,16 +21,20 @@
from typing_extensions import override
import lightning.pytorch as pl
+from lightning.fabric.accelerators.xla import _XLA_AVAILABLE, _XLA_GREATER_EQUAL_2_1
from lightning.fabric.plugins import CheckpointIO, Precision, XLACheckpointIO
from lightning.fabric.plugins.environments import XLAEnvironment
from lightning.fabric.strategies import _StrategyRegistry
-from lightning.fabric.utilities.imports import _raise_enterprise_not_available
+from lightning.fabric.utilities.optimizer import _optimizers_to_device
from lightning.fabric.utilities.types import _PATH, ReduceOp
from lightning.pytorch.plugins import XLAPrecision
from lightning.pytorch.plugins.io.wrapper import _WrappingCheckpointIO
from lightning.pytorch.strategies.ddp import DDPStrategy
from lightning.pytorch.strategies.launchers.xla import _XLALauncher
from lightning.pytorch.strategies.strategy import TBroadcast
+from lightning.pytorch.trainer.states import TrainerFn
+from lightning.pytorch.utilities import find_shared_parameters, set_shared_parameters
+from lightning.pytorch.utilities.rank_zero import rank_zero_only
if TYPE_CHECKING:
from torch_xla.distributed.parallel_loader import MpDeviceLoader
@@ -50,6 +56,8 @@ def __init__(
sync_module_states: bool = True,
**_: Any,
) -> None:
+ if not _XLA_AVAILABLE:
+ raise ModuleNotFoundError(str(_XLA_AVAILABLE))
super().__init__(
accelerator=accelerator,
parallel_devices=parallel_devices,
@@ -58,12 +66,9 @@ def __init__(
precision_plugin=precision_plugin,
start_method="fork",
)
- _raise_enterprise_not_available()
- from pytorch_lightning_enterprise.strategies.xla.ddp import XLAStrategyTrainer as EnterpriseXLAStrategy
-
- self.xla_strategy_impl = EnterpriseXLAStrategy(
- outer_object=self, debug=debug, sync_module_states=sync_module_states
- )
+ self.debug = debug
+ self._launched = False
+ self._sync_module_states = sync_module_states
@property
@override
@@ -100,27 +105,31 @@ def precision_plugin(self, precision_plugin: Optional[Precision]) -> None:
@property
@override
def root_device(self) -> torch.device:
- return self.xla_strategy_impl.root_device
+ if not self._launched:
+ raise RuntimeError("Accessing the XLA device before processes have spawned is not allowed.")
+ import torch_xla.core.xla_model as xm
+
+ return xm.xla_device()
@property
@override
def global_rank(self) -> int:
- return self.xla_strategy_impl.global_rank
+ return super().global_rank if self._launched else 0
@property
@override
def local_rank(self) -> int:
- return self.xla_strategy_impl.local_rank
+ return super().local_rank if self._launched else 0
@property
@override
def node_rank(self) -> int:
- return self.xla_strategy_impl.node_rank
+ return super().node_rank if self._launched else 0
@property
@override
def world_size(self) -> int:
- return self.xla_strategy_impl.world_size
+ return super().world_size if self._launched else 1
@override
def _configure_launcher(self) -> None:
@@ -128,36 +137,113 @@ def _configure_launcher(self) -> None:
@override
def setup(self, trainer: "pl.Trainer") -> None:
- return self.xla_strategy_impl.setup(trainer=trainer)
+ assert self.accelerator is not None
+ self.accelerator.setup(trainer)
+
+ if self.debug:
+ os.environ["PT_XLA_DEBUG"] = "1"
+
+ assert self.model is not None
+ self.precision_plugin.convert_module(self.model)
+
+ shared_params = find_shared_parameters(self.model)
+ self.model_to_device()
+ set_shared_parameters(self.model, shared_params)
+
+ self.model = self._setup_model(self.model)
+
+ if self._sync_module_states:
+ if _XLA_GREATER_EQUAL_2_1:
+ from torch_xla.core.xla_model import broadcast_master_param
+ else:
+ from torch_xla.experimental.pjrt import broadcast_master_param
+
+ broadcast_master_param(self.model)
+
+ if trainer.state.fn == TrainerFn.FITTING:
+ self.setup_optimizers(trainer)
+ self.setup_precision_plugin()
+ if trainer.state.fn == TrainerFn.FITTING:
+ _optimizers_to_device(self.optimizers, self.root_device)
@override
def _setup_model(self, model: Module) -> Module: # type: ignore
- return self.xla_strategy_impl._setup_model(model=model)
+ return model
@property
@override
def distributed_sampler_kwargs(self) -> dict[str, int]:
- return self.xla_strategy_impl.distributed_sampler_kwargs
+ return {"num_replicas": self.world_size, "rank": self.global_rank}
@override
def process_dataloader(self, dataloader: object) -> "MpDeviceLoader":
- return self.xla_strategy_impl.process_dataloader(dataloader=dataloader)
+ from torch_xla.distributed.parallel_loader import MpDeviceLoader
+
+ if isinstance(dataloader, MpDeviceLoader):
+ # dataloader is already wrapped by MpDeviceLoader
+ return dataloader
+
+ dataloader = MpDeviceLoader(dataloader, self.root_device)
+ # Mimic interface to torch.utils.data.DataLoader
+ dataloader.dataset = dataloader._loader.dataset
+ dataloader.batch_sampler = getattr(dataloader._loader, "batch_sampler", None)
+ return dataloader
@override
def configure_ddp(self) -> None:
- return self.xla_strategy_impl.configure_ddp()
+ pass
@override
def model_to_device(self) -> None:
- return self.xla_strategy_impl.model_to_device()
+ assert self.model is not None
+ self.model = self.model.to(self.root_device)
@override
def barrier(self, name: Optional[str] = None, *args: Any, **kwargs: Any) -> None:
- return self.xla_strategy_impl.barrier(name=name, *args, **kwargs)
+ if not self._launched:
+ return
+
+ import torch_xla.core.xla_model as xm
+
+ if name is None:
+ # `None` is not supported: "TypeError: _xla_rendezvous(): incompatible function arguments"
+ name = ""
+ xm.rendezvous(name)
@override
def broadcast(self, obj: TBroadcast, src: int = 0) -> TBroadcast:
- return self.xla_strategy_impl.broadcast(obj=obj, src=src)
+ if not self._launched:
+ return obj
+
+ import torch_xla.core.xla_model as xm
+
+ is_tensor = isinstance(obj, Tensor)
+ if is_tensor:
+ if obj.dim() == 0:
+ obj = obj.unsqueeze(0)
+ original_device = obj.device
+ # XLA distributed requires that the data is on the XLA device
+ obj = obj.to(self.root_device)
+ else:
+ # support for arbitrary pickle-ables
+ buffer = io.BytesIO()
+ torch.save(obj, buffer)
+ obj = torch.tensor( # type: ignore[assignment]
+ bytearray(buffer.getbuffer()), device=self.root_device, dtype=torch.float
+ )
+
+ obj = [obj]
+ xm.collective_broadcast(obj, root_ordinal=src)
+ obj = obj[0]
+
+ if not is_tensor:
+ # this will preserve the dtype and device of any tensors
+ buffer = io.BytesIO(obj.cpu().byte().numpy())
+ obj = torch.load(buffer)
+ else:
+ obj = obj.to(original_device)
+
+ return obj
@override
def reduce(
@@ -166,27 +252,60 @@ def reduce(
group: Optional[Any] = None,
reduce_op: Optional[Union[ReduceOp, str]] = "mean",
) -> Tensor:
- return self.xla_strategy_impl.reduce(output=output, group=group, reduce_op=reduce_op)
+ if not isinstance(output, Tensor):
+ output = torch.tensor(output, device=self.root_device)
+
+ invalid_reduce_op = isinstance(reduce_op, ReduceOp) and reduce_op != ReduceOp.SUM
+ invalid_reduce_op_str = isinstance(reduce_op, str) and reduce_op.lower() not in ("sum", "mean", "avg")
+ if invalid_reduce_op or invalid_reduce_op_str:
+ raise ValueError(
+ "Currently, the XLAStrategy only supports `sum`, `mean`, `avg` for the reduce operation, got:"
+ f" {reduce_op}"
+ )
+
+ import torch_xla.core.xla_model as xm
+
+ output = xm.mesh_reduce("reduce", output, sum)
+
+ if isinstance(reduce_op, str) and reduce_op.lower() in ("avg", "mean"):
+ output = output / self.world_size
+
+ return output
@override
def setup_environment(self) -> None:
- return self.xla_strategy_impl.setup_environment()
+ self._launched = True
+ super().setup_environment()
@override
def setup_distributed(self) -> None:
- return self.xla_strategy_impl.setup_distributed()
+ assert self.parallel_devices is not None
+ if len(self.parallel_devices) == 1:
+ # spawning only 1 device with PjRT is not supported:
+ # https://github.com/Lightning-AI/pytorch-lightning/pull/17408#discussion_r1170671732
+ raise NotImplementedError(
+ "The `XLAStrategy` does not support running on a single device with the PjRT runtime."
+ " Try using all devices or the `SingleDeviceXLAStrategy` strategy"
+ )
+ rank_zero_only.rank = self.global_rank
@override
def set_world_ranks(self) -> None:
- return self.xla_strategy_impl.set_world_ranks()
+ # accessing global_rank will initialize the XLA computation client. since this is called outside of the spawned
+ # processes (by the accelerator connector), we cannot run the code that would normally be here.
+ # instead it's done in `setup_distributed`
+ pass
@override
def save_checkpoint(
self, checkpoint: dict[str, Any], filepath: _PATH, storage_options: Optional[Any] = None
) -> None:
- return self.xla_strategy_impl.save_checkpoint(
- checkpoint=checkpoint, filepath=filepath, storage_options=storage_options
- )
+ import torch_xla.core.xla_model as xm
+
+ # sync any pending lazy tensors on all ranks before saving to prevent potential collective hangs
+ xm.mark_step()
+ # save on global rank zero only
+ super().save_checkpoint(checkpoint, filepath, storage_options=storage_options)
@override
def remove_checkpoint(self, filepath: _PATH) -> None:
@@ -196,7 +315,8 @@ def remove_checkpoint(self, filepath: _PATH) -> None:
filepath: Path to checkpoint
"""
- return self.xla_strategy_impl.remove_checkpoint(filepath=filepath)
+ if self.local_rank == 0:
+ self.checkpoint_io.remove_checkpoint(filepath)
@override
def all_gather(self, tensor: Tensor, group: Optional[Any] = None, sync_grads: bool = False) -> Tensor:
@@ -210,11 +330,29 @@ def all_gather(self, tensor: Tensor, group: Optional[Any] = None, sync_grads: bo
A tensor of shape (world_size, ...)
"""
- return self.xla_strategy_impl.all_gather(tensor=tensor, group=group, sync_grads=sync_grads)
+ if not self._launched:
+ return tensor
+ if not isinstance(tensor, Tensor):
+ raise NotImplementedError(
+ f"`{type(self).__name__}.all_gather` is only implemented for tensors. Given {tensor}"
+ )
+ if tensor.dim() == 0:
+ tensor = tensor.unsqueeze(0)
+ original_device = tensor.device
+ tensor = tensor.to(self.root_device)
+
+ import torch_xla.core.functions as xf
+ import torch_xla.core.xla_model as xm
+
+ tensor = xf.all_gather(tensor) if sync_grads else xm.all_gather(tensor)
+ tensor = tensor.to(original_device)
+ return tensor
@override
def teardown(self) -> None:
- return self.xla_strategy_impl.teardown()
+ super().teardown()
+ self._launched = False # after the Trainer finishes, we aren't inside the spawned region
+ os.environ.pop("PT_XLA_DEBUG", None)
@classmethod
@override
diff --git a/src/lightning/pytorch/utilities/deepspeed.py b/src/lightning/pytorch/utilities/deepspeed.py
index 21a9190552f43..20b418437c681 100644
--- a/src/lightning/pytorch/utilities/deepspeed.py
+++ b/src/lightning/pytorch/utilities/deepspeed.py
@@ -20,8 +20,8 @@
import torch
-from lightning.fabric.utilities.imports import _DEEPSPEED_AVAILABLE
from lightning.fabric.utilities.types import _PATH
+from lightning.pytorch.strategies.deepspeed import _DEEPSPEED_AVAILABLE
CPU_DEVICE = torch.device("cpu")
diff --git a/tests/tests_fabric/conftest.py b/tests/tests_fabric/conftest.py
index 0564632b315a7..9d4a0b9462f2e 100644
--- a/tests/tests_fabric/conftest.py
+++ b/tests/tests_fabric/conftest.py
@@ -19,13 +19,9 @@
from unittest.mock import Mock
import pytest
-import pytorch_lightning_enterprise.utils.imports
import torch.distributed
import lightning.fabric
-import lightning.fabric.plugins.environments.xla
-import lightning.fabric.plugins.io.xla
-import lightning.fabric.plugins.precision.xla
from lightning.fabric.accelerators import XLAAccelerator
from lightning.fabric.strategies.launchers.subprocess_script import _ChildProcessObserver
from lightning.fabric.utilities.distributed import _destroy_dist_connection
@@ -73,8 +69,6 @@ def restore_env_variables():
# set by torchdynamo
"TRITON_CACHE_DIR",
"TORCHINDUCTOR_CACHE_DIR",
- "TQDM_MININTERVAL", # set by our platform
- "TQDM_POSITION", # set by our platform
}
leaked_vars.difference_update(allowlist)
assert not leaked_vars, f"test is leaking environment variable(s): {set(leaked_vars)}"
@@ -150,33 +144,17 @@ def reset_cudnn_benchmark():
def mock_xla_available(monkeypatch: pytest.MonkeyPatch, value: bool = True) -> None:
- # First, mock torch_xla modules in sys.modules so imports succeed
+ monkeypatch.setattr(lightning.fabric.accelerators.xla, "_XLA_AVAILABLE", value)
+ monkeypatch.setattr(lightning.fabric.plugins.environments.xla, "_XLA_AVAILABLE", value)
+ monkeypatch.setattr(lightning.fabric.plugins.precision.xla, "_XLA_AVAILABLE", value)
+ monkeypatch.setattr(lightning.fabric.plugins.io.xla, "_XLA_AVAILABLE", value)
+ monkeypatch.setattr(lightning.fabric.strategies.single_xla, "_XLA_AVAILABLE", value)
+ monkeypatch.setattr(lightning.fabric.strategies.xla_fsdp, "_XLA_AVAILABLE", value)
+ monkeypatch.setattr(lightning.fabric.strategies.launchers.xla, "_XLA_AVAILABLE", value)
monkeypatch.setitem(sys.modules, "torch_xla", Mock())
monkeypatch.setitem(sys.modules, "torch_xla.core.xla_model", Mock())
monkeypatch.setitem(sys.modules, "torch_xla.experimental", Mock())
monkeypatch.setitem(sys.modules, "torch_xla.distributed.fsdp.wrap", Mock())
- monkeypatch.setitem(sys.modules, "torch_xla._internal", Mock())
- monkeypatch.setitem(sys.modules, "torch_xla._internal.tpu", Mock())
-
- # Then patch the _XLA_AVAILABLE flags in various modules
- monkeypatch.setattr(pytorch_lightning_enterprise.utils.imports, "_XLA_AVAILABLE", value)
- monkeypatch.setattr(pytorch_lightning_enterprise.utils.imports, "_XLA_GREATER_EQUAL_2_1", value)
- monkeypatch.setattr(pytorch_lightning_enterprise.utils.imports, "_XLA_GREATER_EQUAL_2_5", value)
- monkeypatch.setattr(lightning.fabric.accelerators.xla, "_XLA_AVAILABLE", value)
- monkeypatch.setattr(lightning.fabric.accelerators.xla, "_XLA_GREATER_EQUAL_2_1", value)
- monkeypatch.setattr(lightning.fabric.accelerators.xla, "_XLA_GREATER_EQUAL_2_5", value)
- # Patch in the modules where they're used after import
- monkeypatch.setattr("pytorch_lightning_enterprise.accelerators.xla._XLA_AVAILABLE", value)
- monkeypatch.setattr("pytorch_lightning_enterprise.accelerators.xla._XLA_GREATER_EQUAL_2_1", value)
- monkeypatch.setattr("pytorch_lightning_enterprise.accelerators.xla._XLA_GREATER_EQUAL_2_5", value)
- monkeypatch.setattr("pytorch_lightning_enterprise.plugins.environments.xla._XLA_AVAILABLE", value)
- monkeypatch.setattr("pytorch_lightning_enterprise.plugins.environments.xla._XLA_GREATER_EQUAL_2_1", value)
- monkeypatch.setattr("pytorch_lightning_enterprise.plugins.precision.xla._XLA_AVAILABLE", value)
- monkeypatch.setattr("pytorch_lightning_enterprise.strategies.xla.single._XLA_AVAILABLE", value)
- monkeypatch.setattr("pytorch_lightning_enterprise.strategies.xla.ddp._XLA_AVAILABLE", value)
- monkeypatch.setattr("pytorch_lightning_enterprise.strategies.xla.ddp._XLA_GREATER_EQUAL_2_1", value)
- monkeypatch.setattr("pytorch_lightning_enterprise.strategies.xla.fsdp._XLA_AVAILABLE", value)
- monkeypatch.setattr("pytorch_lightning_enterprise.strategies.xla.launcher._XLA_AVAILABLE", value)
@pytest.fixture
@@ -188,14 +166,6 @@ def mock_tpu_available(monkeypatch: pytest.MonkeyPatch, value: bool = True) -> N
mock_xla_available(monkeypatch, value)
monkeypatch.setattr(lightning.fabric.accelerators.xla.XLAAccelerator, "is_available", lambda: value)
monkeypatch.setattr(lightning.fabric.accelerators.xla.XLAAccelerator, "auto_device_count", lambda *_: 8)
- # Also mock the enterprise XLAAccelerator methods
- import pytorch_lightning_enterprise.accelerators.xla
-
- monkeypatch.setattr(pytorch_lightning_enterprise.accelerators.xla.XLAAccelerator, "is_available", lambda: value)
- monkeypatch.setattr(pytorch_lightning_enterprise.accelerators.xla.XLAAccelerator, "auto_device_count", lambda *_: 8)
- monkeypatch.setitem(sys.modules, "torch_xla", Mock())
- monkeypatch.setitem(sys.modules, "torch_xla.core.xla_model", Mock())
- monkeypatch.setitem(sys.modules, "torch_xla.experimental", Mock())
@pytest.fixture
diff --git a/tests/tests_fabric/graveyard/test_tpu.py b/tests/tests_fabric/graveyard/test_tpu.py
index 089720a381cca..5b72d60491df7 100644
--- a/tests/tests_fabric/graveyard/test_tpu.py
+++ b/tests/tests_fabric/graveyard/test_tpu.py
@@ -11,11 +11,11 @@
("lightning.fabric.strategies.single_tpu", "SingleTPUStrategy"),
],
)
-def test_graveyard_single_tpu(import_path, name, tpu_available):
+def test_graveyard_single_tpu(import_path, name):
module = import_module(import_path)
cls = getattr(module, name)
device = torch.device("cpu")
- with pytest.deprecated_call(match="is deprecated"):
+ with pytest.deprecated_call(match="is deprecated"), pytest.raises(ModuleNotFoundError, match="torch_xla"):
cls(device)
@@ -34,12 +34,8 @@ def test_graveyard_single_tpu(import_path, name, tpu_available):
("lightning.fabric.plugins.precision.xlabf16", "XLABf16Precision"),
],
)
-def test_graveyard_no_device(import_path, name, tpu_available):
+def test_graveyard_no_device(import_path, name):
module = import_module(import_path)
cls = getattr(module, name)
- with pytest.deprecated_call(match="is deprecated"):
- _instance = cls()
-
- # required to prevent env-var leakage
- if hasattr(cls, "teardown"):
- cls.teardown(_instance)
+ with pytest.deprecated_call(match="is deprecated"), pytest.raises(ModuleNotFoundError, match="torch_xla"):
+ cls()
diff --git a/tests/tests_fabric/plugins/environments/test_kubeflow.py b/tests/tests_fabric/plugins/environments/test_kubeflow.py
index 72fb527d3facb..3436adc9ce2aa 100644
--- a/tests/tests_fabric/plugins/environments/test_kubeflow.py
+++ b/tests/tests_fabric/plugins/environments/test_kubeflow.py
@@ -61,14 +61,14 @@ def test_attributes_from_environment_variables(caplog):
assert env.local_rank() == 0
assert env.node_rank() == 1
# setter should be no-op
- with caplog.at_level(logging.DEBUG, logger="pytorch_lightning_enterprise.plugins.environments.kubeflow"):
+ with caplog.at_level(logging.DEBUG, logger="lightning.fabric.plugins.environments"):
env.set_global_rank(100)
assert env.global_rank() == 1
assert "setting global rank is not allowed" in caplog.text
caplog.clear()
- with caplog.at_level(logging.DEBUG, logger="pytorch_lightning_enterprise.plugins.environments.kubeflow"):
+ with caplog.at_level(logging.DEBUG, logger="lightning.fabric.plugins.environments"):
env.set_world_size(100)
assert env.world_size() == 20
assert "setting world size is not allowed" in caplog.text
diff --git a/tests/tests_fabric/plugins/environments/test_slurm.py b/tests/tests_fabric/plugins/environments/test_slurm.py
index 742fc44855083..b907c287faa5f 100644
--- a/tests/tests_fabric/plugins/environments/test_slurm.py
+++ b/tests/tests_fabric/plugins/environments/test_slurm.py
@@ -21,6 +21,7 @@
from lightning_utilities.test.warning import no_warning_call
from lightning.fabric.plugins.environments import SLURMEnvironment
+from lightning.fabric.utilities.warnings import PossibleUserWarning
from tests_fabric.helpers.runif import RunIf
@@ -71,14 +72,14 @@ def test_attributes_from_environment_variables(caplog):
assert env.node_rank() == 3
assert env.job_name() == "JOB"
# setter should be no-op
- with caplog.at_level(logging.DEBUG, logger="pytorch_lightning_enterprise.plugins.environments.slurm"):
+ with caplog.at_level(logging.DEBUG, logger="lightning.fabric.plugins.environments"):
env.set_global_rank(100)
assert env.global_rank() == 1
assert "setting global rank is not allowed" in caplog.text
caplog.clear()
- with caplog.at_level(logging.DEBUG, logger="pytorch_lightning_enterprise.plugins.environments.slurm"):
+ with caplog.at_level(logging.DEBUG, logger="lightning.fabric.plugins.environments"):
env.set_world_size(100)
assert env.world_size() == 20
assert "setting world size is not allowed" in caplog.text
@@ -134,18 +135,18 @@ def test_detect():
def test_srun_available_and_not_used(monkeypatch):
"""Test that a warning is emitted if Lightning suspects the user forgot to run their script with `srun`."""
monkeypatch.setattr(sys, "argv", ["train.py", "--lr", "0.01"])
- expected = r"`srun` .* available .* but is not used. HINT: .* srun python\d* train.py --lr 0.01"
+ expected = "`srun` .* available .* but is not used. HINT: .* srun python train.py --lr 0.01"
# pretend `srun` is available
with mock.patch("lightning.fabric.plugins.environments.slurm.shutil.which", return_value="/usr/bin/srun"):
- with pytest.warns(UserWarning, match=expected):
+ with pytest.warns(PossibleUserWarning, match=expected):
SLURMEnvironment()
- with pytest.warns(UserWarning, match=expected):
+ with pytest.warns(PossibleUserWarning, match=expected):
SLURMEnvironment.detect()
# no warning if `srun` is unavailable
- with no_warning_call(UserWarning, match=expected):
+ with no_warning_call(PossibleUserWarning, match=expected):
SLURMEnvironment()
assert not SLURMEnvironment.detect()
@@ -175,7 +176,7 @@ def test_validate_user_settings():
# in interactive mode, validation is skipped because processes get launched by Fabric/Trainer, not SLURM
with mock.patch(
- "pytorch_lightning_enterprise.plugins.environments.slurm.SLURMEnvironment.job_name", return_value="interactive"
+ "lightning.fabric.plugins.environments.slurm.SLURMEnvironment.job_name", return_value="interactive"
):
env = SLURMEnvironment()
env.validate_settings(num_devices=4, num_nodes=1) # no error
diff --git a/tests/tests_fabric/plugins/environments/test_torchelastic.py b/tests/tests_fabric/plugins/environments/test_torchelastic.py
index cfd196c231036..161d42894df30 100644
--- a/tests/tests_fabric/plugins/environments/test_torchelastic.py
+++ b/tests/tests_fabric/plugins/environments/test_torchelastic.py
@@ -58,14 +58,14 @@ def test_attributes_from_environment_variables(caplog):
assert env.local_rank() == 2
assert env.node_rank() == 3
# setter should be no-op
- with caplog.at_level(logging.DEBUG, logger="pytorch_lightning_enterprise.plugins.environments.torchelastic"):
+ with caplog.at_level(logging.DEBUG, logger="lightning.fabric.plugins.environments"):
env.set_global_rank(100)
assert env.global_rank() == 1
assert "setting global rank is not allowed" in caplog.text
caplog.clear()
- with caplog.at_level(logging.DEBUG, logger="pytorch_lightning_enterprise.plugins.environments.torchelastic"):
+ with caplog.at_level(logging.DEBUG, logger="lightning.fabric.plugins.environments"):
env.set_world_size(100)
assert env.world_size() == 20
assert "setting world size is not allowed" in caplog.text
diff --git a/tests/tests_fabric/plugins/environments/test_xla.py b/tests/tests_fabric/plugins/environments/test_xla.py
index 0557f5620e8e1..76fd18012ee94 100644
--- a/tests/tests_fabric/plugins/environments/test_xla.py
+++ b/tests/tests_fabric/plugins/environments/test_xla.py
@@ -100,9 +100,8 @@ def test_detect(monkeypatch):
@mock.patch.dict(os.environ, {}, clear=True)
-@mock.patch("pytorch_lightning_enterprise.utils.imports._XLA_AVAILABLE", True)
-@mock.patch("pytorch_lightning_enterprise.utils.imports._XLA_GREATER_EQUAL_2_1", True)
@mock.patch("lightning.fabric.accelerators.xla._XLA_GREATER_EQUAL_2_1", True)
+@mock.patch("lightning.fabric.plugins.environments.xla._XLA_GREATER_EQUAL_2_1", True)
def test_world_size_from_xla_runtime_greater_2_1(xla_available):
"""Test that world_size uses torch_xla.runtime when XLA >= 2.1."""
env = XLAEnvironment()
@@ -114,9 +113,8 @@ def test_world_size_from_xla_runtime_greater_2_1(xla_available):
@mock.patch.dict(os.environ, {}, clear=True)
-@mock.patch("pytorch_lightning_enterprise.utils.imports._XLA_AVAILABLE", True)
-@mock.patch("pytorch_lightning_enterprise.utils.imports._XLA_GREATER_EQUAL_2_1", True)
@mock.patch("lightning.fabric.accelerators.xla._XLA_GREATER_EQUAL_2_1", True)
+@mock.patch("lightning.fabric.plugins.environments.xla._XLA_GREATER_EQUAL_2_1", True)
def test_global_rank_from_xla_runtime_greater_2_1(xla_available):
"""Test that global_rank uses torch_xla.runtime when XLA >= 2.1."""
env = XLAEnvironment()
@@ -128,9 +126,8 @@ def test_global_rank_from_xla_runtime_greater_2_1(xla_available):
@mock.patch.dict(os.environ, {}, clear=True)
-@mock.patch("pytorch_lightning_enterprise.utils.imports._XLA_AVAILABLE", True)
-@mock.patch("pytorch_lightning_enterprise.utils.imports._XLA_GREATER_EQUAL_2_1", True)
@mock.patch("lightning.fabric.accelerators.xla._XLA_GREATER_EQUAL_2_1", True)
+@mock.patch("lightning.fabric.plugins.environments.xla._XLA_GREATER_EQUAL_2_1", True)
def test_local_rank_from_xla_runtime_greater_2_1(xla_available):
"""Test that local_rank uses torch_xla.runtime when XLA >= 2.1."""
env = XLAEnvironment()
@@ -142,9 +139,8 @@ def test_local_rank_from_xla_runtime_greater_2_1(xla_available):
@mock.patch.dict(os.environ, {}, clear=True)
-@mock.patch("pytorch_lightning_enterprise.utils.imports._XLA_AVAILABLE", True)
-@mock.patch("pytorch_lightning_enterprise.utils.imports._XLA_GREATER_EQUAL_2_1", True)
@mock.patch("lightning.fabric.accelerators.xla._XLA_GREATER_EQUAL_2_1", True)
+@mock.patch("lightning.fabric.plugins.environments.xla._XLA_GREATER_EQUAL_2_1", True)
def test_setters_readonly_when_xla_runtime_greater_2_1(xla_available):
"""Test that set_world_size and set_global_rank don't affect values when using XLA runtime >= 2.1."""
env = XLAEnvironment()
diff --git a/tests/tests_fabric/plugins/precision/test_transformer_engine.py b/tests/tests_fabric/plugins/precision/test_transformer_engine.py
index 22908e107131b..ed7c984b1ae64 100644
--- a/tests/tests_fabric/plugins/precision/test_transformer_engine.py
+++ b/tests/tests_fabric/plugins/precision/test_transformer_engine.py
@@ -15,22 +15,19 @@
from unittest.mock import Mock
import pytest
-import pytorch_lightning_enterprise.utils.imports
import torch
import torch.distributed
+import lightning.fabric
from lightning.fabric.connector import _Connector
from lightning.fabric.plugins.precision.transformer_engine import TransformerEnginePrecision
def test_transformer_engine_plugin(monkeypatch):
- module = pytorch_lightning_enterprise.utils.imports
+ module = lightning.fabric.plugins.precision.transformer_engine
if module._TRANSFORMER_ENGINE_AVAILABLE:
pytest.skip("Assumes transformer_engine is unavailable")
- monkeypatch.setattr(module, "_TRANSFORMER_ENGINE_AVAILABLE", True)
- monkeypatch.setattr(
- "pytorch_lightning_enterprise.plugins.precision.transformer_engine._TRANSFORMER_ENGINE_AVAILABLE", True
- )
+ monkeypatch.setattr(module, "_TRANSFORMER_ENGINE_AVAILABLE", lambda: True)
transformer_engine_mock = Mock()
monkeypatch.setitem(sys.modules, "transformer_engine", transformer_engine_mock)
monkeypatch.setitem(sys.modules, "transformer_engine.pytorch", Mock())
@@ -121,11 +118,8 @@ class TELayerNormMock(Mock): ...
def test_convert_module_handles_linear_without_bias(monkeypatch):
- module = pytorch_lightning_enterprise.utils.imports # Set up mock transformer_engine
- monkeypatch.setattr(module, "_TRANSFORMER_ENGINE_AVAILABLE", True)
- monkeypatch.setattr(
- "pytorch_lightning_enterprise.plugins.precision.transformer_engine._TRANSFORMER_ENGINE_AVAILABLE", True
- )
+ module = lightning.fabric.plugins.precision.transformer_engine # Set up mock transformer_engine
+ monkeypatch.setattr(module, "_TRANSFORMER_ENGINE_AVAILABLE", lambda: True)
transformer_engine_mock = Mock()
monkeypatch.setitem(sys.modules, "transformer_engine", transformer_engine_mock)
diff --git a/tests/tests_fabric/strategies/test_deepspeed.py b/tests/tests_fabric/strategies/test_deepspeed.py
index 022e1320068ff..0194c7b87820a 100644
--- a/tests/tests_fabric/strategies/test_deepspeed.py
+++ b/tests/tests_fabric/strategies/test_deepspeed.py
@@ -75,7 +75,7 @@ def test_deepspeed_defaults():
strategy = DeepSpeedStrategy()
assert strategy.config is not None
assert isinstance(strategy.config["zero_optimization"], dict)
- assert strategy.deepspeed_impl._backward_sync_control is None
+ assert strategy._backward_sync_control is None
@RunIf(deepspeed=True)
@@ -230,7 +230,7 @@ def test_deepspeed_save_checkpoint_exclude_frozen_parameters(exclude_frozen_para
from deepspeed import DeepSpeedEngine
strategy = DeepSpeedStrategy(exclude_frozen_parameters=exclude_frozen_parameters)
- assert strategy.deepspeed_impl.exclude_frozen_parameters is exclude_frozen_parameters
+ assert strategy.exclude_frozen_parameters is exclude_frozen_parameters
model = Mock(spec=DeepSpeedEngine, optimizer=None)
model.modules.return_value = [model]
@@ -275,7 +275,7 @@ def test_deepspeed_load_checkpoint_no_state(tmp_path):
@RunIf(deepspeed=True)
-@mock.patch("pytorch_lightning_enterprise.strategies.deepspeed._is_deepspeed_checkpoint", return_value=True)
+@mock.patch("lightning.fabric.strategies.deepspeed._is_deepspeed_checkpoint", return_value=True)
def test_deepspeed_load_checkpoint_one_deepspeed_engine_required(_, tmp_path):
"""Test that the DeepSpeed strategy can only load one DeepSpeedEngine per checkpoint."""
from deepspeed import DeepSpeedEngine
@@ -317,7 +317,7 @@ def test_deepspeed_load_checkpoint_client_state_missing(tmp_path):
@RunIf(deepspeed=True)
-@mock.patch("pytorch_lightning_enterprise.strategies.deepspeed._is_deepspeed_checkpoint", return_value=True)
+@mock.patch("lightning.fabric.strategies.deepspeed._is_deepspeed_checkpoint", return_value=True)
def test_deepspeed_load_checkpoint_state_updated_with_client_state(_, tmp_path):
"""Test that the DeepSpeed strategy properly updates the state variables and returns additional metadata."""
from deepspeed import DeepSpeedEngine
@@ -342,7 +342,7 @@ def test_deepspeed_load_checkpoint_state_updated_with_client_state(_, tmp_path):
@RunIf(deepspeed=True)
@pytest.mark.parametrize("optimzer_state_requested", [True, False])
-@mock.patch("pytorch_lightning_enterprise.strategies.deepspeed._is_deepspeed_checkpoint", return_value=True)
+@mock.patch("lightning.fabric.strategies.deepspeed._is_deepspeed_checkpoint", return_value=True)
def test_deepspeed_load_checkpoint_optimzer_state_requested(_, optimzer_state_requested, tmp_path):
"""Test that the DeepSpeed strategy loads the optimizer state only when requested."""
from deepspeed import DeepSpeedEngine
@@ -427,7 +427,6 @@ def test_validate_parallel_devices_indices(device_indices):
"""
accelerator = Mock(spec=CUDAAccelerator)
- accelerator.name.return_value = "cuda"
strategy = DeepSpeedStrategy(
accelerator=accelerator, parallel_devices=[torch.device("cuda", i) for i in device_indices]
)
diff --git a/tests/tests_fabric/strategies/test_deepspeed_integration.py b/tests/tests_fabric/strategies/test_deepspeed_integration.py
index 8f2e0034f5dda..2a04c6b27dded 100644
--- a/tests/tests_fabric/strategies/test_deepspeed_integration.py
+++ b/tests/tests_fabric/strategies/test_deepspeed_integration.py
@@ -164,7 +164,7 @@ def test_deepspeed_custom_precision_params():
devices=1,
)
fabric.launch()
- assert fabric._strategy.deepspeed_impl._config_initialized
+ assert fabric._strategy._config_initialized
assert fabric._strategy.config["fp16"]["loss_scale"] == 10
assert fabric._strategy.config["fp16"]["initial_scale_power"] == 11
assert fabric._strategy.config["fp16"]["loss_scale_window"] == 12
@@ -251,7 +251,7 @@ def test_deepspeed_env_variables_on_platforms(_, deepspeed_dist_mock, platform):
strategy = fabric._strategy
assert isinstance(strategy, DeepSpeedStrategy)
with mock.patch("platform.system", return_value=platform) as platform_mock:
- strategy.deepspeed_impl._init_deepspeed_distributed()
+ strategy._init_deepspeed_distributed()
deepspeed_dist_mock.assert_called()
platform_mock.assert_called()
if platform == "Windows":
diff --git a/tests/tests_fabric/strategies/test_xla.py b/tests/tests_fabric/strategies/test_xla.py
index 1884208dc2126..9ca21d8d8b894 100644
--- a/tests/tests_fabric/strategies/test_xla.py
+++ b/tests/tests_fabric/strategies/test_xla.py
@@ -194,14 +194,14 @@ def test_rank_properties_access(xla_available):
strategy.cluster_environment = Mock()
# we're in the main process, no processes have been launched yet
- assert not strategy.xla_strategy_impl._launched
+ assert not strategy._launched
assert strategy.global_rank == 0
assert strategy.local_rank == 0
assert strategy.node_rank == 0
assert strategy.world_size == 1
# simulate we're in a worker process
- strategy.xla_strategy_impl._launched = True
+ strategy._launched = True
assert strategy.global_rank == strategy.cluster_environment.global_rank()
assert strategy.local_rank == strategy.cluster_environment.local_rank()
assert strategy.node_rank == strategy.cluster_environment.node_rank()
diff --git a/tests/tests_fabric/strategies/test_xla_fsdp.py b/tests/tests_fabric/strategies/test_xla_fsdp.py
index fa1021402590e..c2634283ad110 100644
--- a/tests/tests_fabric/strategies/test_xla_fsdp.py
+++ b/tests/tests_fabric/strategies/test_xla_fsdp.py
@@ -23,6 +23,7 @@
from lightning.fabric.accelerators import XLAAccelerator
from lightning.fabric.plugins import XLAPrecision
from lightning.fabric.strategies import XLAFSDPStrategy
+from lightning.fabric.strategies.xla_fsdp import _activation_checkpointing_auto_wrapper, _XLAFSDPBackwardSyncControl
from tests_fabric.helpers.runif import RunIf
@@ -42,7 +43,6 @@ def test_xla_fsdp_setup_optimizer_validation():
def test_xla_fsdp_no_backward_sync():
"""Test that the backward sync control calls `.no_sync()`, and only on a module wrapped in
XlaFullyShardedDataParallel."""
- from pytorch_lightning_enterprise.strategies.xla.fsdp import _XLAFSDPBackwardSyncControl
from torch_xla.distributed.fsdp import XlaFullyShardedDataParallel
strategy = XLAFSDPStrategy()
@@ -75,45 +75,41 @@ def test_xla_fsdp_grad_clipping_value_error():
@mock.patch.dict(os.environ, os.environ.copy(), clear=True)
-@mock.patch("pytorch_lightning_enterprise.strategies.xla.fsdp._XLA_AVAILABLE", True)
def test_rank_properties_access(xla_available):
"""Test that the strategy returns the expected values depending on whether we're in the main process or not."""
strategy = XLAFSDPStrategy()
strategy.cluster_environment = Mock()
# we're in the main process, no processes have been launched yet
- assert not strategy.xla_fsdp_impl._launched
+ assert not strategy._launched
assert strategy.global_rank == 0
assert strategy.local_rank == 0
assert strategy.node_rank == 0
assert strategy.world_size == 1
# simulate we're in a worker process
- strategy.xla_fsdp_impl._launched = True
+ strategy._launched = True
assert strategy.global_rank == strategy.cluster_environment.global_rank()
assert strategy.local_rank == strategy.cluster_environment.local_rank()
assert strategy.node_rank == strategy.cluster_environment.node_rank()
assert strategy.world_size == strategy.cluster_environment.world_size()
-@mock.patch("pytorch_lightning_enterprise.strategies.xla.fsdp._XLA_AVAILABLE", True)
def test_xla_fsdp_policy(xla_available):
- from pytorch_lightning_enterprise.strategies.xla import fsdp as fsdp_module
-
strategy = XLAFSDPStrategy(foo=1)
- assert strategy.xla_fsdp_impl._fsdp_kwargs == {"foo": 1}
+ assert strategy._fsdp_kwargs == {"foo": 1}
strategy = XLAFSDPStrategy(auto_wrap_policy={torch.nn.Linear})
- kwargs = strategy.xla_fsdp_impl._parse_fsdp_kwargs()
+ kwargs = strategy._parse_fsdp_kwargs()
assert set(kwargs) == {"auto_wrap_policy", "compute_dtype"}
assert kwargs["auto_wrap_policy"].func._mock_name == "transformer_auto_wrap_policy"
assert kwargs["compute_dtype"] is torch.float32
strategy = XLAFSDPStrategy(activation_checkpointing_policy={torch.nn.Linear})
- _ = strategy.xla_fsdp_impl._parse_fsdp_kwargs()
- kwargs = strategy.xla_fsdp_impl._parse_fsdp_kwargs() # ensure it's idempotent
+ _ = strategy._parse_fsdp_kwargs()
+ kwargs = strategy._parse_fsdp_kwargs() # ensure it's idempotent
assert set(kwargs) == {"auto_wrapper_callable", "compute_dtype"}
- assert kwargs["auto_wrapper_callable"].func is fsdp_module._activation_checkpointing_auto_wrapper
+ assert kwargs["auto_wrapper_callable"].func is _activation_checkpointing_auto_wrapper
assert kwargs["compute_dtype"] is torch.float32
strategy = XLAFSDPStrategy(
@@ -122,17 +118,17 @@ def test_xla_fsdp_policy(xla_available):
activation_checkpointing_policy={torch.nn.Linear},
precision=XLAPrecision("bf16-true"),
)
- kwargs = strategy.xla_fsdp_impl._parse_fsdp_kwargs()
+ kwargs = strategy._parse_fsdp_kwargs()
assert set(kwargs) == {"auto_wrap_policy", "auto_wrapper_callable", "compute_dtype"}
assert kwargs["auto_wrap_policy"].func._mock_name == "transformer_auto_wrap_policy"
- assert kwargs["auto_wrapper_callable"].func is fsdp_module._activation_checkpointing_auto_wrapper
+ assert kwargs["auto_wrapper_callable"].func is _activation_checkpointing_auto_wrapper
assert kwargs["compute_dtype"] is torch.bfloat16
strategy.teardown()
strategy = XLAFSDPStrategy(activation_checkpointing_policy={torch.nn.Linear}, auto_wrapper_callable="foo")
with pytest.raises(ValueError, match="cannot set both"):
- strategy.xla_fsdp_impl._parse_fsdp_kwargs()
+ strategy._parse_fsdp_kwargs()
strategy = XLAFSDPStrategy(activation_checkpointing_policy="foo")
with pytest.raises(TypeError, match="must be a set"):
- strategy.xla_fsdp_impl._parse_fsdp_kwargs()
+ strategy._parse_fsdp_kwargs()
diff --git a/tests/tests_fabric/test_connector.py b/tests/tests_fabric/test_connector.py
index de9f7b2d16395..1074789e71055 100644
--- a/tests/tests_fabric/test_connector.py
+++ b/tests/tests_fabric/test_connector.py
@@ -242,10 +242,8 @@ class TestStrategy(DDPStrategy):
),
],
)
-@mock.patch(
- "pytorch_lightning_enterprise.plugins.environments.lsf.LSFEnvironment._read_hosts", return_value=["node0", "node1"]
-)
-@mock.patch("pytorch_lightning_enterprise.plugins.environments.lsf.LSFEnvironment._get_node_rank", return_value=0)
+@mock.patch("lightning.fabric.plugins.environments.lsf.LSFEnvironment._read_hosts", return_value=["node0", "node1"])
+@mock.patch("lightning.fabric.plugins.environments.lsf.LSFEnvironment._get_node_rank", return_value=0)
def test_fallback_from_ddp_spawn_to_ddp_on_cluster(_, __, env_vars, expected_environment):
with mock.patch.dict(os.environ, env_vars, clear=True):
connector = _Connector(strategy="ddp_spawn", accelerator="cpu", devices=2)
@@ -343,8 +341,8 @@ def test_cuda_accelerator_can_not_run_on_system(_):
@pytest.mark.skipif(XLAAccelerator.is_available(), reason="test requires missing TPU")
-@mock.patch("pytorch_lightning_enterprise.accelerators.xla._XLA_AVAILABLE", True)
-@mock.patch("pytorch_lightning_enterprise.accelerators.xla._using_pjrt", return_value=True)
+@mock.patch("lightning.fabric.accelerators.xla._XLA_AVAILABLE", True)
+@mock.patch("lightning.fabric.accelerators.xla._using_pjrt", return_value=True)
def test_tpu_accelerator_can_not_run_on_system(_):
with pytest.raises(RuntimeError, match="XLAAccelerator` can not run on your system"):
_Connector(accelerator="tpu", devices=8)
@@ -890,7 +888,7 @@ def test_precision_selection_model_parallel(_, precision, raises):
def test_bitsandbytes_precision_cuda_required(monkeypatch):
- monkeypatch.setattr("pytorch_lightning_enterprise.plugins.precision.bitsandbytes._BITSANDBYTES_AVAILABLE", True)
+ monkeypatch.setattr(lightning.fabric.plugins.precision.bitsandbytes, "_BITSANDBYTES_AVAILABLE", True)
monkeypatch.setitem(sys.modules, "bitsandbytes", Mock())
with pytest.raises(RuntimeError, match="Bitsandbytes is only supported on CUDA GPUs"):
_Connector(accelerator="cpu", plugins=BitsandbytesPrecision(mode="int8"))
diff --git a/tests/tests_fabric/utilities/test_throughput.py b/tests/tests_fabric/utilities/test_throughput.py
index 785ec0a3ed2e9..a175fa97fd444 100644
--- a/tests/tests_fabric/utilities/test_throughput.py
+++ b/tests/tests_fabric/utilities/test_throughput.py
@@ -45,13 +45,7 @@ def test_get_available_flops(xla_available):
):
assert get_available_flops(torch.device("cuda"), torch.bfloat16) is None
- # Import from the right module based on _XLA_GREATER_EQUAL_2_1
- from lightning.fabric.accelerators.xla import _XLA_GREATER_EQUAL_2_1
-
- if _XLA_GREATER_EQUAL_2_1:
- from torch_xla._internal import tpu
- else:
- from torch_xla.experimental import tpu
+ from torch_xla.experimental import tpu
assert isinstance(tpu, Mock)
diff --git a/tests/tests_pytorch/accelerators/test_xla.py b/tests/tests_pytorch/accelerators/test_xla.py
index 52434ee3a19d5..5e56d5c585c88 100644
--- a/tests/tests_pytorch/accelerators/test_xla.py
+++ b/tests/tests_pytorch/accelerators/test_xla.py
@@ -22,6 +22,7 @@
from torch import nn
from torch.utils.data import DataLoader
+import lightning.fabric
from lightning.fabric.utilities.imports import _IS_WINDOWS
from lightning.pytorch import Trainer
from lightning.pytorch.accelerators import CPUAccelerator, XLAAccelerator
@@ -311,7 +312,7 @@ def test_warning_if_tpus_not_used(tpu_available):
)
@RunIf(min_python="3.9") # mocking issue
def test_trainer_config_device_ids(devices, expected_device_ids, tpu_available, monkeypatch):
- monkeypatch.setattr("pytorch_lightning_enterprise.accelerators.xla._using_pjrt", lambda: True)
+ monkeypatch.setattr(lightning.fabric.accelerators.xla, "_using_pjrt", lambda: True)
mock = DeviceMock()
monkeypatch.setattr(torch, "device", mock)
diff --git a/tests/tests_pytorch/callbacks/progress/test_tqdm_progress_bar.py b/tests/tests_pytorch/callbacks/progress/test_tqdm_progress_bar.py
index e764765ca75f7..0bd29b998c598 100644
--- a/tests/tests_pytorch/callbacks/progress/test_tqdm_progress_bar.py
+++ b/tests/tests_pytorch/callbacks/progress/test_tqdm_progress_bar.py
@@ -480,9 +480,8 @@ def predict_step(self, *args, **kwargs):
return super().predict_step(*args, **kwargs)
-@mock.patch("builtins.print")
-@mock.patch("lightning.pytorch.callbacks.progress.tqdm_progress.Tqdm.write")
-def test_tqdm_progress_bar_print(tqdm_write, mock_print, tmp_path):
+@mock.patch("tqdm.tqdm.write")
+def test_tqdm_progress_bar_print(tqdm_write, tmp_path):
"""Test that printing in the LightningModule redirects arguments to the progress bar."""
model = PrintModel()
bar = TQDMProgressBar()
@@ -507,9 +506,8 @@ def test_tqdm_progress_bar_print(tqdm_write, mock_print, tmp_path):
]
-@mock.patch("builtins.print")
-@mock.patch("lightning.pytorch.callbacks.progress.tqdm_progress.Tqdm.write")
-def test_tqdm_progress_bar_print_no_train(tqdm_write, mock_print, tmp_path):
+@mock.patch("tqdm.tqdm.write")
+def test_tqdm_progress_bar_print_no_train(tqdm_write, tmp_path):
"""Test that printing in the LightningModule redirects arguments to the progress bar without training."""
model = PrintModel()
bar = TQDMProgressBar()
diff --git a/tests/tests_pytorch/conftest.py b/tests/tests_pytorch/conftest.py
index 53dd20e7e4399..878298c6bfd94 100644
--- a/tests/tests_pytorch/conftest.py
+++ b/tests/tests_pytorch/conftest.py
@@ -22,7 +22,6 @@
from unittest.mock import Mock
import pytest
-import pytorch_lightning_enterprise.utils.imports
import torch.distributed
from tqdm import TMonitor
@@ -102,8 +101,6 @@ def restore_env_variables():
"TPU_ML_PLATFORM_VERSION",
"LD_LIBRARY_PATH",
"ENABLE_RUNTIME_UPTIME_TELEMETRY",
- "TQDM_MININTERVAL", # set by our platform
- "TQDM_POSITION", # set by our platform
}
leaked_vars.difference_update(allowlist)
assert not leaked_vars, f"test is leaking environment variable(s): {set(leaked_vars)}"
@@ -223,47 +220,14 @@ def mps_count_1(monkeypatch):
def mock_xla_available(monkeypatch: pytest.MonkeyPatch, value: bool = True) -> None:
- # First, mock torch_xla modules in sys.modules so imports succeed
- monkeypatch.setitem(sys.modules, "torch_xla", Mock())
- monkeypatch.setitem(sys.modules, "torch_xla.core", Mock())
- monkeypatch.setitem(sys.modules, "torch_xla.core.xla_model", Mock())
- monkeypatch.setitem(sys.modules, "torch_xla.core.functions", Mock())
- monkeypatch.setitem(sys.modules, "torch_xla.core.xla_env_vars", Mock())
- monkeypatch.setitem(sys.modules, "torch_xla.experimental", Mock())
- monkeypatch.setitem(sys.modules, "torch_xla.experimental.pjrt", Mock())
- monkeypatch.setitem(sys.modules, "torch_xla.experimental.tpu", Mock())
- monkeypatch.setitem(sys.modules, "torch_xla.distributed", Mock())
- monkeypatch.setitem(sys.modules, "torch_xla.distributed.fsdp", Mock())
- monkeypatch.setitem(sys.modules, "torch_xla.distributed.fsdp.wrap", Mock())
- monkeypatch.setitem(sys.modules, "torch_xla.distributed.parallel_loader", Mock())
- monkeypatch.setitem(sys.modules, "torch_xla.distributed.xla_multiprocessing", Mock())
- monkeypatch.setitem(sys.modules, "torch_xla.runtime", Mock())
- monkeypatch.setitem(sys.modules, "torch_xla.utils", Mock())
- monkeypatch.setitem(sys.modules, "torch_xla.utils.utils", Mock())
- monkeypatch.setitem(sys.modules, "torch_xla.debug", Mock())
- monkeypatch.setitem(sys.modules, "torch_xla.debug.profiler", Mock())
- monkeypatch.setitem(sys.modules, "torch_xla._internal", Mock())
- monkeypatch.setitem(sys.modules, "torch_xla._internal.tpu", Mock())
-
- # Then patch the _XLA_AVAILABLE flags in various modules
- monkeypatch.setattr(pytorch_lightning_enterprise.utils.imports, "_XLA_AVAILABLE", value)
- monkeypatch.setattr(pytorch_lightning_enterprise.utils.imports, "_XLA_GREATER_EQUAL_2_1", value)
- monkeypatch.setattr(pytorch_lightning_enterprise.utils.imports, "_XLA_GREATER_EQUAL_2_5", value)
+ monkeypatch.setattr(lightning.pytorch.strategies.xla, "_XLA_AVAILABLE", value)
+ monkeypatch.setattr(lightning.pytorch.strategies.single_xla, "_XLA_AVAILABLE", value)
+ monkeypatch.setattr(lightning.pytorch.plugins.precision.xla, "_XLA_AVAILABLE", value)
+ monkeypatch.setattr(lightning.pytorch.strategies.launchers.xla, "_XLA_AVAILABLE", value)
monkeypatch.setattr(lightning.fabric.accelerators.xla, "_XLA_AVAILABLE", value)
- monkeypatch.setattr(lightning.fabric.accelerators.xla, "_XLA_GREATER_EQUAL_2_1", value)
- monkeypatch.setattr(lightning.fabric.accelerators.xla, "_XLA_GREATER_EQUAL_2_5", value)
- # Patch in the modules where they're used after import
- monkeypatch.setattr("pytorch_lightning_enterprise.accelerators.xla._XLA_AVAILABLE", value)
- monkeypatch.setattr("pytorch_lightning_enterprise.accelerators.xla._XLA_GREATER_EQUAL_2_1", value)
- monkeypatch.setattr("pytorch_lightning_enterprise.accelerators.xla._XLA_GREATER_EQUAL_2_5", value)
- monkeypatch.setattr("pytorch_lightning_enterprise.plugins.environments.xla._XLA_AVAILABLE", value)
- monkeypatch.setattr("pytorch_lightning_enterprise.plugins.environments.xla._XLA_GREATER_EQUAL_2_1", value)
- monkeypatch.setattr("pytorch_lightning_enterprise.plugins.precision.xla._XLA_AVAILABLE", value)
- monkeypatch.setattr("pytorch_lightning_enterprise.strategies.xla.single._XLA_AVAILABLE", value)
- monkeypatch.setattr("pytorch_lightning_enterprise.strategies.xla.ddp._XLA_AVAILABLE", value)
- monkeypatch.setattr("pytorch_lightning_enterprise.strategies.xla.ddp._XLA_GREATER_EQUAL_2_1", value)
- monkeypatch.setattr("pytorch_lightning_enterprise.strategies.xla.fsdp._XLA_AVAILABLE", value)
- monkeypatch.setattr("pytorch_lightning_enterprise.strategies.xla.launcher._XLA_AVAILABLE", value)
+ monkeypatch.setattr(lightning.fabric.plugins.environments.xla, "_XLA_AVAILABLE", value)
+ monkeypatch.setattr(lightning.fabric.plugins.io.xla, "_XLA_AVAILABLE", value)
+ monkeypatch.setattr(lightning.fabric.strategies.launchers.xla, "_XLA_AVAILABLE", value)
@pytest.fixture
@@ -273,14 +237,10 @@ def xla_available(monkeypatch: pytest.MonkeyPatch) -> None:
def mock_tpu_available(monkeypatch: pytest.MonkeyPatch, value: bool = True) -> None:
mock_xla_available(monkeypatch, value)
+ monkeypatch.setattr(lightning.pytorch.accelerators.xla.XLAAccelerator, "is_available", lambda: value)
monkeypatch.setattr(lightning.fabric.accelerators.xla.XLAAccelerator, "is_available", lambda: value)
+ monkeypatch.setattr(lightning.pytorch.accelerators.xla.XLAAccelerator, "auto_device_count", lambda *_: 8)
monkeypatch.setattr(lightning.fabric.accelerators.xla.XLAAccelerator, "auto_device_count", lambda *_: 8)
- # Also mock the enterprise XLAAccelerator methods
- import pytorch_lightning_enterprise.accelerators.xla
-
- monkeypatch.setattr(pytorch_lightning_enterprise.accelerators.xla.XLAAccelerator, "is_available", lambda: value)
- monkeypatch.setattr(pytorch_lightning_enterprise.accelerators.xla.XLAAccelerator, "auto_device_count", lambda *_: 8)
-
monkeypatch.setitem(sys.modules, "torch_xla", Mock())
monkeypatch.setitem(sys.modules, "torch_xla.core.xla_model", Mock())
monkeypatch.setitem(sys.modules, "torch_xla.experimental", Mock())
diff --git a/tests/tests_pytorch/deprecated_api/test_no_removal_version.py b/tests/tests_pytorch/deprecated_api/test_no_removal_version.py
index 7f47ad0be54bc..0b79b638534fa 100644
--- a/tests/tests_pytorch/deprecated_api/test_no_removal_version.py
+++ b/tests/tests_pytorch/deprecated_api/test_no_removal_version.py
@@ -2,7 +2,6 @@
from unittest.mock import Mock
import pytest
-import pytorch_lightning_enterprise.plugins.precision.bitsandbytes
import torch.nn
import lightning.fabric
@@ -63,7 +62,6 @@ def test_fsdp_precision_plugin():
def test_bitsandbytes_precision_plugin(monkeypatch):
monkeypatch.setattr(lightning.fabric.plugins.precision.bitsandbytes, "_BITSANDBYTES_AVAILABLE", True)
- monkeypatch.setattr(pytorch_lightning_enterprise.plugins.precision.bitsandbytes, "_BITSANDBYTES_AVAILABLE", True)
bitsandbytes_mock = Mock()
monkeypatch.setitem(sys.modules, "bitsandbytes", bitsandbytes_mock)
@@ -109,9 +107,7 @@ def test_precision_plugin():
def test_transformer_engine_precision_plugin(monkeypatch):
- monkeypatch.setattr(
- pytorch_lightning_enterprise.plugins.precision.transformer_engine, "_TRANSFORMER_ENGINE_AVAILABLE", True
- )
+ monkeypatch.setattr(lightning.fabric.plugins.precision.transformer_engine, "_TRANSFORMER_ENGINE_AVAILABLE", True)
transformer_engine_mock = Mock()
monkeypatch.setitem(sys.modules, "transformer_engine", transformer_engine_mock)
monkeypatch.setitem(sys.modules, "transformer_engine.pytorch", Mock())
diff --git a/tests/tests_pytorch/graveyard/test_tpu.py b/tests/tests_pytorch/graveyard/test_tpu.py
index eb8a7bfce07d4..0e010ca53ddcb 100644
--- a/tests/tests_pytorch/graveyard/test_tpu.py
+++ b/tests/tests_pytorch/graveyard/test_tpu.py
@@ -1,25 +1,9 @@
-import os
from importlib import import_module
import pytest
import torch
-# mimics `lightning_utilites.RequirementCache`
-class MockXLAAvailable:
- def __init__(self, available: bool, pkg_name: str = "torch_xla"):
- self.available = available
- self.pkg_name = pkg_name
-
- def __bool__(self):
- return self.available
-
- def __str__(self):
- if self.available:
- return f"Requirement '{self.pkg_name}' met"
- return f"Module not found: {self.pkg_name!r}. HINT: Try running `pip install -U {self.pkg_name}`"
-
-
@pytest.mark.parametrize(
("import_path", "name"),
[
@@ -50,17 +34,8 @@ def test_graveyard_single_tpu(import_path, name):
("lightning.pytorch.plugins.precision.xlabf16", "XLABf16PrecisionPlugin"),
],
)
-def test_graveyard_no_device(import_path, name, monkeypatch):
- monkeypatch.setattr("pytorch_lightning_enterprise.accelerators.xla._XLA_AVAILABLE", MockXLAAvailable(False))
- monkeypatch.setattr("pytorch_lightning_enterprise.plugins.precision.xla._XLA_AVAILABLE", MockXLAAvailable(False))
-
+def test_graveyard_no_device(import_path, name):
module = import_module(import_path)
cls = getattr(module, name)
with pytest.deprecated_call(match="is deprecated"), pytest.raises(ModuleNotFoundError, match="torch_xla"):
cls()
-
- # teardown
- # ideally, we should call the plugin's teardown method, but since the class
- # instantiation itself fails, we directly manipulate the env vars here
- os.environ.pop("XLA_USE_BF16", None)
- os.environ.pop("XLA_USE_F16", None)
diff --git a/tests/tests_pytorch/loggers/conftest.py b/tests/tests_pytorch/loggers/conftest.py
index 79878b980900b..98819eb080eb8 100644
--- a/tests/tests_pytorch/loggers/conftest.py
+++ b/tests/tests_pytorch/loggers/conftest.py
@@ -38,10 +38,8 @@ def mlflow_mock(monkeypatch):
mlflow.tracking = mlflow_tracking
mlflow.entities = mlflow_entities
- monkeypatch.setattr("pytorch_lightning_enterprise.loggers.mlflow._MLFLOW_AVAILABLE", True)
- monkeypatch.setattr("pytorch_lightning_enterprise.loggers.mlflow._MLFLOW_SYNCHRONOUS_AVAILABLE", True)
- monkeypatch.setattr("pytorch_lightning_enterprise.utils.imports._MLFLOW_AVAILABLE", True)
- monkeypatch.setattr("pytorch_lightning_enterprise.utils.imports._MLFLOW_SYNCHRONOUS_AVAILABLE", True)
+ monkeypatch.setattr("lightning.pytorch.loggers.mlflow._MLFLOW_AVAILABLE", True)
+ monkeypatch.setattr("lightning.pytorch.loggers.mlflow._MLFLOW_SYNCHRONOUS_AVAILABLE", True)
return mlflow
@@ -88,8 +86,7 @@ class RunType: # to make isinstance checks pass
wandb.sdk.lib = wandb_sdk_lib
wandb.wandb_run = wandb_wandb_run
- monkeypatch.setattr("pytorch_lightning_enterprise.loggers.wandb._WANDB_AVAILABLE", True)
- monkeypatch.setattr("pytorch_lightning_enterprise.utils.imports._WANDB_AVAILABLE", True)
+ monkeypatch.setattr("lightning.pytorch.loggers.wandb._WANDB_AVAILABLE", True)
return wandb
@@ -112,8 +109,7 @@ def comet_mock(monkeypatch):
comet.start = Mock(name="comet_ml.start", return_value=comet.Experiment())
comet.config = Mock()
- monkeypatch.setattr("pytorch_lightning_enterprise.loggers.comet._COMET_AVAILABLE", True)
- monkeypatch.setattr("pytorch_lightning_enterprise.utils.imports._COMET_AVAILABLE", True)
+ monkeypatch.setattr("lightning.pytorch.loggers.comet._COMET_AVAILABLE", True)
return comet
@@ -129,9 +125,6 @@ def __getitem__(self, item):
def __setitem__(self, key, value):
pass
- def wait(self):
- pass
-
run_mock = MagicMock(spec=RunType, exists=Mock(return_value=False), wait=Mock(), get_structure=MagicMock())
run_mock.get_root_object.return_value = run_mock
@@ -160,6 +153,5 @@ def wait(self):
neptune.types = neptune_types
neptune.utils = neptune_utils
- monkeypatch.setattr("pytorch_lightning_enterprise.loggers.neptune._NEPTUNE_AVAILABLE", True)
- monkeypatch.setattr("pytorch_lightning_enterprise.utils.imports._NEPTUNE_AVAILABLE", True)
+ monkeypatch.setattr("lightning.pytorch.loggers.neptune._NEPTUNE_AVAILABLE", True)
return neptune
diff --git a/tests/tests_pytorch/loggers/test_all.py b/tests/tests_pytorch/loggers/test_all.py
index 416676175cd63..a9cf9af79185b 100644
--- a/tests/tests_pytorch/loggers/test_all.py
+++ b/tests/tests_pytorch/loggers/test_all.py
@@ -70,7 +70,7 @@ def _instantiate_logger(logger_class, save_dir, **override_kwargs):
@mock.patch.dict(os.environ, {})
-@mock.patch("pytorch_lightning_enterprise.loggers.mlflow._get_resolve_tags", Mock())
+@mock.patch("lightning.pytorch.loggers.mlflow._get_resolve_tags", Mock())
@pytest.mark.parametrize("logger_class", ALL_LOGGER_CLASSES)
def test_loggers_fit_test_all(logger_class, mlflow_mock, wandb_mock, comet_mock, neptune_mock, tmp_path, monkeypatch):
"""Verify that basic functionality of all loggers."""
@@ -107,8 +107,8 @@ def log_metrics(self, metrics, step):
if logger_class == CometLogger:
logger.experiment.id = "foo"
- logger.logger_impl._comet_config.offline_directory = None
- logger.logger_impl._project_name = "bar"
+ logger._comet_config.offline_directory = None
+ logger._project_name = "bar"
logger.experiment.get_key.return_value = "SOME_KEY"
if logger_class == NeptuneLogger:
@@ -287,7 +287,7 @@ def _test_logger_initialization(tmp_path, logger_class):
@mock.patch.dict(os.environ, {})
-@mock.patch("pytorch_lightning_enterprise.loggers.mlflow._get_resolve_tags", Mock())
+@mock.patch("lightning.pytorch.loggers.mlflow._get_resolve_tags", Mock())
def test_logger_with_prefix_all(mlflow_mock, wandb_mock, comet_mock, neptune_mock, monkeypatch, tmp_path):
"""Test that prefix is added at the beginning of the metric keys."""
prefix = "tmp"
@@ -335,7 +335,7 @@ def test_logger_with_prefix_all(mlflow_mock, wandb_mock, comet_mock, neptune_moc
logger.experiment.log.assert_called_once_with({"tmp-test": 1.0, "trainer/global_step": 0})
-@mock.patch("pytorch_lightning_enterprise.loggers.mlflow._get_resolve_tags", Mock())
+@mock.patch("lightning.pytorch.loggers.mlflow._get_resolve_tags", Mock())
def test_logger_default_name(mlflow_mock, monkeypatch, tmp_path):
"""Test that the default logger name is lightning_logs."""
# CSV
@@ -358,6 +358,6 @@ def test_logger_default_name(mlflow_mock, monkeypatch, tmp_path):
logger = _instantiate_logger(MLFlowLogger, save_dir=tmp_path)
_ = logger.experiment
- logger.logger_impl._mlflow_client.create_experiment.assert_called_with(name="lightning_logs", artifact_location=ANY)
+ logger._mlflow_client.create_experiment.assert_called_with(name="lightning_logs", artifact_location=ANY)
# on MLFLowLogger `name` refers to the experiment id
# assert logger.experiment.get_experiment(logger.name).name == "lightning_logs"
diff --git a/tests/tests_pytorch/loggers/test_comet.py b/tests/tests_pytorch/loggers/test_comet.py
index 4108d3c3b641c..dae8f617b873e 100644
--- a/tests/tests_pytorch/loggers/test_comet.py
+++ b/tests/tests_pytorch/loggers/test_comet.py
@@ -89,10 +89,10 @@ def test_comet_experiment_is_still_alive_after_training_complete(comet_mock):
logger.finalize("ended")
# Assert that data was saved to comet.com
- logger.logger_impl._experiment.flush.assert_called_once()
+ logger._experiment.flush.assert_called_once()
# Assert that was not ended
- logger.logger_impl._experiment.end.assert_not_called()
+ logger._experiment.end.assert_not_called()
@mock.patch.dict(os.environ, {})
@@ -115,8 +115,8 @@ def test_comet_logger_experiment_name(comet_mock):
experiment_config=comet_mock.ExperimentConfig(),
)
# check that we saved "experiment name" in kwargs as new "name" arg
- assert logger.logger_impl._kwargs["name"] == experiment_name
- assert "experiment_name" not in logger.logger_impl._kwargs
+ assert logger._kwargs["name"] == experiment_name
+ assert "experiment_name" not in logger._kwargs
# check that "experiment name" was passed to experiment config correctly
assert call(experiment_name=experiment_name) not in comet_mock.ExperimentConfig.call_args_list
@@ -130,10 +130,10 @@ def test_comet_version(comet_mock):
experiment_name = "My Name"
logger = CometLogger(api_key=api_key, name=experiment_name)
- assert logger.logger_impl._experiment is not None
+ assert logger._experiment is not None
_ = logger.version
- logger.logger_impl._experiment.get_key.assert_called()
+ logger._experiment.get_key.assert_called()
@mock.patch.dict(os.environ, {})
@@ -146,7 +146,7 @@ def test_comet_epoch_logging(comet_mock, tmp_path, monkeypatch):
{"test": 1},
epoch=1,
step=123,
- prefix=logger.logger_impl._prefix,
+ prefix=logger._prefix,
framework="pytorch-lightning",
)
diff --git a/tests/tests_pytorch/loggers/test_mlflow.py b/tests/tests_pytorch/loggers/test_mlflow.py
index b82800fe22f27..c7f9dbe1fe2c6 100644
--- a/tests/tests_pytorch/loggers/test_mlflow.py
+++ b/tests/tests_pytorch/loggers/test_mlflow.py
@@ -16,13 +16,13 @@
from unittest.mock import MagicMock, Mock
import pytest
-import pytorch_lightning_enterprise.loggers # noqa: F401
-from pytorch_lightning_enterprise.utils.imports import _MLFLOW_AVAILABLE
from lightning.pytorch import Trainer
from lightning.pytorch.demos.boring_classes import BoringModel
from lightning.pytorch.loggers.mlflow import (
+ _MLFLOW_AVAILABLE,
MLFlowLogger,
+ _get_resolve_tags,
)
@@ -30,13 +30,13 @@ def mock_mlflow_run_creation(logger, experiment_name=None, experiment_id=None, r
"""Helper function to simulate mlflow client creating a new (or existing) experiment."""
run = MagicMock()
run.info.run_id = run_id
- logger.logger_impl._mlflow_client.get_experiment_by_name = MagicMock(return_value=experiment_name)
- logger.logger_impl._mlflow_client.create_experiment = MagicMock(return_value=experiment_id)
- logger.logger_impl._mlflow_client.create_run = MagicMock(return_value=run)
+ logger._mlflow_client.get_experiment_by_name = MagicMock(return_value=experiment_name)
+ logger._mlflow_client.create_experiment = MagicMock(return_value=experiment_id)
+ logger._mlflow_client.create_run = MagicMock(return_value=run)
return logger
-@mock.patch("pytorch_lightning_enterprise.loggers.mlflow._get_resolve_tags", Mock())
+@mock.patch("lightning.pytorch.loggers.mlflow._get_resolve_tags", Mock())
def test_mlflow_logger_exists(mlflow_mock, tmp_path):
"""Test launching three independent loggers with either same or different experiment name."""
client = mlflow_mock.tracking.MlflowClient
@@ -57,8 +57,8 @@ def test_mlflow_logger_exists(mlflow_mock, tmp_path):
client.return_value.create_run = MagicMock(return_value=run1)
logger = MLFlowLogger("test", save_dir=str(tmp_path))
- assert logger.logger_impl._experiment_id is None
- assert logger.logger_impl._run_id is None
+ assert logger._experiment_id is None
+ assert logger._run_id is None
_ = logger.experiment
assert logger.experiment_id == "exp-id-1"
assert logger.run_id == "run-id-1"
@@ -96,14 +96,13 @@ def test_mlflow_run_name_setting(tmp_path):
pytest.skip("test for explicit file creation requires mlflow dependency to be installed.")
from mlflow.utils.mlflow_tags import MLFLOW_RUN_NAME
- from pytorch_lightning_enterprise.loggers.mlflow import _get_resolve_tags
resolve_tags = _get_resolve_tags()
tags = resolve_tags({MLFLOW_RUN_NAME: "run-name-1"})
# run_name is appended to tags
logger = MLFlowLogger("test", run_name="run-name-1", save_dir=str(tmp_path))
- logger.logger_impl._mlflow_client = client = Mock()
+ logger._mlflow_client = client = Mock()
logger = mock_mlflow_run_creation(logger, experiment_id="exp-id")
_ = logger.experiment
@@ -123,7 +122,7 @@ def test_mlflow_run_name_setting(tmp_path):
client.create_run.assert_called_with(experiment_id="exp-id", tags=default_tags)
-@mock.patch("pytorch_lightning_enterprise.loggers.mlflow._get_resolve_tags", Mock())
+@mock.patch("lightning.pytorch.loggers.mlflow._get_resolve_tags", Mock())
def test_mlflow_run_id_setting(mlflow_mock, tmp_path):
"""Test that the run_id argument uses the provided run_id."""
client = mlflow_mock.tracking.MlflowClient
@@ -144,7 +143,7 @@ def test_mlflow_run_id_setting(mlflow_mock, tmp_path):
client.reset_mock(return_value=True)
-@mock.patch("pytorch_lightning_enterprise.loggers.mlflow._get_resolve_tags", Mock())
+@mock.patch("lightning.pytorch.loggers.mlflow._get_resolve_tags", Mock())
def test_mlflow_log_dir(mlflow_mock, tmp_path):
"""Test that the trainer saves checkpoints in the logger's save dir."""
client = mlflow_mock.tracking.MlflowClient
@@ -212,8 +211,8 @@ def on_train_epoch_end(self, *args, **kwargs):
assert os.listdir(trainer.checkpoint_callback.dirpath) == [f"epoch=0-step={limit_batches}.ckpt"]
-@mock.patch("pytorch_lightning_enterprise.loggers.mlflow._get_resolve_tags", Mock())
-@mock.patch("pytorch_lightning_enterprise.utils.imports._MLFLOW_AVAILABLE", return_value=True)
+@mock.patch("lightning.pytorch.loggers.mlflow._get_resolve_tags", Mock())
+@mock.patch("lightning.pytorch.loggers.mlflow._MLFLOW_AVAILABLE", return_value=True)
def test_mlflow_experiment_id_retrieved_once(_, mlflow_mock, tmp_path):
"""Test that the logger experiment_id retrieved only once."""
logger = MLFlowLogger("test", save_dir=str(tmp_path))
@@ -223,7 +222,7 @@ def test_mlflow_experiment_id_retrieved_once(_, mlflow_mock, tmp_path):
assert logger.experiment.get_experiment_by_name.call_count == 1
-@mock.patch("pytorch_lightning_enterprise.loggers.mlflow._get_resolve_tags", Mock())
+@mock.patch("lightning.pytorch.loggers.mlflow._get_resolve_tags", Mock())
def test_mlflow_logger_with_unexpected_characters(mlflow_mock, tmp_path):
"""Test that the logger raises warning with special characters not accepted by MLFlow."""
logger = MLFlowLogger("test", save_dir=str(tmp_path))
@@ -233,7 +232,7 @@ def test_mlflow_logger_with_unexpected_characters(mlflow_mock, tmp_path):
logger.log_metrics(metrics)
-@mock.patch("pytorch_lightning_enterprise.loggers.mlflow._get_resolve_tags", Mock())
+@mock.patch("lightning.pytorch.loggers.mlflow._get_resolve_tags", Mock())
def test_mlflow_logger_experiment_calls(mlflow_mock, tmp_path):
"""Test that the logger calls methods on the mlflow experiment correctly."""
time = mlflow_mock.entities.time
@@ -243,7 +242,7 @@ def test_mlflow_logger_experiment_calls(mlflow_mock, tmp_path):
time.return_value = 1
logger = MLFlowLogger("test", save_dir=str(tmp_path), artifact_location="my_artifact_location")
- logger.logger_impl._mlflow_client.get_experiment_by_name.return_value = None
+ logger._mlflow_client.get_experiment_by_name.return_value = None
params = {"test": "test_param"}
logger.log_hyperparams(params)
@@ -261,13 +260,13 @@ def test_mlflow_logger_experiment_calls(mlflow_mock, tmp_path):
)
metric.assert_called_with(key="some_metric", value=10, timestamp=1000, step=0)
- logger.logger_impl._mlflow_client.create_experiment.assert_called_once_with(
+ logger._mlflow_client.create_experiment.assert_called_once_with(
name="test", artifact_location="my_artifact_location"
)
@pytest.mark.parametrize("synchronous", [False, True])
-@mock.patch("pytorch_lightning_enterprise.loggers.mlflow._get_resolve_tags", Mock())
+@mock.patch("lightning.pytorch.loggers.mlflow._get_resolve_tags", Mock())
def test_mlflow_logger_experiment_calls_with_synchronous(mlflow_mock, tmp_path, synchronous):
"""Test that the logger calls methods on the mlflow experiment with the specified synchronous flag."""
@@ -303,8 +302,8 @@ def test_mlflow_logger_experiment_calls_with_synchronous(mlflow_mock, tmp_path,
mlflow_client.create_experiment.assert_called_once_with(name="test", artifact_location="my_artifact_location")
-@mock.patch("pytorch_lightning_enterprise.loggers.mlflow._get_resolve_tags", Mock())
-@mock.patch("pytorch_lightning_enterprise.loggers.mlflow._MLFLOW_SYNCHRONOUS_AVAILABLE", False)
+@mock.patch("lightning.pytorch.loggers.mlflow._get_resolve_tags", Mock())
+@mock.patch.dict("lightning.pytorch.loggers.mlflow.__dict__", {"_MLFLOW_SYNCHRONOUS_AVAILABLE": False})
def test_mlflow_logger_no_synchronous_support(mlflow_mock, tmp_path):
"""Test that the logger does not support synchronous flag."""
time = mlflow_mock.entities.time
@@ -316,7 +315,7 @@ def test_mlflow_logger_no_synchronous_support(mlflow_mock, tmp_path):
MLFlowLogger("test", save_dir=str(tmp_path), artifact_location="my_artifact_location", synchronous=True)
-@mock.patch("pytorch_lightning_enterprise.loggers.mlflow._get_resolve_tags", Mock())
+@mock.patch("lightning.pytorch.loggers.mlflow._get_resolve_tags", Mock())
def test_mlflow_logger_with_long_param_value(mlflow_mock, tmp_path):
"""Test that long parameter values are truncated to 250 characters."""
@@ -334,7 +333,7 @@ def _check_value_length(value, *args, **kwargs):
logger.experiment.log_batch.assert_called_once()
-@mock.patch("pytorch_lightning_enterprise.loggers.mlflow._get_resolve_tags", Mock())
+@mock.patch("lightning.pytorch.loggers.mlflow._get_resolve_tags", Mock())
def test_mlflow_logger_with_many_params(mlflow_mock, tmp_path):
"""Test that when logging more than 100 parameters, it will be split into batches of at most 100 parameters."""
logger = MLFlowLogger("test", save_dir=str(tmp_path))
@@ -353,37 +352,37 @@ def test_mlflow_logger_with_many_params(mlflow_mock, tmp_path):
("finished", "FINISHED"),
],
)
-@mock.patch("pytorch_lightning_enterprise.loggers.mlflow._get_resolve_tags", Mock())
+@mock.patch("lightning.pytorch.loggers.mlflow._get_resolve_tags", Mock())
def test_mlflow_logger_finalize(mlflow_mock, status, expected):
logger = MLFlowLogger("test")
# Pretend we are in a worker process and finalizing
_ = logger.experiment
- assert logger.logger_impl._initialized
+ assert logger._initialized
logger.finalize(status)
logger.experiment.set_terminated.assert_called_once_with(logger.run_id, expected)
-@mock.patch("pytorch_lightning_enterprise.loggers.mlflow._get_resolve_tags", Mock())
+@mock.patch("lightning.pytorch.loggers.mlflow._get_resolve_tags", Mock())
def test_mlflow_logger_finalize_when_exception(mlflow_mock):
logger = MLFlowLogger("test")
# Pretend we are on the main process and failing
- assert logger.logger_impl._mlflow_client
- assert not logger.logger_impl._initialized
+ assert logger._mlflow_client
+ assert not logger._initialized
logger.finalize("failed")
logger.experiment.set_terminated.assert_not_called()
# Pretend we are in a worker process and failing
_ = logger.experiment
- assert logger.logger_impl._initialized
+ assert logger._initialized
logger.finalize("failed")
logger.experiment.set_terminated.assert_called_once_with(logger.run_id, "FAILED")
@pytest.mark.parametrize("log_model", ["all", True, False])
-@mock.patch("pytorch_lightning_enterprise.loggers.mlflow._get_resolve_tags", Mock())
+@mock.patch("lightning.pytorch.loggers.mlflow._get_resolve_tags", Mock())
def test_mlflow_log_model(mlflow_mock, log_model, tmp_path):
"""Test that the logger creates the folders and files in the right place."""
client = mlflow_mock.tracking.MlflowClient
@@ -421,7 +420,7 @@ def test_mlflow_log_model(mlflow_mock, log_model, tmp_path):
assert not client.return_value.log_artifacts.called
-@mock.patch("pytorch_lightning_enterprise.loggers.mlflow._get_resolve_tags", Mock())
+@mock.patch("lightning.pytorch.loggers.mlflow._get_resolve_tags", Mock())
def test_set_tracking_uri(mlflow_mock):
"""Test that the tracking uri is set for logging artifacts to MLFlow server."""
logger = MLFlowLogger(tracking_uri="the_tracking_uri")
diff --git a/tests/tests_pytorch/loggers/test_neptune.py b/tests/tests_pytorch/loggers/test_neptune.py
index e50bcf967ef9a..572b85715da86 100644
--- a/tests/tests_pytorch/loggers/test_neptune.py
+++ b/tests/tests_pytorch/loggers/test_neptune.py
@@ -13,6 +13,7 @@
# limitations under the License.
import os
import pickle
+from collections import namedtuple
from unittest import mock
from unittest.mock import MagicMock, call
@@ -44,10 +45,10 @@ def _fit_and_test(logger, model, tmp_path):
def _get_logger_with_mocks(**kwargs):
logger = NeptuneLogger(**kwargs)
run_instance_mock = MagicMock()
- logger.logger_impl._run_instance = run_instance_mock
- logger.logger_impl._run_instance.__getitem__.return_value.fetch.return_value = "exp-name"
+ logger._run_instance = run_instance_mock
+ logger._run_instance.__getitem__.return_value.fetch.return_value = "exp-name"
run_attr_mock = MagicMock()
- logger.logger_impl._run_instance.__getitem__.return_value = run_attr_mock
+ logger._run_instance.__getitem__.return_value = run_attr_mock
return logger, run_instance_mock, run_attr_mock
@@ -59,7 +60,7 @@ def test_neptune_online(neptune_mock):
logger = NeptuneLogger(api_key="test", project="project")
created_run_mock = logger.run
- assert logger.logger_impl._run_instance == created_run_mock
+ assert logger._run_instance == created_run_mock
created_run_mock.exists.assert_called_once_with("sys/id")
assert logger.name == "Run test name"
assert logger.version == "TEST-1"
@@ -78,8 +79,8 @@ def test_neptune_offline(neptune_mock):
logger.experiment["foo"] = "bar"
created_run_mock.exists.assert_called_once_with("sys/id")
- assert logger.logger_impl._run_short_id == "OFFLINE"
- assert logger.logger_impl._run_name == "offline-name"
+ assert logger._run_short_id == "OFFLINE"
+ assert logger._run_name == "offline-name"
def test_online_with_custom_run(neptune_mock):
@@ -90,8 +91,8 @@ def test_online_with_custom_run(neptune_mock):
neptune_mock.init_run.reset_mock()
logger = NeptuneLogger(run=created_run)
- assert logger.logger_impl._run_instance == created_run
- assert logger.logger_impl._run_instance == created_run
+ assert logger._run_instance == created_run
+ assert logger._run_instance == created_run
assert logger.version == "TEST-1"
assert neptune_mock.init_run.call_count == 0
@@ -274,6 +275,38 @@ def test_save_dir(neptune_mock):
assert logger.save_dir == os.path.join(os.getcwd(), ".neptune")
+def test_get_full_model_name():
+ SimpleCheckpoint = namedtuple("SimpleCheckpoint", ["dirpath"])
+ test_input_data = [
+ ("key", os.path.join("foo", "bar", "key.ext"), SimpleCheckpoint(dirpath=os.path.join("foo", "bar"))),
+ (
+ "key/in/parts",
+ os.path.join("foo", "bar", "key/in/parts.ext"),
+ SimpleCheckpoint(dirpath=os.path.join("foo", "bar")),
+ ),
+ ("key", os.path.join("../foo", "bar", "key.ext"), SimpleCheckpoint(dirpath=os.path.join("../foo", "bar"))),
+ ("key", os.path.join("foo", "key.ext"), SimpleCheckpoint(dirpath=os.path.join("./foo", "bar/../"))),
+ ]
+
+ for expected_model_name, model_path, checkpoint in test_input_data:
+ assert NeptuneLogger._get_full_model_name(model_path, checkpoint) == expected_model_name
+
+
+def test_get_full_model_names_from_exp_structure():
+ input_dict = {
+ "foo": {
+ "bar": {
+ "lvl1_1": {"lvl2": {"lvl3_1": "some non important value", "lvl3_2": "some non important value"}},
+ "lvl1_2": "some non important value",
+ },
+ "other_non_important": {"val100": 100},
+ },
+ "other_non_important": {"val42": 42},
+ }
+ expected_keys = {"lvl1_1/lvl2/lvl3_1", "lvl1_1/lvl2/lvl3_2", "lvl1_2"}
+ assert NeptuneLogger._get_full_model_names_from_exp_structure(input_dict, "foo/bar") == expected_keys
+
+
def test_inactive_run(neptune_mock, tmp_path, monkeypatch):
from neptune.exceptions import InactiveRunException
diff --git a/tests/tests_pytorch/loggers/test_wandb.py b/tests/tests_pytorch/loggers/test_wandb.py
index 0409cbd7b8f2c..e9b9e9a8090b0 100644
--- a/tests/tests_pytorch/loggers/test_wandb.py
+++ b/tests/tests_pytorch/loggers/test_wandb.py
@@ -25,6 +25,7 @@
from lightning.pytorch.cli import LightningCLI
from lightning.pytorch.demos.boring_classes import BoringModel
from lightning.pytorch.loggers import TensorBoardLogger, WandbLogger
+from lightning.pytorch.utilities.exceptions import MisconfigurationException
def test_wandb_project_name(wandb_mock):
@@ -100,7 +101,7 @@ def test_wandb_logger_init(wandb_mock):
_ = logger.experiment
# verify default resume value
- assert logger.logger_impl._wandb_init["resume"] == "allow"
+ assert logger._wandb_init["resume"] == "allow"
logger.log_metrics({"acc": 1.0}, step=3)
wandb_mock.init.assert_called_once()
@@ -144,9 +145,9 @@ def test_wandb_logger_sync_tensorboard_log_metrics(wandb_mock):
def test_wandb_logger_init_before_spawn(wandb_mock):
logger = WandbLogger()
- assert logger.logger_impl._experiment is None
- logger.logger_impl.__getstate__()
- assert logger.logger_impl._experiment is not None
+ assert logger._experiment is None
+ logger.__getstate__()
+ assert logger._experiment is not None
def test_wandb_logger_experiment_called_first(wandb_mock, tmp_path):
@@ -625,7 +626,7 @@ def test_wandb_log_media(wandb_mock, tmp_path):
def test_wandb_logger_offline_log_model(wandb_mock, tmp_path):
"""Test that log_model=True raises an error in offline mode."""
- with pytest.raises(ValueError, match="checkpoints cannot be uploaded in offline mode"):
+ with pytest.raises(MisconfigurationException, match="checkpoints cannot be uploaded in offline mode"):
_ = WandbLogger(save_dir=tmp_path, offline=True, log_model=True)
diff --git a/tests/tests_pytorch/models/test_tpu.py b/tests/tests_pytorch/models/test_tpu.py
index f1878a4a2650d..e74778ee32f4e 100644
--- a/tests/tests_pytorch/models/test_tpu.py
+++ b/tests/tests_pytorch/models/test_tpu.py
@@ -235,7 +235,7 @@ def test_tpu_misconfiguration(devices, tpu_available):
@pytest.mark.skipif(XLAAccelerator.is_available(), reason="test requires missing TPU")
-@mock.patch("pytorch_lightning_enterprise.accelerators.xla._using_pjrt", return_value=True)
+@mock.patch("lightning.fabric.accelerators.xla._using_pjrt", return_value=True)
def test_exception_when_no_tpu_found(_, xla_available):
"""Test if exception is thrown when xla devices are not available."""
with pytest.raises(MisconfigurationException, match="XLAAccelerator` can not run on your system"):
diff --git a/tests/tests_pytorch/plugins/precision/test_deepspeed_precision.py b/tests/tests_pytorch/plugins/precision/test_deepspeed_precision.py
index 86eb93c4542a0..da4ce6b89aaab 100644
--- a/tests/tests_pytorch/plugins/precision/test_deepspeed_precision.py
+++ b/tests/tests_pytorch/plugins/precision/test_deepspeed_precision.py
@@ -19,7 +19,7 @@
def test_invalid_precision_with_deepspeed_precision():
- with pytest.raises(ValueError, match=r"is not supported in DeepSpeed. `precision` must be one of"):
+ with pytest.raises(ValueError, match="is not supported. `precision` must be one of"):
DeepSpeedPrecision(precision="64-true")
diff --git a/tests/tests_pytorch/plugins/precision/test_transformer_engine.py b/tests/tests_pytorch/plugins/precision/test_transformer_engine.py
index 6055f9ecd6e60..a9967280e3f23 100644
--- a/tests/tests_pytorch/plugins/precision/test_transformer_engine.py
+++ b/tests/tests_pytorch/plugins/precision/test_transformer_engine.py
@@ -16,16 +16,16 @@
from unittest.mock import ANY, Mock
import pytest
-import pytorch_lightning_enterprise.utils.imports
import torch
+import lightning.fabric
from lightning.pytorch import LightningModule, Trainer
from lightning.pytorch.plugins import TransformerEnginePrecision
from lightning.pytorch.trainer.connectors.accelerator_connector import _AcceleratorConnector
def test_transformer_engine_precision_plugin(monkeypatch):
- module = pytorch_lightning_enterprise.plugins.precision.transformer_engine
+ module = lightning.fabric.plugins.precision.transformer_engine
if module._TRANSFORMER_ENGINE_AVAILABLE:
pytest.skip("Assumes transformer_engine is unavailable")
monkeypatch.setattr(module, "_TRANSFORMER_ENGINE_AVAILABLE", lambda: True)
@@ -44,7 +44,7 @@ def test_transformer_engine_precision_plugin(monkeypatch):
def test_configure_model(monkeypatch):
- module = pytorch_lightning_enterprise.plugins.precision.transformer_engine
+ module = lightning.fabric.plugins.precision.transformer_engine
if module._TRANSFORMER_ENGINE_AVAILABLE:
pytest.skip("Assumes transformer_engine is unavailable")
monkeypatch.setattr(module, "_TRANSFORMER_ENGINE_AVAILABLE", lambda: True)
diff --git a/tests/tests_pytorch/strategies/test_deepspeed.py b/tests/tests_pytorch/strategies/test_deepspeed.py
index c4a1800743dd0..047e2d24f77a3 100644
--- a/tests/tests_pytorch/strategies/test_deepspeed.py
+++ b/tests/tests_pytorch/strategies/test_deepspeed.py
@@ -33,6 +33,7 @@
from lightning.pytorch.loggers import CSVLogger
from lightning.pytorch.plugins import DeepSpeedPrecision
from lightning.pytorch.strategies.deepspeed import DeepSpeedStrategy
+from lightning.pytorch.utilities.exceptions import MisconfigurationException
from lightning.pytorch.utilities.imports import _TORCHMETRICS_GREATER_EQUAL_0_11 as _TM_GE_0_11
from tests_pytorch.helpers.datamodules import ClassifDataModule
from tests_pytorch.helpers.runif import RunIf
@@ -152,7 +153,7 @@ def test_deepspeed_precision_choice(cuda_count_1, tmp_path):
def test_deepspeed_with_invalid_config_path():
"""Test to ensure if we pass an invalid config path we throw an exception."""
with pytest.raises(
- FileNotFoundError, match="You passed in a path to a DeepSpeed config but the path does not exist"
+ MisconfigurationException, match="You passed in a path to a DeepSpeed config but the path does not exist"
):
DeepSpeedStrategy(config="invalid_path.json")
@@ -1003,7 +1004,7 @@ def test_deepspeed_strategy_env_variables(mock_deepspeed_distributed, tmp_path,
strategy = trainer.strategy
assert isinstance(strategy, DeepSpeedStrategy)
with mock.patch("platform.system", return_value=platform) as mock_platform:
- strategy.deepspeed_strategy_impl._init_deepspeed_distributed()
+ strategy._init_deepspeed_distributed()
mock_deepspeed_distributed.assert_called()
mock_platform.assert_called()
if platform == "Windows":
@@ -1078,7 +1079,7 @@ def training_step(self, batch, batch_idx):
enable_progress_bar=False,
enable_model_summary=False,
)
- with pytest.raises(ValueError, match="returning `None` .* is not supported"):
+ with pytest.raises(MisconfigurationException, match="returning `None` .* is not supported"):
trainer.fit(model)
@@ -1157,7 +1158,7 @@ def test_deepspeed_gradient_clip_by_value(tmp_path):
enable_progress_bar=False,
enable_model_summary=False,
)
- with pytest.raises(ValueError, match="does not support clipping gradients by value"):
+ with pytest.raises(MisconfigurationException, match="does not support clipping gradients by value"):
trainer.fit(model)
@@ -1266,7 +1267,6 @@ def test_validate_parallel_devices_indices(device_indices):
"""
accelerator = Mock(spec=CUDAAccelerator)
- accelerator.name.return_value = "cuda"
strategy = DeepSpeedStrategy(
accelerator=accelerator, parallel_devices=[torch.device("cuda", i) for i in device_indices]
)
diff --git a/tests/tests_pytorch/strategies/test_xla.py b/tests/tests_pytorch/strategies/test_xla.py
index 0fd2afa98690b..3fde2600c9483 100644
--- a/tests/tests_pytorch/strategies/test_xla.py
+++ b/tests/tests_pytorch/strategies/test_xla.py
@@ -50,14 +50,14 @@ def test_rank_properties_access(xla_available):
strategy.cluster_environment = Mock()
# we're in the main process, no processes have been launched yet
- assert not strategy.xla_strategy_impl._launched
+ assert not strategy._launched
assert strategy.global_rank == 0
assert strategy.local_rank == 0
assert strategy.node_rank == 0
assert strategy.world_size == 1
# simulate we're in a worker process
- strategy.xla_strategy_impl._launched = True
+ strategy._launched = True
assert strategy.global_rank == strategy.cluster_environment.global_rank()
assert strategy.local_rank == strategy.cluster_environment.local_rank()
assert strategy.node_rank == strategy.cluster_environment.node_rank()
diff --git a/tests/tests_pytorch/trainer/connectors/test_accelerator_connector.py b/tests/tests_pytorch/trainer/connectors/test_accelerator_connector.py
index 114fb02364a09..99ce4f5f975c8 100644
--- a/tests/tests_pytorch/trainer/connectors/test_accelerator_connector.py
+++ b/tests/tests_pytorch/trainer/connectors/test_accelerator_connector.py
@@ -968,7 +968,6 @@ def test_precision_selection(precision_str, strategy_str, expected_precision_cls
def test_bitsandbytes_precision_cuda_required(monkeypatch):
monkeypatch.setattr(lightning.fabric.plugins.precision.bitsandbytes, "_BITSANDBYTES_AVAILABLE", True)
- monkeypatch.setattr("pytorch_lightning_enterprise.plugins.precision.bitsandbytes._BITSANDBYTES_AVAILABLE", True)
monkeypatch.setitem(sys.modules, "bitsandbytes", Mock())
with pytest.raises(RuntimeError, match="Bitsandbytes is only supported on CUDA GPUs"):
_AcceleratorConnector(accelerator="cpu", plugins=BitsandbytesPrecision(mode="int8"))
diff --git a/tests/tests_pytorch/utilities/migration/test_utils.py b/tests/tests_pytorch/utilities/migration/test_utils.py
index b8b13af7b4710..41ca1a779f8a5 100644
--- a/tests/tests_pytorch/utilities/migration/test_utils.py
+++ b/tests/tests_pytorch/utilities/migration/test_utils.py
@@ -101,10 +101,7 @@ def test_patch_legacy_imports_unified(pl_version):
assert any(key.startswith("lightning." + "pytorch") for key in sys.modules), (
f"Imported unified package, so it has to be in sys.modules: {_list_sys_modules('lightning' + '.pytorch')}"
)
- assert not any(
- key.startswith("pytorch_lightning") and not key.startswith("pytorch_lightning_enterprise")
- for key in sys.modules
- ), (
+ assert not any(key.startswith("pytorch_lightning") for key in sys.modules), (
"Should not import standalone package, all imports should be redirected to the unified package;\n"
f" environment: {_list_sys_modules('pytorch_lightning')}"
)
@@ -125,10 +122,7 @@ def test_patch_legacy_imports_unified(pl_version):
assert any(key.startswith("lightning." + "pytorch") for key in sys.modules), (
f"Imported unified package, so it has to be in sys.modules: {_list_sys_modules('lightning' + '.pytorch')}"
)
- assert not any(
- key.startswith("pytorch_lightning") and not key.startswith("pytorch_lightning_enterprise")
- for key in sys.modules
- ), (
+ assert not any(key.startswith("pytorch_lightning") for key in sys.modules), (
"Should not import standalone package, all imports should be redirected to the unified package;\n"
f" environment: {_list_sys_modules('pytorch_lightning')}"
)