diff --git a/.github/workflows/docs-build.yml b/.github/workflows/docs-build.yml index db2e4517b79e9..31af5cc30176f 100644 --- a/.github/workflows/docs-build.yml +++ b/.github/workflows/docs-build.yml @@ -125,10 +125,7 @@ jobs: working-directory: ./docs/source-${{ matrix.pkg-name }} # allow failing link check and doctest if you run with dispatch continue-on-error: ${{ (matrix.target == 'doctest' || matrix.target == 'linkcheck') && github.event_name == 'workflow_dispatch' }} - run: | - # temp fix: https://github.com/Lightning-AI/pytorch-lightning/actions/runs/19440502586/job/55622388642?pr=21354#step:11:4596 - uv pip install -U fastapi - make ${{ matrix.target }} --debug --jobs $(nproc) SPHINXOPTS="$BUILD_SPHINX_OPTS" + run: make ${{ matrix.target }} --debug --jobs $(nproc) SPHINXOPTS="$BUILD_SPHINX_OPTS" - name: Keep artifact if: github.event_name == 'pull_request' diff --git a/docs/source-pytorch/conf.py b/docs/source-pytorch/conf.py index 67b1c8bf7bfd6..db3f5334ed162 100644 --- a/docs/source-pytorch/conf.py +++ b/docs/source-pytorch/conf.py @@ -604,7 +604,10 @@ def package_list_from_file(file): from lightning.pytorch.cli import _JSONARGPARSE_SIGNATURES_AVAILABLE as _JSONARGPARSE_AVAILABLE from lightning.pytorch.utilities.imports import _TORCHVISION_AVAILABLE from lightning.fabric.loggers.tensorboard import _TENSORBOARD_AVAILABLE, _TENSORBOARDX_AVAILABLE -from lightning.fabric.utilities.imports import _COMET_AVAILABLE, _MLFLOW_AVAILABLE, _NEPTUNE_AVAILABLE, _WANDB_AVAILABLE +from lightning.pytorch.loggers.neptune import _NEPTUNE_AVAILABLE +from lightning.pytorch.loggers.comet import _COMET_AVAILABLE +from lightning.pytorch.loggers.mlflow import _MLFLOW_AVAILABLE +from lightning.pytorch.loggers.wandb import _WANDB_AVAILABLE """ coverage_skip_undoc_in_source = True diff --git a/requirements/fabric/base.txt b/requirements/fabric/base.txt index d187769117be6..fdc814c7f7fa1 100644 --- a/requirements/fabric/base.txt +++ b/requirements/fabric/base.txt @@ -6,4 +6,3 @@ fsspec[http] >=2022.5.0, <2025.11.0 packaging >=20.0, <=25.0 typing-extensions >4.5.0, <4.16.0 lightning-utilities >=0.10.0, <0.16.0 -pytorch-lightning-enterprise >=2.6.0 diff --git a/requirements/pytorch/base.txt b/requirements/pytorch/base.txt index 65a3f7484fb7d..014b223b1f012 100644 --- a/requirements/pytorch/base.txt +++ b/requirements/pytorch/base.txt @@ -9,4 +9,3 @@ torchmetrics >0.7.0, <1.9.0 packaging >=20.0, <=25.0 typing-extensions >4.5.0, <4.16.0 lightning-utilities >=0.10.0, <0.16.0 -pytorch-lightning-enterprise >=2.6.0 diff --git a/src/lightning/fabric/accelerators/xla.py b/src/lightning/fabric/accelerators/xla.py index 0cf42821424ee..db2cf2586e1ba 100644 --- a/src/lightning/fabric/accelerators/xla.py +++ b/src/lightning/fabric/accelerators/xla.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. import functools -import warnings from typing import Any, Union import torch @@ -21,11 +20,7 @@ from lightning.fabric.accelerators.accelerator import Accelerator from lightning.fabric.accelerators.registry import _AcceleratorRegistry -from lightning.fabric.utilities.imports import _raise_enterprise_not_available - -_XLA_AVAILABLE = RequirementCache("torch_xla>=1.13", "torch_xla") -_XLA_GREATER_EQUAL_2_1 = RequirementCache("torch_xla>=2.1") -_XLA_GREATER_EQUAL_2_5 = RequirementCache("torch_xla>=2.5") +from lightning.fabric.utilities.device_parser import _check_data_type class XLAAccelerator(Accelerator): @@ -36,38 +31,38 @@ class XLAAccelerator(Accelerator): """ def __init__(self, *args: Any, **kwargs: Any) -> None: - _raise_enterprise_not_available() + if not _XLA_AVAILABLE: + raise ModuleNotFoundError(str(_XLA_AVAILABLE)) + if not _using_pjrt(): + raise RuntimeError("The XLA XRT runtime is not supported anymore.") super().__init__(*args, **kwargs) - from pytorch_lightning_enterprise.accelerators.xla import XLAAccelerator as EnterpriseXLAAccelerator - - self.accelerator_impl = EnterpriseXLAAccelerator(*args, **kwargs) - @override def setup_device(self, device: torch.device) -> None: - return self.accelerator_impl.setup_device(device) + pass @override def teardown(self) -> None: - return self.accelerator_impl.teardown() + pass @staticmethod @override def parse_devices(devices: Union[int, str, list[int]]) -> Union[int, list[int]]: """Accelerator device parsing logic.""" - _raise_enterprise_not_available() - from pytorch_lightning_enterprise.accelerators.xla import XLAAccelerator as EnterpriseXLAAccelerator - - return EnterpriseXLAAccelerator.parse_devices(devices) + return _parse_tpu_devices(devices) @staticmethod @override def get_parallel_devices(devices: Union[int, list[int]]) -> list[torch.device]: """Gets parallel devices for the Accelerator.""" - _raise_enterprise_not_available() - from pytorch_lightning_enterprise.accelerators.xla import XLAAccelerator as EnterpriseXLAAccelerator - - return EnterpriseXLAAccelerator.get_parallel_devices(devices) + devices = _parse_tpu_devices(devices) + if isinstance(devices, int): + return [torch.device("xla", i) for i in range(devices)] + # list of devices is not supported, just a specific index, fine to access [0] + return [torch.device("xla", devices[0])] + # we cannot create `xla_device` here because processes have not been spawned yet (this is called in the + # accelerator connector init). However, there doesn't seem to be a problem with instantiating `torch.device`. + # it will be replaced with `xla_device` (also a torch.device`, but with extra logic) in the strategy @staticmethod @override @@ -76,10 +71,16 @@ def get_parallel_devices(devices: Union[int, list[int]]) -> list[torch.device]: @functools.lru_cache(maxsize=1) def auto_device_count() -> int: """Get the devices when set to auto.""" - _raise_enterprise_not_available() - from pytorch_lightning_enterprise.accelerators.xla import XLAAccelerator as EnterpriseXLAAccelerator + if not _XLA_AVAILABLE: + return 0 + if _XLA_GREATER_EQUAL_2_1: + from torch_xla._internal import tpu + + return tpu.num_available_devices() + from torch_xla.experimental import tpu - return EnterpriseXLAAccelerator.auto_device_count() + device_count_on_version = {2: 8, 3: 8, 4: 4} + return device_count_on_version.get(tpu.version(), 8) @staticmethod @override @@ -91,9 +92,6 @@ def is_available() -> bool: # XLA may raise these exceptions if it's not properly configured. This needs to be avoided for the cases # when `torch_xla` is imported but not used return False - except ModuleNotFoundError as e: - warnings.warn(str(e)) - return False @staticmethod @override @@ -108,3 +106,74 @@ def register_accelerators(cls, accelerator_registry: _AcceleratorRegistry) -> No cls, description=cls.__name__, ) + + +# PJRT support requires this minimum version +_XLA_AVAILABLE = RequirementCache("torch_xla>=1.13", "torch_xla") +_XLA_GREATER_EQUAL_2_1 = RequirementCache("torch_xla>=2.1") +_XLA_GREATER_EQUAL_2_5 = RequirementCache("torch_xla>=2.5") + + +def _using_pjrt() -> bool: + # `using_pjrt` is removed in torch_xla 2.5 + if _XLA_GREATER_EQUAL_2_5: + from torch_xla import runtime as xr + + return xr.device_type() is not None + # delete me when torch_xla 2.2 is the min supported version, where XRT support has been dropped. + if _XLA_GREATER_EQUAL_2_1: + from torch_xla import runtime as xr + + return xr.using_pjrt() + + from torch_xla.experimental import pjrt + + return pjrt.using_pjrt() + + +def _parse_tpu_devices(devices: Union[int, str, list[int]]) -> Union[int, list[int]]: + """Parses the TPU devices given in the format as accepted by the + :class:`~lightning.pytorch.trainer.trainer.Trainer` and :class:`~lightning.fabric.Fabric`. + + Args: + devices: An int of 1 or string '1' indicates that 1 core with multi-processing should be used + An int 8 or string '8' indicates that all 8 cores with multi-processing should be used + A single element list of int or string can be used to indicate the specific TPU core to use. + + Returns: + A list of tpu cores to be used. + + """ + _check_data_type(devices) + if isinstance(devices, str): + devices = _parse_tpu_devices_str(devices) + _check_tpu_devices_valid(devices) + return devices + + +def _check_tpu_devices_valid(devices: object) -> None: + device_count = XLAAccelerator.auto_device_count() + if ( + # support number of devices + isinstance(devices, int) + and devices in {1, device_count} + # support picking a specific device + or isinstance(devices, (list, tuple)) + and len(devices) == 1 + and 0 <= devices[0] <= device_count - 1 + ): + return + raise ValueError( + f"`devices` can only be 'auto', 1, {device_count} or [<0-{device_count - 1}>] for TPUs. Got {devices!r}" + ) + + +def _parse_tpu_devices_str(devices: str) -> Union[int, list[int]]: + devices = devices.strip() + try: + return int(devices) + except ValueError: + try: + return [int(x.strip()) for x in devices.split(",") if len(x) > 0] + except ValueError: + raise ValueError(f"Could not parse the selected TPU devices: {devices!r}") diff --git a/src/lightning/fabric/plugins/environments/kubeflow.py b/src/lightning/fabric/plugins/environments/kubeflow.py index 8013ec2a65720..23a1c0d1753af 100644 --- a/src/lightning/fabric/plugins/environments/kubeflow.py +++ b/src/lightning/fabric/plugins/environments/kubeflow.py @@ -13,11 +13,11 @@ # limitations under the License. import logging +import os from typing_extensions import override from lightning.fabric.plugins.environments.cluster_environment import ClusterEnvironment -from lightning.fabric.utilities.imports import _raise_enterprise_not_available log = logging.getLogger(__name__) @@ -33,28 +33,20 @@ class KubeflowEnvironment(ClusterEnvironment): """ - def __init__(self) -> None: - _raise_enterprise_not_available() - from pytorch_lightning_enterprise.plugins.environments.kubeflow import ( - KubeflowEnvironment as EnterpriseKubeflowEnvironment, - ) - - self.kubeflow_impl = EnterpriseKubeflowEnvironment() - @property @override def creates_processes_externally(self) -> bool: - return self.kubeflow_impl.creates_processes_externally + return True @property @override def main_address(self) -> str: - return self.kubeflow_impl.main_address + return os.environ["MASTER_ADDR"] @property @override def main_port(self) -> int: - return self.kubeflow_impl.main_port + return int(os.environ["MASTER_PORT"]) @staticmethod @override @@ -63,24 +55,24 @@ def detect() -> bool: @override def world_size(self) -> int: - return self.kubeflow_impl.world_size() + return int(os.environ["WORLD_SIZE"]) @override def set_world_size(self, size: int) -> None: - return self.kubeflow_impl.set_world_size(size) + log.debug("KubeflowEnvironment.set_world_size was called, but setting world size is not allowed. Ignored.") @override def global_rank(self) -> int: - return self.kubeflow_impl.global_rank() + return int(os.environ["RANK"]) @override def set_global_rank(self, rank: int) -> None: - return self.kubeflow_impl.set_global_rank(rank) + log.debug("KubeflowEnvironment.set_global_rank was called, but setting global rank is not allowed. Ignored.") @override def local_rank(self) -> int: - return self.kubeflow_impl.local_rank() + return 0 @override def node_rank(self) -> int: - return self.kubeflow_impl.node_rank() + return self.global_rank() diff --git a/src/lightning/fabric/plugins/environments/lsf.py b/src/lightning/fabric/plugins/environments/lsf.py index 0b62bf502303a..f0a07d61d9f03 100644 --- a/src/lightning/fabric/plugins/environments/lsf.py +++ b/src/lightning/fabric/plugins/environments/lsf.py @@ -13,11 +13,12 @@ # limitations under the License. import logging import os +import socket from typing_extensions import override from lightning.fabric.plugins.environments.cluster_environment import ClusterEnvironment -from lightning.fabric.utilities.imports import _raise_enterprise_not_available +from lightning.fabric.utilities.cloud_io import get_filesystem log = logging.getLogger(__name__) @@ -49,32 +50,36 @@ class LSFEnvironment(ClusterEnvironment): def __init__(self) -> None: super().__init__() - - _raise_enterprise_not_available() - from pytorch_lightning_enterprise.plugins.environments.lsf import ( - LSFEnvironment as EnterpriseLSFEnvironment, - ) - - self.lsf_impl = EnterpriseLSFEnvironment() + self._main_address = self._get_main_address() + self._main_port = self._get_main_port() + self._node_rank = self._get_node_rank() + self._set_init_progress_group_env_vars() + + def _set_init_progress_group_env_vars(self) -> None: + # set environment variables needed for initializing torch distributed process group + os.environ["MASTER_ADDR"] = str(self._main_address) + log.debug(f"MASTER_ADDR: {os.environ['MASTER_ADDR']}") + os.environ["MASTER_PORT"] = str(self._main_port) + log.debug(f"MASTER_PORT: {os.environ['MASTER_PORT']}") @property @override def creates_processes_externally(self) -> bool: """LSF creates subprocesses, i.e., PyTorch Lightning does not need to spawn them.""" - return self.lsf_impl.creates_processes_externally + return True @property @override def main_address(self) -> str: """The main address is read from an OpenMPI host rank file in the environment variable ``LSB_DJOB_RANKFILE``.""" - return self.lsf_impl.main_address + return self._main_address @property @override def main_port(self) -> int: """The main port is calculated from the LSF job ID.""" - return self.lsf_impl.main_port + return self._main_port @staticmethod @override @@ -86,28 +91,110 @@ def detect() -> bool: @override def world_size(self) -> int: """The world size is read from the environment variable ``JSM_NAMESPACE_SIZE``.""" - return self.lsf_impl.world_size() + world_size = os.environ.get("JSM_NAMESPACE_SIZE") + if world_size is None: + raise ValueError( + "Cannot determine world size. Environment variable `JSM_NAMESPACE_SIZE` not found." + " Make sure you run your executable with `jsrun`." + ) + return int(world_size) @override def set_world_size(self, size: int) -> None: - return self.lsf_impl.set_world_size(size) + log.debug("LSFEnvironment.set_world_size was called, but setting world size is not allowed. Ignored.") @override def global_rank(self) -> int: """The world size is read from the environment variable ``JSM_NAMESPACE_RANK``.""" - return self.lsf_impl.global_rank() + global_rank = os.environ.get("JSM_NAMESPACE_RANK") + if global_rank is None: + raise ValueError( + "Cannot determine global rank. Environment variable `JSM_NAMESPACE_RANK` not found." + " Make sure you run your executable with `jsrun`." + ) + return int(global_rank) @override def set_global_rank(self, rank: int) -> None: - return self.lsf_impl.set_global_rank(rank) + log.debug("LSFEnvironment.set_global_rank was called, but setting global rank is not allowed. Ignored.") @override def local_rank(self) -> int: """The local rank is read from the environment variable `JSM_NAMESPACE_LOCAL_RANK`.""" - return self.lsf_impl.local_rank() + local_rank = os.environ.get("JSM_NAMESPACE_LOCAL_RANK") + if local_rank is None: + raise ValueError( + "Cannot determine local rank. Environment variable `JSM_NAMESPACE_LOCAL_RANK` not found." + " Make sure you run your executable with `jsrun`." + ) + return int(local_rank) @override def node_rank(self) -> int: """The node rank is determined by the position of the current hostname in the OpenMPI host rank file stored in ``LSB_DJOB_RANKFILE``.""" - return self.lsf_impl.node_rank() + return self._node_rank + + def _get_node_rank(self) -> int: + """A helper method for getting the node rank. + + The node rank is determined by the position of the current node in the list of hosts used in the job. This is + calculated by reading all hosts from ``LSB_DJOB_RANKFILE`` and finding this node's hostname in the list. + + """ + hosts = self._read_hosts() + count: dict[str, int] = {} + for host in hosts: + if host not in count: + count[host] = len(count) + return count[socket.gethostname()] + + @staticmethod + def _read_hosts() -> list[str]: + """Read compute hosts that are a part of the compute job. + + LSF uses the Job Step Manager (JSM) to manage job steps. Job steps are executed by the JSM from "launch" nodes. + Each job is assigned a launch node. This launch node will be the first node in the list contained in + ``LSB_DJOB_RANKFILE``. + + """ + var = "LSB_DJOB_RANKFILE" + rankfile = os.environ.get(var) + if rankfile is None: + raise ValueError("Did not find the environment variable `LSB_DJOB_RANKFILE`") + if not rankfile: + raise ValueError("The environment variable `LSB_DJOB_RANKFILE` is empty") + + fs = get_filesystem(rankfile) + with fs.open(rankfile, "r") as f: + ret = [line.strip() for line in f] + # remove the launch node (i.e. the first node in LSB_DJOB_RANKFILE) from the list + return ret[1:] + + def _get_main_address(self) -> str: + """A helper for getting the main address. + + The main address is assigned to the first node in the list of nodes used for the job. + + """ + hosts = self._read_hosts() + return hosts[0] + + @staticmethod + def _get_main_port() -> int: + """A helper function for accessing the main port. + + Uses the LSF job ID so all ranks can compute the main port. + + """ + # check for user-specified main port + if "MASTER_PORT" in os.environ: + log.debug(f"Using externally specified main port: {os.environ['MASTER_PORT']}") + return int(os.environ["MASTER_PORT"]) + if "LSB_JOBID" in os.environ: + port = int(os.environ["LSB_JOBID"]) + # all ports should be in the 10k+ range + port = port % 1000 + 10000 + log.debug(f"calculated LSF main port: {port}") + return port + raise ValueError("Could not find job id in environment variable LSB_JOBID") diff --git a/src/lightning/fabric/plugins/environments/slurm.py b/src/lightning/fabric/plugins/environments/slurm.py index 4ac79ce0014e9..4d98b7ed6a8eb 100644 --- a/src/lightning/fabric/plugins/environments/slurm.py +++ b/src/lightning/fabric/plugins/environments/slurm.py @@ -14,6 +14,7 @@ import logging import os +import re import shutil import signal import sys @@ -22,7 +23,7 @@ from typing_extensions import override from lightning.fabric.plugins.environments.cluster_environment import ClusterEnvironment -from lightning.fabric.utilities.imports import _IS_WINDOWS, _raise_enterprise_not_available +from lightning.fabric.utilities.imports import _IS_WINDOWS from lightning.fabric.utilities.rank_zero import rank_zero_warn from lightning.fabric.utilities.warnings import PossibleUserWarning @@ -45,35 +46,57 @@ class SLURMEnvironment(ClusterEnvironment): def __init__(self, auto_requeue: bool = True, requeue_signal: Optional[signal.Signals] = None) -> None: super().__init__() - _raise_enterprise_not_available() - from pytorch_lightning_enterprise.plugins.environments.slurm import ( - SLURMEnvironment as EnterpriseSLURMEnvironment, - ) - - self.slurm_impl = EnterpriseSLURMEnvironment(auto_requeue=auto_requeue, requeue_signal=requeue_signal) - - @property - def auto_requeue(self) -> bool: - return self.slurm_impl.auto_requeue - - @property - def requeue_signal(self) -> Optional[signal.Signals]: - return self.slurm_impl.requeue_signal + self.auto_requeue = auto_requeue + if requeue_signal is None and not _IS_WINDOWS: + requeue_signal = signal.SIGUSR1 + self.requeue_signal = requeue_signal + self._validate_srun_used() + self._validate_srun_variables() @property @override def creates_processes_externally(self) -> bool: - return self.slurm_impl.creates_processes_externally + return True @property @override def main_address(self) -> str: - return self.slurm_impl.main_address + root_node = os.environ.get("MASTER_ADDR") + if root_node is None: + nodelist = os.environ.get("SLURM_NODELIST", "127.0.0.1") + root_node = self.resolve_root_node_address(nodelist) + os.environ["MASTER_ADDR"] = root_node + + log.debug(f"MASTER_ADDR: {os.environ['MASTER_ADDR']}") + return root_node @property @override def main_port(self) -> int: - return self.slurm_impl.main_port + # ----------------------- + # SLURM JOB = PORT number + # ----------------------- + # this way every process knows what port to use + job_id = os.environ.get("SLURM_JOB_ID") + if job_id is not None: + # use the last 4 numbers in the job id as the id + default_port = job_id[-4:] + # all ports should be in the 10k+ range + default_port = int(default_port) + 15000 + else: + default_port = 12910 + + # ----------------------- + # PORT NUMBER = MASTER_PORT + # ----------------------- + # in case the user passed it in + if "MASTER_PORT" in os.environ: + default_port = int(os.environ["MASTER_PORT"]) + else: + os.environ["MASTER_PORT"] = str(default_port) + + log.debug(f"MASTER_PORT: {os.environ['MASTER_PORT']}") + return default_port @staticmethod @override @@ -95,40 +118,58 @@ def job_name() -> Optional[str]: @staticmethod def job_id() -> Optional[int]: - _raise_enterprise_not_available() - from pytorch_lightning_enterprise.plugins.environments.slurm import ( - SLURMEnvironment as EnterpriseSLURMEnvironment, - ) - - return EnterpriseSLURMEnvironment.job_id() + # in interactive mode, don't make logs use the same job id + if _is_slurm_interactive_mode(): + return None + + job_id = os.environ.get("SLURM_JOB_ID") + if job_id is None: + return None + try: + return int(job_id) + except ValueError: + return None @override def world_size(self) -> int: - return self.slurm_impl.world_size() + return int(os.environ["SLURM_NTASKS"]) @override def set_world_size(self, size: int) -> None: - return self.slurm_impl.set_world_size(size) + log.debug("SLURMEnvironment.set_world_size was called, but setting world size is not allowed. Ignored.") @override def global_rank(self) -> int: - return self.slurm_impl.global_rank() + return int(os.environ["SLURM_PROCID"]) @override def set_global_rank(self, rank: int) -> None: - return self.slurm_impl.set_global_rank(rank) + log.debug("SLURMEnvironment.set_global_rank was called, but setting global rank is not allowed. Ignored.") @override def local_rank(self) -> int: - return self.slurm_impl.local_rank() + return int(os.environ["SLURM_LOCALID"]) @override def node_rank(self) -> int: - return self.slurm_impl.node_rank() + return int(os.environ["SLURM_NODEID"]) @override def validate_settings(self, num_devices: int, num_nodes: int) -> None: - return self.slurm_impl.validate_settings(num_devices, num_nodes) + if _is_slurm_interactive_mode(): + return + ntasks_per_node = os.environ.get("SLURM_NTASKS_PER_NODE") + if ntasks_per_node is not None and int(ntasks_per_node) != num_devices: + raise ValueError( + f"You set `devices={num_devices}` in Lightning, but the number of tasks per node configured in SLURM" + f" `--ntasks-per-node={ntasks_per_node}` does not match. HINT: Set `devices={ntasks_per_node}`." + ) + nnodes = os.environ.get("SLURM_NNODES") + if nnodes is not None and int(nnodes) != num_nodes: + raise ValueError( + f"You set `num_nodes={num_nodes}` in Lightning, but the number of nodes configured in SLURM" + f" `--nodes={nnodes}` does not match. HINT: Set `num_nodes={nnodes}`." + ) @staticmethod def resolve_root_node_address(nodes: str) -> str: @@ -141,12 +182,9 @@ def resolve_root_node_address(nodes: str) -> str: - the range notation with brackets, e.g., 'host[5-9]' yields 'host5' as the root """ - _raise_enterprise_not_available() - from pytorch_lightning_enterprise.plugins.environments.slurm import ( - SLURMEnvironment as EnterpriseSLURMEnvironment, - ) - - return EnterpriseSLURMEnvironment.resolve_root_node_address(nodes) + nodes = re.sub(r"\[(.*?)[,-].*\]", "\\1", nodes) # Take the first node of every node range + nodes = re.sub(r"\[(.*?)\]", "\\1", nodes) # handle special case where node range is single number + return nodes.split(" ")[0].split(",")[0] @staticmethod def _validate_srun_used() -> None: @@ -178,12 +216,12 @@ def _validate_srun_variables() -> None: for a complete list of supported srun variables. """ - _raise_enterprise_not_available() - from pytorch_lightning_enterprise.plugins.environments.slurm import ( - SLURMEnvironment as EnterpriseSLURMEnvironment, - ) - - return EnterpriseSLURMEnvironment._validate_srun_variables() + ntasks = int(os.environ.get("SLURM_NTASKS", "1")) + if ntasks > 1 and "SLURM_NTASKS_PER_NODE" not in os.environ: + raise RuntimeError( + f"You set `--ntasks={ntasks}` in your SLURM bash script, but this variable is not supported." + f" HINT: Use `--ntasks-per-node={ntasks}` instead." + ) def _is_srun_used() -> bool: diff --git a/src/lightning/fabric/plugins/environments/torchelastic.py b/src/lightning/fabric/plugins/environments/torchelastic.py index 62ddd442df7b2..4003dd30dec09 100644 --- a/src/lightning/fabric/plugins/environments/torchelastic.py +++ b/src/lightning/fabric/plugins/environments/torchelastic.py @@ -13,12 +13,13 @@ # limitations under the License. import logging +import os import torch.distributed from typing_extensions import override from lightning.fabric.plugins.environments.cluster_environment import ClusterEnvironment -from lightning.fabric.utilities.imports import _raise_enterprise_not_available +from lightning.fabric.utilities.rank_zero import rank_zero_warn log = logging.getLogger(__name__) @@ -26,29 +27,29 @@ class TorchElasticEnvironment(ClusterEnvironment): """Environment for fault-tolerant and elastic training with `torchelastic `_""" - def __init__(self) -> None: - super().__init__() - _raise_enterprise_not_available() - from pytorch_lightning_enterprise.plugins.environments.torchelastic import ( - TorchElasticEnvironment as EnterpriseTorchElasticEnvironment, - ) - - self.torchelastic_impl = EnterpriseTorchElasticEnvironment() - @property @override def creates_processes_externally(self) -> bool: - return self.torchelastic_impl.creates_processes_externally + return True @property @override def main_address(self) -> str: - return self.torchelastic_impl.main_address + if "MASTER_ADDR" not in os.environ: + rank_zero_warn("MASTER_ADDR environment variable is not defined. Set as localhost") + os.environ["MASTER_ADDR"] = "127.0.0.1" + log.debug(f"MASTER_ADDR: {os.environ['MASTER_ADDR']}") + return os.environ["MASTER_ADDR"] @property @override def main_port(self) -> int: - return self.torchelastic_impl.main_port + if "MASTER_PORT" not in os.environ: + rank_zero_warn("MASTER_PORT environment variable is not defined. Set as 12910") + os.environ["MASTER_PORT"] = "12910" + log.debug(f"MASTER_PORT: {os.environ['MASTER_PORT']}") + + return int(os.environ["MASTER_PORT"]) @staticmethod @override @@ -59,28 +60,34 @@ def detect() -> bool: @override def world_size(self) -> int: - return self.torchelastic_impl.world_size() + return int(os.environ["WORLD_SIZE"]) @override def set_world_size(self, size: int) -> None: - return self.torchelastic_impl.set_world_size(size) + log.debug("TorchElasticEnvironment.set_world_size was called, but setting world size is not allowed. Ignored.") @override def global_rank(self) -> int: - return self.torchelastic_impl.global_rank() + return int(os.environ["RANK"]) @override def set_global_rank(self, rank: int) -> None: - return self.torchelastic_impl.set_global_rank(rank) + log.debug( + "TorchElasticEnvironment.set_global_rank was called, but setting global rank is not allowed. Ignored." + ) @override def local_rank(self) -> int: - return self.torchelastic_impl.local_rank() + return int(os.environ["LOCAL_RANK"]) @override def node_rank(self) -> int: - return self.torchelastic_impl.node_rank() + return int(os.environ.get("GROUP_RANK", 0)) @override def validate_settings(self, num_devices: int, num_nodes: int) -> None: - return self.torchelastic_impl.validate_settings(num_devices, num_nodes) + if num_devices * num_nodes != self.world_size(): + raise ValueError( + f"You set `devices={num_devices}` and `num_nodes={num_nodes}` in Lightning, but the product" + f" ({num_devices} * {num_nodes}) does not match the world size ({self.world_size()})." + ) diff --git a/src/lightning/fabric/plugins/environments/xla.py b/src/lightning/fabric/plugins/environments/xla.py index 515668fa2b154..b8350872f22d9 100644 --- a/src/lightning/fabric/plugins/environments/xla.py +++ b/src/lightning/fabric/plugins/environments/xla.py @@ -17,9 +17,8 @@ from typing_extensions import override -from lightning.fabric.accelerators.xla import XLAAccelerator +from lightning.fabric.accelerators.xla import _XLA_AVAILABLE, _XLA_GREATER_EQUAL_2_1, XLAAccelerator from lightning.fabric.plugins.environments.cluster_environment import ClusterEnvironment -from lightning.fabric.utilities.imports import _raise_enterprise_not_available log = logging.getLogger(__name__) @@ -33,28 +32,26 @@ class XLAEnvironment(ClusterEnvironment): """ def __init__(self, *args: Any, **kwargs: Any) -> None: - super().__init__() - _raise_enterprise_not_available() - from pytorch_lightning_enterprise.plugins.environments.xla import ( - XLAEnvironment as EnterpriseXLAEnvironment, - ) - - self.xla_impl = EnterpriseXLAEnvironment(*args, **kwargs) + if not _XLA_AVAILABLE: + raise ModuleNotFoundError(str(_XLA_AVAILABLE)) + super().__init__(*args, **kwargs) @property @override def creates_processes_externally(self) -> bool: - return self.xla_impl.creates_processes_externally + return False @property @override def main_address(self) -> str: - return self.xla_impl.main_address + # unused by lightning + raise NotImplementedError @property @override def main_port(self) -> int: - return self.xla_impl.main_port + # unused by lightning + raise NotImplementedError @staticmethod @override @@ -69,11 +66,18 @@ def world_size(self) -> int: The output is cached for performance. """ - return self.xla_impl.world_size() + if _XLA_GREATER_EQUAL_2_1: + from torch_xla import runtime as xr + + return xr.world_size() + + import torch_xla.core.xla_model as xm + + return xm.xrt_world_size() @override def set_world_size(self, size: int) -> None: - return self.xla_impl.set_world_size(size) + log.debug("XLAEnvironment.set_world_size was called, but setting world size is not allowed. Ignored.") @override @functools.lru_cache(maxsize=1) @@ -83,11 +87,18 @@ def global_rank(self) -> int: The output is cached for performance. """ - return self.xla_impl.global_rank() + if _XLA_GREATER_EQUAL_2_1: + from torch_xla import runtime as xr + + return xr.global_ordinal() + + import torch_xla.core.xla_model as xm + + return xm.get_ordinal() @override def set_global_rank(self, rank: int) -> None: - return self.xla_impl.set_global_rank(rank) + log.debug("XLAEnvironment.set_global_rank was called, but setting global rank is not allowed. Ignored.") @override @functools.lru_cache(maxsize=1) @@ -97,7 +108,14 @@ def local_rank(self) -> int: The output is cached for performance. """ - return self.xla_impl.local_rank() + if _XLA_GREATER_EQUAL_2_1: + from torch_xla import runtime as xr + + return xr.local_ordinal() + + import torch_xla.core.xla_model as xm + + return xm.get_local_ordinal() @override @functools.lru_cache(maxsize=1) @@ -107,4 +125,11 @@ def node_rank(self) -> int: The output is cached for performance. """ - return self.xla_impl.node_rank() + if _XLA_GREATER_EQUAL_2_1: + from torch_xla import runtime as xr + + return xr.host_index() + import torch_xla.core.xla_env_vars as xenv + from torch_xla.utils.utils import getenv_as + + return getenv_as(xenv.HOST_ORDINAL, int, 0) diff --git a/src/lightning/fabric/plugins/io/xla.py b/src/lightning/fabric/plugins/io/xla.py index da6ea414cac97..146fa2f33b510 100644 --- a/src/lightning/fabric/plugins/io/xla.py +++ b/src/lightning/fabric/plugins/io/xla.py @@ -12,12 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging +import os from typing import Any, Optional +import torch +from lightning_utilities.core.apply_func import apply_to_collection +from lightning_utilities.core.imports import RequirementCache from typing_extensions import override +from lightning.fabric.accelerators.xla import _XLA_AVAILABLE from lightning.fabric.plugins.io.torch_io import TorchCheckpointIO -from lightning.fabric.utilities.imports import _raise_enterprise_not_available +from lightning.fabric.utilities.cloud_io import get_filesystem from lightning.fabric.utilities.types import _PATH log = logging.getLogger(__name__) @@ -31,11 +36,9 @@ class XLACheckpointIO(TorchCheckpointIO): """ def __init__(self, *args: Any, **kwargs: Any) -> None: + if not _XLA_AVAILABLE: + raise ModuleNotFoundError(str(_XLA_AVAILABLE)) super().__init__(*args, **kwargs) - _raise_enterprise_not_available() - from pytorch_lightning_enterprise.plugins.io.xla import XLACheckpointIO as EnterpriseXLACheckpointIO - - self.xla_impl = EnterpriseXLACheckpointIO(*args, **kwargs) @override def save_checkpoint(self, checkpoint: dict[str, Any], path: _PATH, storage_options: Optional[Any] = None) -> None: @@ -51,4 +54,21 @@ def save_checkpoint(self, checkpoint: dict[str, Any], path: _PATH, storage_optio If ``storage_options`` arg is passed in """ - return self.xla_impl.save_checkpoint(checkpoint, path, storage_options) + if storage_options is not None: + raise TypeError( + "`Trainer.save_checkpoint(..., storage_options=...)` with `storage_options` arg" + f" is not supported for `{self.__class__.__name__}`. Please implement your custom `CheckpointIO`" + " to define how you'd like to use `storage_options`." + ) + fs = get_filesystem(path) + fs.makedirs(os.path.dirname(path), exist_ok=True) + if RequirementCache("omegaconf"): + # workaround for https://github.com/pytorch/xla/issues/2773 + from omegaconf import DictConfig, ListConfig, OmegaConf + + checkpoint = apply_to_collection(checkpoint, (DictConfig, ListConfig), OmegaConf.to_container) + import torch_xla.core.xla_model as xm + + cpu_data = xm._maybe_convert_to_cpu(checkpoint, convert=True) + log.debug(f"Saving checkpoint: {path}") + torch.save(cpu_data, path) diff --git a/src/lightning/fabric/plugins/precision/bitsandbytes.py b/src/lightning/fabric/plugins/precision/bitsandbytes.py index 4ff20ee9fab1e..4c648f2b97181 100644 --- a/src/lightning/fabric/plugins/precision/bitsandbytes.py +++ b/src/lightning/fabric/plugins/precision/bitsandbytes.py @@ -11,15 +11,34 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from contextlib import AbstractContextManager -from typing import Any, Literal, Optional +import functools +import logging +import math +import os +import warnings +from collections import OrderedDict +from contextlib import AbstractContextManager, ExitStack +from functools import partial +from types import ModuleType +from typing import Any, Callable, Literal, Optional, cast import torch +from lightning_utilities import apply_to_collection from lightning_utilities.core.imports import RequirementCache -from typing_extensions import override +from torch import Tensor +from torch.nn import init +from torch.nn.modules.module import _IncompatibleKeys +from typing_extensions import Self, override from lightning.fabric.plugins.precision.precision import Precision -from lightning.fabric.utilities.imports import _raise_enterprise_not_available +from lightning.fabric.plugins.precision.utils import ( + _ClassReplacementContextManager, + _convert_fp_tensor, + _DtypeContextManager, +) +from lightning.fabric.utilities.types import _DEVICE + +log = logging.getLogger(__name__) _BITSANDBYTES_AVAILABLE = RequirementCache("bitsandbytes") @@ -54,34 +73,376 @@ def __init__( dtype: Optional[torch.dtype] = None, ignore_modules: Optional[set[str]] = None, ) -> None: - super().__init__() - _raise_enterprise_not_available() - from pytorch_lightning_enterprise.plugins.precision.bitsandbytes import ( - BitsandbytesPrecision as EnterpriseBitsandbytesPrecision, - ) + _import_bitsandbytes() + + if dtype is None: + # try to be smart about the default selection + if mode.startswith("int8"): + dtype = torch.float16 + else: + dtype = ( + torch.bfloat16 if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else torch.float16 + ) + if mode.startswith("int8") and dtype is not torch.float16: + # this limitation is mentioned in https://huggingface.co/blog/hf-bitsandbytes-integration#usage + raise ValueError(f"{mode!r} only works with `dtype=torch.float16`, but you chose `{dtype}`") - self.bitsandbytes_impl = EnterpriseBitsandbytesPrecision(mode=mode, dtype=dtype, ignore_modules=ignore_modules) + globals_ = globals() + mode_to_cls = { + "nf4": globals_["_NF4Linear"], + "nf4-dq": globals_["_NF4DQLinear"], + "fp4": globals_["_FP4Linear"], + "fp4-dq": globals_["_FP4DQLinear"], + "int8-training": globals_["_Linear8bitLt"], + "int8": globals_["_Int8LinearInference"], + } + self._linear_cls = mode_to_cls[mode] + self.dtype = dtype + self.ignore_modules = ignore_modules or set() @override def convert_module(self, module: torch.nn.Module) -> torch.nn.Module: - return self.bitsandbytes_impl.convert_module(module) + # avoid naive users thinking they quantized their model + if not any(isinstance(m, torch.nn.Linear) for m in module.modules()): + raise TypeError( + "You are using the bitsandbytes precision plugin, but your model has no Linear layers. This plugin" + " won't work for your model." + ) + + # convert modules if they haven't been converted already + bnb = _import_bitsandbytes() + if not any(isinstance(m, (bnb.nn.Linear8bitLt, bnb.nn.Linear4bit)) for m in module.modules()): + # this will not quantize the model but only replace the layer classes + _convert_layers(module, self._linear_cls, self.ignore_modules) + + # set the compute dtype if necessary + for m in module.modules(): + if isinstance(m, bnb.nn.Linear4bit): + m.compute_dtype = self.dtype + m.compute_type_is_set = False + return module @override def tensor_init_context(self) -> AbstractContextManager: - return self.bitsandbytes_impl.tensor_init_context() + return _DtypeContextManager(self.dtype) @override def module_init_context(self) -> AbstractContextManager: - return self.bitsandbytes_impl.module_init_context() + if self.ignore_modules: + # cannot patch the Linear class if the user wants to skip some submodules + raise RuntimeError( + "Instantiating your model under the `init_module` context manager is not supported when used with" + f" `BitsandbytesPrecision(..., ignore_modules={self.ignore_modules})` as this" + " may initialize the layers on-device, defeating the purpose of quantization. You can remove" + " `ignore_modules` or remove the `init_module` context manager." + ) + dtype_ctx = self.tensor_init_context() + # TODO: this could also support replacing `Embedding` and `Conv1D` + context_manager = _ClassReplacementContextManager({"torch.nn.Linear": self._linear_cls}) + stack = ExitStack() + stack.enter_context(dtype_ctx) + stack.enter_context(context_manager) + return stack @override def forward_context(self) -> AbstractContextManager: - return self.bitsandbytes_impl.forward_context() + return _DtypeContextManager(self.dtype) @override def convert_input(self, data: Any) -> Any: - return self.bitsandbytes_impl.convert_input(data) + return apply_to_collection(data, function=_convert_fp_tensor, dtype=Tensor, dst_type=self.dtype) @override def convert_output(self, data: Any) -> Any: - return self.bitsandbytes_impl.convert_output(data) + return apply_to_collection(data, function=_convert_fp_tensor, dtype=Tensor, dst_type=torch.get_default_dtype()) + + +def _quantize_on_load_hook(quantize_fn: Callable[[torch.Tensor], None], state_dict: OrderedDict, *_: Any) -> None: + # There is only one key that ends with `*.weight`, the other one is the bias + weight_key = next((name for name in state_dict if name.endswith("weight")), None) + if weight_key is None: + return + # Load the weight from the state dict and re-quantize it + weight = state_dict.pop(weight_key) + quantize_fn(weight) + + +def _ignore_missing_weights_hook(module: torch.nn.Module, incompatible_keys: _IncompatibleKeys) -> None: + # since we manually loaded the weight in the `_quantize_on_load_hook` hook, we need to avoid this missing key false + # positive + for key in reversed(incompatible_keys.missing_keys): + if key.endswith("weight"): + incompatible_keys.missing_keys.remove(key) + + +def _replace_param( + param: torch.nn.Parameter, data: torch.Tensor, quant_state: Optional[tuple] = None +) -> torch.nn.Parameter: + bnb = _import_bitsandbytes() + + # doing `param.data = weight` raises a RuntimeError if param.data was on meta-device, so + # we need to re-create the parameters instead of overwriting the data + if param.device.type == "meta": + if isinstance(param, bnb.nn.Params4bit): + return bnb.nn.Params4bit( + data=data, + requires_grad=data.requires_grad, + quant_state=quant_state, + blocksize=param.blocksize, + compress_statistics=param.compress_statistics, + quant_type=param.quant_type, + quant_storage=param.quant_storage, + module=param.module, + bnb_quantized=param.bnb_quantized, + ) + return torch.nn.Parameter(data, requires_grad=data.requires_grad) + param.data = data + if isinstance(param, bnb.nn.Params4bit): + param.quant_state = quant_state + return param + + +@functools.lru_cache(maxsize=1) +def _import_bitsandbytes() -> ModuleType: + if not _BITSANDBYTES_AVAILABLE: + raise ModuleNotFoundError(str(_BITSANDBYTES_AVAILABLE)) + # configuration for bitsandbytes before import + nowelcome_set = "BITSANDBYTES_NOWELCOME" in os.environ + if not nowelcome_set: + os.environ["BITSANDBYTES_NOWELCOME"] = "1" + warnings.filterwarnings("ignore", message=r".*bitsandbytes was compiled without GPU support.*") + warnings.filterwarnings( + "ignore", message=r"MatMul8bitLt: inputs will be cast from .* to float16 during quantization" + ) + import bitsandbytes as bnb + + if not nowelcome_set: + del os.environ["BITSANDBYTES_NOWELCOME"] + + class _Linear8bitLt(bnb.nn.Linear8bitLt): + """Wraps `bnb.nn.Linear8bitLt` and enables instantiation directly on the device and re-quantizaton when loading + the state dict.""" + + def __init__(self, *args: Any, device: Optional[_DEVICE] = None, threshold: float = 6.0, **kwargs: Any) -> None: + super().__init__(*args, device=device, threshold=threshold, **kwargs) + self.weight = cast(bnb.nn.Int8Params, self.weight) # type: ignore[has-type] + self.bias: Optional[torch.nn.Parameter] = self.bias + # if the device is CUDA or we are under a CUDA context manager, quantize the weight here, so we don't end up + # filling the device memory with float32 weights which could lead to OOM + if torch.tensor(0, device=device).device.type == "cuda": + self.quantize_() + self._register_load_state_dict_pre_hook(partial(_quantize_on_load_hook, self.quantize_)) + self.register_load_state_dict_post_hook(_ignore_missing_weights_hook) + + def quantize_(self, weight: Optional[torch.Tensor] = None, device: Optional[torch.device] = None) -> None: + """Inplace quantize.""" + if weight is None: + weight = self.weight.data + if weight.data.dtype == torch.int8: + # already quantized + return + assert isinstance(self.weight, bnb.nn.Int8Params) + self.weight = self.quantize(self.weight, weight, device) + + @staticmethod + def quantize( + int8params: bnb.nn.Int8Params, weight: torch.Tensor, device: Optional[torch.device] + ) -> bnb.nn.Int8Params: + device = device or torch.device("cuda") + if device.type != "cuda": + raise RuntimeError(f"Unexpected device type: {device.type}") + # https://github.com/TimDettmers/bitsandbytes/blob/0.41.0/bitsandbytes/nn/modules.py#L291-L302 + B = weight.contiguous().to(device=device, dtype=torch.float16) + if int8params.has_fp16_weights: + int8params.data = B + else: + # bitsandbytes >= 0.45 supports an improved API + if hasattr(bnb.functional, "int8_vectorwise_quant"): + CB, SCB, _ = bnb.functional.int8_vectorwise_quant(B) + else: # old method is deprecated in 0.45, removed in 0.46+. + CB, _, SCB, _, _ = bnb.functional.double_quant(B) + + int8params.data = CB + setattr(int8params, "CB", CB) + setattr(int8params, "SCB", SCB) + return int8params + + def to_empty(self, *, device: _DEVICE, recurse: bool = True) -> Self: + if self.weight.device.type == "meta": + # need custom logic if int8params is on meta device + raise NotImplementedError + if self.weight.dtype == torch.uint8: # was quantized + # need the original shape here + raise NotImplementedError + device = torch.device(device) + weight = torch.empty_like(self.weight.data, device=device) + if device.type == "cuda": # re-quantize + self.quantize_(weight, device) + else: + self.weight = _replace_param(self.weight, weight) + if self.bias is not None: + self.bias = _replace_param(self.bias, torch.empty_like(self.bias, device=device)) + return self + + def reset_parameters(self) -> None: + # from `torch.nn.Linear.reset_parameters` + if self.bias is not None: + fan_in, _ = torch.nn.init._calculate_fan_in_and_fan_out(self.weight) + bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 + init.uniform_(self.bias, -bound, bound) + + linear_init_finished = isinstance(self.weight, bnb.nn.Params4bit) + if linear_init_finished and self.weight.dtype == torch.uint8: # was quantized + # need the original shape here + raise NotImplementedError + weight = self.weight.data + torch.nn.init.kaiming_uniform_(weight, a=math.sqrt(5)) + if linear_init_finished: + if self.weight.device.type == "meta": + # need custom logic if int8params is on meta device + raise NotImplementedError + if self.weight.device.type == "cuda": # re-quantize + self.quantize_(weight) + else: + self.weight = _replace_param(self.weight, weight) + + class _Linear4bit(bnb.nn.Linear4bit): + """Wraps `bnb.nn.Linear4bit` to enable: instantiation directly on the device, re-quantizaton when loading the + state dict, meta-device initialization, and materialization.""" + + def __init__(self, *args: Any, device: Optional[_DEVICE] = None, **kwargs: Any) -> None: + super().__init__(*args, device=device, **kwargs) + self.weight = cast(bnb.nn.Params4bit, self.weight) # type: ignore[has-type] + self.bias: Optional[torch.nn.Parameter] = self.bias + # if the device is CUDA or we are under a CUDA context manager, quantize the weight here, so we don't end up + # filling the device memory with float32 weights which could lead to OOM + if torch.tensor(0, device=device).device.type == "cuda": + self.quantize_() + self._register_load_state_dict_pre_hook(partial(_quantize_on_load_hook, self.quantize_)) + self.register_load_state_dict_post_hook(_ignore_missing_weights_hook) + + def quantize_(self, weight: Optional[torch.Tensor] = None, device: Optional[torch.device] = None) -> None: + """Inplace quantize.""" + if weight is None: + weight = self.weight.data + if weight.data.dtype == torch.uint8: + # already quantized + return + assert isinstance(self.weight, bnb.nn.Params4bit) + self.weight = self.quantize(self.weight, weight, device) + self.weight.bnb_quantized = True + + @staticmethod + def quantize( + params4bit: bnb.nn.Params4bit, weight: torch.Tensor, device: Optional[torch.device] + ) -> bnb.nn.Params4bit: + device = device or torch.device("cuda") + if device.type != "cuda": + raise RuntimeError(f"Unexpected device type: {device.type}") + # https://github.com/TimDettmers/bitsandbytes/blob/0.41.0/bitsandbytes/nn/modules.py#L156-L159 + w = weight.contiguous().to(device=device, dtype=torch.half) + w_4bit, quant_state = bnb.functional.quantize_4bit( + w, + blocksize=params4bit.blocksize, + compress_statistics=params4bit.compress_statistics, + quant_type=params4bit.quant_type, + quant_storage=params4bit.quant_storage, + ) + return _replace_param(params4bit, w_4bit, quant_state) + + def to_empty(self, *, device: _DEVICE, recurse: bool = True) -> Self: + if self.weight.dtype == torch.uint8: # was quantized + # cannot init the quantized params directly + weight = torch.empty(self.weight.quant_state.shape, device=device, dtype=torch.half) + else: + weight = torch.empty_like(self.weight.data, device=device) + device = torch.device(device) + if device.type == "cuda": # re-quantize + self.quantize_(weight, device) + else: + self.weight = _replace_param(self.weight, weight) + if self.bias is not None: + self.bias = _replace_param(self.bias, torch.empty_like(self.bias, device=device)) + return self + + def reset_parameters(self) -> None: + # from `torch.nn.Linear.reset_parameters` + if self.bias is not None: + fan_in, _ = torch.nn.init._calculate_fan_in_and_fan_out(self.weight) + bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 + init.uniform_(self.bias, -bound, bound) + + linear_init_finished = isinstance(self.weight, bnb.nn.Params4bit) + if linear_init_finished and self.weight.dtype == torch.uint8: # was quantized + # cannot init the quantized params directly + weight = torch.empty(self.weight.quant_state.shape, device=self.weight.device, dtype=torch.half) + else: + weight = self.weight.data + torch.nn.init.kaiming_uniform_(weight, a=math.sqrt(5)) + if linear_init_finished: + if self.weight.device.type == "cuda": # re-quantize + self.quantize_(weight) + else: + self.weight = _replace_param(self.weight, weight) + + # Use a class instead `functools.partial` to respect `isinstance` checks and attribute accesses + class _Int8LinearInference(_Linear8bitLt): + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, has_fp16_weights=False, **kwargs) + + class _FP4Linear(_Linear4bit): + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, quant_type="fp4", compress_statistics=False, **kwargs) + + class _FP4DQLinear(_Linear4bit): + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, quant_type="fp4", compress_statistics=True, **kwargs) + + class _NF4Linear(_Linear4bit): + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, quant_type="nf4", compress_statistics=False, **kwargs) + + class _NF4DQLinear(_Linear4bit): + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, quant_type="nf4", compress_statistics=True, **kwargs) + + # these classes are defined programmatically like this to avoid importing bitsandbytes in environments that have + # it available but will not use it + classes = { + "_Linear8bitLt": _Linear8bitLt, + "_Linear4bit": _Linear4bit, + "_Int8LinearInference": _Int8LinearInference, + "_FP4Linear": _FP4Linear, + "_FP4DQLinear": _FP4DQLinear, + "_NF4Linear": _NF4Linear, + "_NF4DQLinear": _NF4DQLinear, + } + globals().update(classes) + + return bnb + + +def _convert_layers(module: torch.nn.Module, linear_cls: type, ignore_modules: set[str], prefix: str = "") -> None: + for name, child in module.named_children(): + fullname = f"{prefix}.{name}" if prefix else name + if isinstance(child, torch.nn.Linear) and not any(fullname.startswith(s) for s in ignore_modules): + log.debug(f"Replacing layer {fullname!r} with bitsandbytes equivalent") + has_bias = child.bias is not None + # since we are going to copy over the child's data, the device doesn't matter. I chose CPU + # to avoid spiking CUDA memory even though initialization is slower + # 4bit layers support quantizing from meta-device params so this is only relevant for 8-bit + _Linear4bit = globals()["_Linear4bit"] + device = torch.device("meta" if issubclass(linear_cls, _Linear4bit) else "cpu") + replacement = linear_cls( + child.in_features, + child.out_features, + bias=has_bias, + device=device, + ) + if has_bias: + replacement.bias = _replace_param(replacement.bias, child.bias.data.clone()) + state = {"quant_state": replacement.weight.quant_state if issubclass(linear_cls, _Linear4bit) else None} + replacement.weight = _replace_param(replacement.weight, child.weight.data.clone(), **state) + module.__setattr__(name, replacement) + else: + _convert_layers(child, linear_cls, ignore_modules, prefix=fullname) diff --git a/src/lightning/fabric/plugins/precision/deepspeed.py b/src/lightning/fabric/plugins/precision/deepspeed.py index b2f369dbf956d..526095008f376 100644 --- a/src/lightning/fabric/plugins/precision/deepspeed.py +++ b/src/lightning/fabric/plugins/precision/deepspeed.py @@ -11,16 +11,17 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from contextlib import AbstractContextManager +from contextlib import AbstractContextManager, nullcontext from typing import TYPE_CHECKING, Any, Literal import torch +from lightning_utilities.core.apply_func import apply_to_collection from torch import Tensor from torch.nn import Module -from typing_extensions import override +from typing_extensions import get_args, override -from lightning.fabric.plugins.precision.precision import _PRECISION_INPUT_STR, Precision -from lightning.fabric.utilities.imports import _raise_enterprise_not_available +from lightning.fabric.plugins.precision.precision import Precision +from lightning.fabric.plugins.precision.utils import _convert_fp_tensor, _DtypeContextManager from lightning.fabric.utilities.types import Steppable if TYPE_CHECKING: @@ -43,37 +44,51 @@ class DeepSpeedPrecision(Precision): """ def __init__(self, precision: _PRECISION_INPUT) -> None: - super().__init__() - _raise_enterprise_not_available() - from pytorch_lightning_enterprise.plugins.precision.deepspeed import ( - DeepSpeedPrecisionFabric as EnterpriseDeepSpeedPrecision, - ) - - self.deepspeed_impl = EnterpriseDeepSpeedPrecision(precision=precision) + supported_precision = get_args(_PRECISION_INPUT) + if precision not in supported_precision: + raise ValueError( + f"`precision={precision!r})` is not supported in DeepSpeed." + f" `precision` must be one of: {supported_precision}." + ) + self.precision = precision + + precision_to_type = { + "bf16-mixed": torch.bfloat16, + "16-mixed": torch.float16, + "bf16-true": torch.bfloat16, + "16-true": torch.float16, + "32-true": torch.float32, + } + self._desired_dtype = precision_to_type[self.precision] @override def convert_module(self, module: Module) -> Module: - return self.deepspeed_impl.convert_module(module) + if "true" in self.precision: + return module.to(dtype=self._desired_dtype) + return module @override def tensor_init_context(self) -> AbstractContextManager: - return self.deepspeed_impl.tensor_init_context() + if "true" not in self.precision: + return nullcontext() + return _DtypeContextManager(self._desired_dtype) @override def module_init_context(self) -> AbstractContextManager: - return self.deepspeed_impl.module_init_context() + return self.tensor_init_context() @override def convert_input(self, data: Any) -> Any: - return self.deepspeed_impl.convert_input(data) + return apply_to_collection(data, function=_convert_fp_tensor, dtype=Tensor, dst_type=self._desired_dtype) @override def convert_output(self, data: Any) -> Any: - return self.deepspeed_impl.convert_output(data) + return apply_to_collection(data, function=_convert_fp_tensor, dtype=Tensor, dst_type=torch.get_default_dtype()) @override def backward(self, tensor: Tensor, model: "DeepSpeedEngine", *args: Any, **kwargs: Any) -> None: - return self.deepspeed_impl.backward(tensor, model, *args, **kwargs) + """Performs back-propagation using DeepSpeed's engine.""" + model.backward(tensor, *args, **kwargs) @override def optimizer_step( @@ -81,20 +96,5 @@ def optimizer_step( optimizer: Steppable, **kwargs: Any, ) -> Any: - return self.deepspeed_impl.optimizer_step(optimizer, **kwargs) - - @property - def precision(self) -> _PRECISION_INPUT_STR: - return self.deepspeed_impl.precision - - @precision.setter - def precision(self, precision: _PRECISION_INPUT_STR) -> None: - self.deepspeed_impl.precision = precision - - @property - def _desired_dtype(self) -> torch.dtype: - return self.deepspeed_impl._desired_dtype - - @_desired_dtype.setter - def _desired_dtype(self, dtype: torch.dtype) -> None: - self.deepspeed_impl._desired_dtype = dtype + # DeepSpeed handles the optimizer step internally + return optimizer.step(**kwargs) diff --git a/src/lightning/fabric/plugins/precision/transformer_engine.py b/src/lightning/fabric/plugins/precision/transformer_engine.py index 3d3d2491763f4..bf1e51ea6b2b0 100644 --- a/src/lightning/fabric/plugins/precision/transformer_engine.py +++ b/src/lightning/fabric/plugins/precision/transformer_engine.py @@ -11,19 +11,31 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import logging from collections.abc import Mapping -from contextlib import AbstractContextManager +from contextlib import AbstractContextManager, ExitStack from typing import TYPE_CHECKING, Any, Literal, Optional, Union import torch +from lightning_utilities import apply_to_collection +from lightning_utilities.core.imports import RequirementCache +from torch import Tensor from typing_extensions import override from lightning.fabric.plugins.precision.precision import Precision -from lightning.fabric.utilities.imports import _raise_enterprise_not_available +from lightning.fabric.plugins.precision.utils import ( + _ClassReplacementContextManager, + _convert_fp_tensor, + _DtypeContextManager, +) +from lightning.fabric.utilities.rank_zero import rank_zero_info, rank_zero_warn if TYPE_CHECKING: from transformer_engine.common.recipe import DelayedScaling +_TRANSFORMER_ENGINE_AVAILABLE = RequirementCache("transformer_engine>=0.11.0") +log = logging.getLogger(__name__) + class TransformerEnginePrecision(Precision): """Plugin for training with fp8 precision via nvidia's @@ -60,79 +72,111 @@ def __init__( replace_layers: Optional[bool] = None, fallback_compute_dtype: Optional[torch.dtype] = None, ) -> None: - super().__init__() - _raise_enterprise_not_available() - from pytorch_lightning_enterprise.plugins.precision.transformer_engine import ( - TransformerEnginePrecision as EnterpriseTransformerEnginePrecision, - ) - - self.transformer_engine_impl = EnterpriseTransformerEnginePrecision( - weights_dtype=weights_dtype, - recipe=recipe, - replace_layers=replace_layers, - fallback_compute_dtype=fallback_compute_dtype, - ) - - @property - def weights_dtype(self) -> torch.dtype: - return self.transformer_engine_impl.weights_dtype - - @weights_dtype.setter - def weights_dtype(self, value: torch.dtype) -> None: - self.transformer_engine_impl.weights_dtype = value - - @property - def recipe(self) -> Union[Mapping[str, Any], "DelayedScaling"]: - return self.transformer_engine_impl.recipe - - @recipe.setter - def recipe(self, value: Union[Mapping[str, Any], "DelayedScaling"]) -> None: - self.transformer_engine_impl.recipe = value - - @property - def replace_layers(self) -> bool: - return self.transformer_engine_impl.replace_layers - - @replace_layers.setter - def replace_layers(self, value: bool) -> None: - self.transformer_engine_impl.replace_layers = value - - @property - def fallback_compute_dtype(self) -> torch.dtype: - return self.transformer_engine_impl.fallback_compute_dtype - - @fallback_compute_dtype.setter - def fallback_compute_dtype(self, value: torch.dtype) -> None: - self.transformer_engine_impl.fallback_compute_dtype = value + if not _TRANSFORMER_ENGINE_AVAILABLE: + raise ModuleNotFoundError(str(_TRANSFORMER_ENGINE_AVAILABLE)) + from transformer_engine.common.recipe import DelayedScaling + + if recipe is None: + recipe = DelayedScaling() + elif isinstance(recipe, Mapping): + recipe = dict(recipe) # copy + if "fp8_format" in recipe: + from transformer_engine.common.recipe import Format + + recipe["fp8_format"] = getattr(Format, recipe["fp8_format"]) + recipe = DelayedScaling(**recipe) + + self.weights_dtype = weights_dtype + self.recipe = recipe + self.replace_layers = replace_layers + self.fallback_compute_dtype = fallback_compute_dtype or weights_dtype @override def convert_module(self, module: torch.nn.Module) -> torch.nn.Module: - return self.transformer_engine_impl.convert_module(module) + # avoid converting if any is found. assume the user took care of it + if any("transformer_engine.pytorch" in m.__module__ for m in module.modules()): + if self.replace_layers is True: + # info level because this is expected with `init_module` + rank_zero_info( + "`TransformerEnginePrecision(replace_layers=True)` is set but the model already contains" + " TransformerEngine layers. Skipping" + ) + elif self.replace_layers in (None, True): + _convert_layers(module) + module = module.to(dtype=self.weights_dtype) + return module @override def tensor_init_context(self) -> AbstractContextManager: - return self.transformer_engine_impl.tensor_init_context() + return _DtypeContextManager(self.weights_dtype) @override def module_init_context(self) -> AbstractContextManager: - return self.transformer_engine_impl.module_init_context() + dtype_ctx = self.tensor_init_context() + stack = ExitStack() + if self.replace_layers: + import transformer_engine.pytorch as te + + context_manager = _ClassReplacementContextManager({ + "torch.nn.Linear": te.Linear, + "torch.nn.LayerNorm": te.LayerNorm, + }) + stack.enter_context(context_manager) + stack.enter_context(dtype_ctx) + return stack @override def forward_context(self) -> AbstractContextManager: - return self.transformer_engine_impl.forward_context() + dtype_ctx = _DtypeContextManager(self.weights_dtype) + fallback_autocast_ctx = torch.autocast(device_type="cuda", dtype=self.fallback_compute_dtype) + import transformer_engine.pytorch as te + + autocast_ctx = te.fp8_autocast(enabled=True, fp8_recipe=self.recipe) + stack = ExitStack() + stack.enter_context(dtype_ctx) + # enable an outer fallback autocast for operations that do not support fp8 + stack.enter_context(fallback_autocast_ctx) + stack.enter_context(autocast_ctx) + return stack @override def convert_input(self, data: Any) -> Any: - return self.transformer_engine_impl.convert_input(data) + return apply_to_collection(data, function=_convert_fp_tensor, dtype=Tensor, dst_type=self.weights_dtype) @override def convert_output(self, data: Any) -> Any: - return self.transformer_engine_impl.convert_output(data) - - @property - def _desired_dtype(self) -> torch.dtype: - return self.transformer_engine_impl._desired_dtype - - @_desired_dtype.setter - def _desired_dtype(self, dtype: torch.dtype) -> None: - self.transformer_engine_impl._desired_dtype = dtype + return apply_to_collection(data, function=_convert_fp_tensor, dtype=Tensor, dst_type=torch.get_default_dtype()) + + +def _convert_layers(module: torch.nn.Module) -> None: + import transformer_engine.pytorch as te + + for name, child in module.named_children(): + if isinstance(child, torch.nn.Linear): + if child.in_features % 8 != 0 or child.out_features % 16 != 0: + # https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/examples/fp8_primer.html#FP8-autocasting + rank_zero_warn( + "Support for FP8 in the linear layers with this plugin is currently limited to" + " tensors with shapes where the dimensions are divisible by 8 and 16 respectively." + f" The layer {name!r} does not fit this criteria. You might want to add padding to your inputs." + ) + continue + has_bias = child.bias is not None + replacement = te.Linear(child.in_features, child.out_features, bias=has_bias) + replacement.weight.data = child.weight.data.clone() + if has_bias: + replacement.bias.data = child.bias.data.clone() + log.debug(f"Replacing layer {name!r} with Transformer Engine equivalent") + module.__setattr__(name, replacement) + elif isinstance(child, torch.nn.LayerNorm): + replacement = te.LayerNorm(child.normalized_shape[0], eps=child.eps) + replacement.weight.data = child.weight.data.clone() + # Check if bias exists before attempting to clone its data + if child.bias is not None and replacement.bias is not None: + replacement.bias.data = child.bias.data.clone() + log.debug(f"Replacing layer {name!r} with Transformer Engine equivalent") + module.__setattr__(name, replacement) + else: + # there are other transformer engine layers that we could convert but require fusion. full list at: + # https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/api/pytorch.html + _convert_layers(child) diff --git a/src/lightning/fabric/plugins/precision/xla.py b/src/lightning/fabric/plugins/precision/xla.py index 0946f394ca954..fdb30032b3cdd 100644 --- a/src/lightning/fabric/plugins/precision/xla.py +++ b/src/lightning/fabric/plugins/precision/xla.py @@ -11,13 +11,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import os from typing import Any, Literal import torch -from typing_extensions import override +from typing_extensions import get_args, override -from lightning.fabric.plugins.precision.precision import _PRECISION_INPUT_STR, Precision -from lightning.fabric.utilities.imports import _raise_enterprise_not_available +from lightning.fabric.accelerators.xla import _XLA_AVAILABLE +from lightning.fabric.plugins.precision.precision import Precision from lightning.fabric.utilities.types import Optimizable _PRECISION_INPUT = Literal["32-true", "16-true", "bf16-true"] @@ -36,11 +37,24 @@ class XLAPrecision(Precision): """ def __init__(self, precision: _PRECISION_INPUT) -> None: - super().__init__() - _raise_enterprise_not_available() - from pytorch_lightning_enterprise.plugins.precision.xla import XLAPrecision as EnterpriseXLAPrecision - - self.xla_impl = EnterpriseXLAPrecision(precision=precision) + if not _XLA_AVAILABLE: + raise ModuleNotFoundError(str(_XLA_AVAILABLE)) + supported_precision = get_args(_PRECISION_INPUT) + if precision not in supported_precision: + raise ValueError( + f"`precision={precision!r})` is not supported in XLA." + f" `precision` must be one of: {supported_precision}." + ) + self.precision = precision + + if precision == "16-true": + os.environ["XLA_USE_F16"] = "1" + self._desired_dtype = torch.float16 + elif precision == "bf16-true": + os.environ["XLA_USE_BF16"] = "1" + self._desired_dtype = torch.bfloat16 + else: + self._desired_dtype = torch.float32 @override def optimizer_step( @@ -48,24 +62,12 @@ def optimizer_step( optimizer: Optimizable, **kwargs: Any, ) -> Any: - return self.xla_impl.optimizer_step(optimizer, **kwargs) + import torch_xla.core.xla_model as xm + + # you always want to `xm.mark_step()` after `optimizer.step` for better performance, so we set `barrier=True` + return xm.optimizer_step(optimizer, optimizer_args=kwargs, barrier=True) @override def teardown(self) -> None: - return self.xla_impl.teardown() - - @property - def _desired_dtype(self) -> torch.dtype: - return self.xla_impl._desired_dtype - - @_desired_dtype.setter - def _desired_dtype(self, dtype: torch.dtype) -> None: - self.xla_impl._desired_dtype = dtype - - @property - def precision(self) -> _PRECISION_INPUT_STR: - return self.xla_impl.precision - - @precision.setter - def precision(self, precision: _PRECISION_INPUT_STR) -> None: - self.xla_impl.precision = precision + os.environ.pop("XLA_USE_BF16", None) + os.environ.pop("XLA_USE_F16", None) diff --git a/src/lightning/fabric/strategies/deepspeed.py b/src/lightning/fabric/strategies/deepspeed.py index 93d05016868dd..883546fea1f2d 100644 --- a/src/lightning/fabric/strategies/deepspeed.py +++ b/src/lightning/fabric/strategies/deepspeed.py @@ -11,30 +11,45 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import argparse +import json import logging -from contextlib import AbstractContextManager +import os +import platform +from collections.abc import Mapping +from contextlib import AbstractContextManager, ExitStack from datetime import timedelta +from itertools import chain +from pathlib import Path from typing import TYPE_CHECKING, Any, Callable, Optional, Union import torch +from lightning_utilities.core.imports import RequirementCache from torch.nn import Module from torch.optim import Optimizer from typing_extensions import override -from lightning.fabric.accelerators import Accelerator +from lightning.fabric.accelerators import Accelerator, CUDAAccelerator from lightning.fabric.plugins.collectives.torch_collective import default_pg_timeout from lightning.fabric.plugins.environments.cluster_environment import ClusterEnvironment from lightning.fabric.plugins.precision import Precision from lightning.fabric.strategies.ddp import DDPStrategy from lightning.fabric.strategies.registry import _StrategyRegistry from lightning.fabric.strategies.strategy import _Sharded -from lightning.fabric.utilities.imports import _raise_enterprise_not_available +from lightning.fabric.utilities.distributed import log +from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_6 +from lightning.fabric.utilities.load import _move_state_into +from lightning.fabric.utilities.rank_zero import rank_zero_info, rank_zero_warn +from lightning.fabric.utilities.seed import reset_seed from lightning.fabric.utilities.types import _PATH if TYPE_CHECKING: from deepspeed import DeepSpeedEngine from torch.optim.lr_scheduler import _LRScheduler +_DEEPSPEED_AVAILABLE = RequirementCache("deepspeed") +_DEEPSPEED_GREATER_EQUAL_0_16 = RequirementCache("deepspeed>=0.16.0") + # TODO(fabric): Links in the docstrings to PL-specific deepspeed user docs need to be replaced. class DeepSpeedStrategy(DDPStrategy, _Sharded): @@ -220,11 +235,24 @@ def __init__( exclude_frozen_parameters: Exclude frozen parameters when saving checkpoints. """ + if not _DEEPSPEED_AVAILABLE: + raise ImportError( + "To use the `DeepSpeedStrategy`, you must have DeepSpeed installed." + " Install it by running `pip install -U deepspeed`." + ) - _raise_enterprise_not_available() - from pytorch_lightning_enterprise.strategies.deepspeed import ( - DeepSpeedStrategyFabric as EnterpriseDeepSpeedStrategy, - ) + if _TORCH_GREATER_EQUAL_2_6 and not _DEEPSPEED_GREATER_EQUAL_0_16: + # Starting with PyTorch 2.6, `torch.load` defaults to `weights_only=True` when loading full checkpoints. + # DeepSpeed added support for this behavior in version 0.16.0. + import deepspeed + + deepspeed_version = deepspeed.__version__ + + raise ImportError( + f"PyTorch >= 2.6 requires DeepSpeed >= 0.16.0. " + f"Detected DeepSpeed version: {deepspeed_version}. " + "Please upgrade by running `pip install -U 'deepspeed>=0.16.0'`." + ) super().__init__( accelerator=accelerator, @@ -233,68 +261,77 @@ def __init__( precision=precision, process_group_backend=process_group_backend, ) - self.deepspeed_impl = EnterpriseDeepSpeedStrategy( - outer_object=self, - accelerator=accelerator, - zero_optimization=zero_optimization, - stage=stage, - remote_device=remote_device, - offload_optimizer=offload_optimizer, - offload_parameters=offload_parameters, - offload_params_device=offload_params_device, - nvme_path=nvme_path, - params_buffer_count=params_buffer_count, - params_buffer_size=params_buffer_size, - max_in_cpu=max_in_cpu, - offload_optimizer_device=offload_optimizer_device, - optimizer_buffer_count=optimizer_buffer_count, - block_size=block_size, - queue_depth=queue_depth, - single_submit=single_submit, - overlap_events=overlap_events, - thread_count=thread_count, - pin_memory=pin_memory, - sub_group_size=sub_group_size, - contiguous_gradients=contiguous_gradients, - overlap_comm=overlap_comm, - allgather_partitions=allgather_partitions, - reduce_scatter=reduce_scatter, - allgather_bucket_size=allgather_bucket_size, - reduce_bucket_size=reduce_bucket_size, - zero_allow_untested_optimizer=zero_allow_untested_optimizer, - logging_batch_size_per_gpu=logging_batch_size_per_gpu, - config=config, - logging_level=logging_level, - parallel_devices=parallel_devices, - cluster_environment=cluster_environment, - loss_scale=loss_scale, - initial_scale_power=initial_scale_power, - loss_scale_window=loss_scale_window, - hysteresis=hysteresis, - min_loss_scale=min_loss_scale, - partition_activations=partition_activations, - cpu_checkpointing=cpu_checkpointing, - contiguous_memory_optimization=contiguous_memory_optimization, - synchronize_checkpoint_boundary=synchronize_checkpoint_boundary, - load_full_weights=load_full_weights, - precision=precision, - process_group_backend=process_group_backend, - timeout=timeout, - exclude_frozen_parameters=exclude_frozen_parameters, - ) + self._backward_sync_control = None # DeepSpeed handles gradient accumulation internally + self._timeout: Optional[timedelta] = timeout + + self.config = self._load_config(config) + if self.config is None: + # User has not overridden config, set defaults + self.config = self._create_default_config( + zero_optimization, + zero_allow_untested_optimizer, + logging_batch_size_per_gpu, + offload_optimizer=offload_optimizer, + offload_parameters=offload_parameters, + nvme_path=nvme_path, + offload_params_device=offload_params_device, + params_buffer_count=params_buffer_count, + params_buffer_size=params_buffer_size, + max_in_cpu=max_in_cpu, + pin_memory=pin_memory, + offload_optimizer_device=offload_optimizer_device, + optimizer_buffer_count=optimizer_buffer_count, + block_size=block_size, + queue_depth=queue_depth, + single_submit=single_submit, + overlap_events=overlap_events, + thread_count=thread_count, + partition_activations=partition_activations, + cpu_checkpointing=cpu_checkpointing, + contiguous_memory_optimization=contiguous_memory_optimization, + synchronize_checkpoint_boundary=synchronize_checkpoint_boundary, + stage=stage, + contiguous_gradients=contiguous_gradients, + overlap_comm=overlap_comm, + allgather_partitions=allgather_partitions, + reduce_scatter=reduce_scatter, + allgather_bucket_size=allgather_bucket_size, + reduce_bucket_size=reduce_bucket_size, + sub_group_size=sub_group_size, + ) + + import deepspeed + + self._config_initialized = False + deepspeed.utils.logging.logger.setLevel(logging_level) + + self.remote_device = remote_device + self.load_full_weights = load_full_weights + self.exclude_frozen_parameters = exclude_frozen_parameters + + # default FP16 parameters. + self.loss_scale = loss_scale + self.initial_scale_power = initial_scale_power + self.loss_scale_window = loss_scale_window + self.hysteresis = hysteresis + self.min_loss_scale = min_loss_scale + + self._deepspeed_engine: Optional[DeepSpeedEngine] = None @property def zero_stage_3(self) -> bool: - return self.deepspeed_impl.zero_stage_3 + assert isinstance(self.config, dict) + zero_optimization = self.config.get("zero_optimization") + return zero_optimization is not None and zero_optimization.get("stage") == 3 @property @override def distributed_sampler_kwargs(self) -> dict[str, int]: - return self.deepspeed_impl.distributed_sampler_kwargs + return {"num_replicas": self.world_size, "rank": self.global_rank} @property def model(self) -> "DeepSpeedEngine": - return self.deepspeed_impl._deepspeed_engine + return self._deepspeed_engine @override def setup_module_and_optimizers( @@ -308,9 +345,14 @@ def setup_module_and_optimizers( deepspeed optimizer, and an optional learning rate scheduler. """ - return self.deepspeed_impl.setup_module_and_optimizers( - module=module, optimizers=optimizers, scheduler=scheduler - ) + if len(optimizers) != 1: + raise ValueError( + f"Currently only one optimizer is supported with DeepSpeed. Got {len(optimizers)} optimizers instead." + ) + + self._deepspeed_engine, optimizer, scheduler = self._initialize_engine(module, optimizers[0], scheduler) + self._set_deepspeed_activation_checkpointing() + return self._deepspeed_engine, [optimizer], scheduler @override def setup_module(self, module: Module) -> "DeepSpeedEngine": @@ -319,7 +361,8 @@ def setup_module(self, module: Module) -> "DeepSpeedEngine": For training, see :meth:`setup_module_and_optimizers`. """ - return self.deepspeed_impl.setup_module(module=module) + self._deepspeed_engine, _, _ = self._initialize_engine(module) + return self._deepspeed_engine @override def setup_optimizer(self, optimizer: Optimizer) -> Optimizer: @@ -328,15 +371,34 @@ def setup_optimizer(self, optimizer: Optimizer) -> Optimizer: Please use :meth:`setup_module_and_optimizers` to set up both module and optimizer together. """ - return self.deepspeed_impl.setup_optimizer(optimizer=optimizer) + raise NotImplementedError(self._err_msg_joint_setup_required()) @override def module_init_context(self, empty_init: Optional[bool] = None) -> AbstractContextManager: - return self.deepspeed_impl.module_init_context(empty_init=empty_init) + if self.zero_stage_3 and empty_init is False: + raise NotImplementedError( + f"`{empty_init=}` is not a valid choice with `DeepSpeedStrategy` when ZeRO stage 3 is enabled." + ) + module_sharded_ctx = self.module_sharded_context() + stack = ExitStack() + if not self.zero_stage_3: + stack.enter_context(super().module_init_context(empty_init=empty_init)) + stack.enter_context(module_sharded_ctx) + return stack @override def module_sharded_context(self) -> AbstractContextManager: - return self.deepspeed_impl.module_sharded_context() + # Current limitation in Fabric: The config needs to be fully determined at the time of calling the context + # manager. Later modifications through e.g. `Fabric.setup()` won't have an effect here. + + import deepspeed + + assert self._config_initialized + return deepspeed.zero.Init( + enabled=self.zero_stage_3, + remote_device=self.remote_device, + config_dict_or_path=self.config, + ) @override def save_checkpoint( @@ -363,8 +425,46 @@ def save_checkpoint( :class:`deepspeed.DeepSpeedEngine` objects were found. """ - return self.deepspeed_impl.save_checkpoint( - path=path, state=state, storage_options=storage_options, filter=filter + if storage_options is not None: + raise TypeError( + "`DeepSpeedStrategy.save_checkpoint(..., storage_options=...)` is not supported because" + " `DeepSpeedStrategy` does not use the `CheckpointIO`." + ) + if filter is not None: + raise TypeError( + "`DeepSpeedStrategy.save_checkpoint(..., filter=...)` is not supported because" + " `DeepSpeedStrategy` manages the state serialization internally." + ) + + engines = _get_deepspeed_engines_from_state(state) + if len(engines) == 0: + raise ValueError( + "Could not find a DeepSpeed model in the provided checkpoint state. Please provide the model as" + " part of the state like so: `save_checkpoint(..., state={'model': model, ...})`. Make sure" + " you set up the model (and optimizers if any) through the strategy before saving the checkpoint." + ) + if len(engines) > 1: + raise ValueError( + "Found multiple DeepSpeed engine modules in the given state. Saving checkpoints with DeepSpeed is" + " currently limited to a single model per checkpoint. To save multiple models, call the" + " save method for each model separately with a different path." + ) + engine = engines[0] + + # broadcast the path from rank 0 to ensure all the states are saved in a common path + path = self.broadcast(path) + + # split the checkpoint into two parts: + # 1) the deepspeed engine encapsulating both the model and optionally the optimizer(s) + # 2) the rest of the user's state, which in deepspeed is called `client state` + excluded_objects = (engine, engine.optimizer) if engine.optimizer is not None else (engine,) + state = {k: v for k, v in state.items() if v not in excluded_objects} + _validate_state_keys(state) + # there might be other stateful objects unrelated to the deepspeed engine - convert them to a state_dict + state = self._convert_stateful_objects_in_state(state, filter={}) + # use deepspeed's internal checkpointing function to handle partitioned weights across processes + engine.save_checkpoint( + path, client_state=state, tag="checkpoint", exclude_frozen_parameters=self.exclude_frozen_parameters ) @override @@ -395,7 +495,59 @@ def load_checkpoint( not in the expected DeepSpeed format. """ - return self.deepspeed_impl.load_checkpoint(path=path, state=state, strict=strict) + if isinstance(state, (Module, Optimizer)) or self.load_full_weights and self.zero_stage_3: + # This code path to enables loading a checkpoint from a non-deepspeed checkpoint or from + # a consolidated checkpoint + path = self.broadcast(path) + return super().load_checkpoint(path=path, state=state, strict=strict, weights_only=weights_only) + + if not state: + raise ValueError( + f"Got DeepSpeedStrategy.load_checkpoint(..., state={state!r}) but a state with at least " + f" a model instance to reload is required. Pass it in like so:" + " DeepSpeedStrategy.load_checkpoint(..., state={'model': model, ...})" + ) + _validate_checkpoint_directory(path) + + engines = _get_deepspeed_engines_from_state(state) + if len(engines) == 0: + raise ValueError( + "Could not find a DeepSpeed model in the provided checkpoint state. Please provide the model as" + " part of the state like so: `load_checkpoint(..., state={'model': model, ...})`. Make sure" + " you set up the model (and optimizers if any) through the strategy before loading the checkpoint." + ) + if len(engines) > 1: + raise ValueError( + "Found multiple DeepSpeed engine modules in the given state. Saving and loading checkpoints" + " with DeepSpeed is currently limited to a single model per checkpoint. To load multiple model" + " states, call the load method for each model checkpoint separately." + ) + engine = engines[0] + + from deepspeed.runtime.base_optimizer import DeepSpeedOptimizer + + optimzer_state_requested = any(isinstance(item, (Optimizer, DeepSpeedOptimizer)) for item in state.values()) + + torch.cuda.empty_cache() + _, client_state = engine.load_checkpoint( + path, + tag="checkpoint", + load_optimizer_states=optimzer_state_requested, + load_lr_scheduler_states=False, + load_module_strict=strict, + ) + + if client_state is None: + raise RuntimeError( + "DeepSpeed was unable to load the checkpoint. Ensure you passed in a DeepSpeed compatible checkpoint" + " or a single checkpoint file by setting `DeepSpeedStrategy(..., load_full_weights=True)`." + ) + + # `Engine.load_checkpoint` adds useless keys 'optimizer' and 'lr_scheduler' to the client state; remove + # them to avoid name collision with user state + keys = set(client_state) & set(state) - {"optimizer", "lr_scheduler"} + _move_state_into(source=client_state, destination=state, keys=keys) + return client_state @override def clip_gradients_norm( @@ -406,19 +558,19 @@ def clip_gradients_norm( norm_type: Union[float, int] = 2.0, error_if_nonfinite: bool = True, ) -> torch.Tensor: - return self.deepspeed_impl.clip_gradients_norm( - module=module, - optimizer=optimizer, - max_norm=max_norm, - norm_type=norm_type, - error_if_nonfinite=error_if_nonfinite, + raise NotImplementedError( + "DeepSpeed handles gradient clipping automatically within the optimizer. " + "Make sure to set the `gradient_clipping` value in your Config." ) @override def clip_gradients_value( self, module: "DeepSpeedEngine", optimizer: Optimizer, clip_val: Union[float, int] ) -> None: - return self.deepspeed_impl.clip_gradients_value(module=module, optimizer=optimizer, clip_val=clip_val) + raise NotImplementedError( + "DeepSpeed handles gradient clipping automatically within the optimizer. " + "Make sure to set the `gradient_clipping` value in your Config." + ) @classmethod @override @@ -461,22 +613,338 @@ def register_strategies(cls, strategy_registry: _StrategyRegistry) -> None: offload_optimizer_device="nvme", ) + def _initialize_engine( + self, model: Module, optimizer: Optional[Optimizer] = None, scheduler: Optional["_LRScheduler"] = None + ) -> tuple["DeepSpeedEngine", Optimizer, Any]: + """Initialize one model and one optimizer with an optional learning rate scheduler. + + This calls ``deepspeed.initialize`` internally. + + """ + import deepspeed + + model_parameters = filter(lambda p: p.requires_grad, model.parameters()) + deepspeed_engine, deepspeed_optimizer, _, deepspeed_scheduler = deepspeed.initialize( + args=argparse.Namespace(device_rank=self.root_device.index), + config=self.config, + model=model, + model_parameters=model_parameters, + optimizer=optimizer, + lr_scheduler=scheduler, + dist_init_required=False, + ) + return deepspeed_engine, deepspeed_optimizer, deepspeed_scheduler + @override def setup_environment(self) -> None: - return self.deepspeed_impl.setup_environment() + if not isinstance(self.accelerator, CUDAAccelerator): + raise RuntimeError( + f"The DeepSpeed strategy is only supported on CUDA GPUs but `{self.accelerator.__class__.__name__}`" + " is used." + ) + super().setup_environment() @override def _setup_distributed(self) -> None: - return self.deepspeed_impl._setup_distributed() + assert self.parallel_devices is not None + _validate_device_index_selection(self.parallel_devices) + reset_seed() + self._set_world_ranks() + self._init_deepspeed_distributed() + if not self._config_initialized: + self._format_config() + self._config_initialized = True + + def _init_deepspeed_distributed(self) -> None: + import deepspeed + + assert self.cluster_environment is not None + if platform.system() != "Windows": + # do not set env variables on windows, allow deepspeed to control setup + self._set_node_environment_variables() + log.info( + "initializing deepspeed distributed: " + f"GLOBAL_RANK: {self.global_rank}, " + f"MEMBER: {self.global_rank + 1}/{self.world_size}" + ) + self._process_group_backend = self._get_process_group_backend() + deepspeed.init_distributed( + self._process_group_backend, distributed_port=self.cluster_environment.main_port, timeout=self._timeout + ) - @property - def config(self) -> dict[str, Any]: - return self.deepspeed_impl.config + def _set_node_environment_variables(self) -> None: + assert self.cluster_environment is not None + os.environ["MASTER_ADDR"] = self.cluster_environment.main_address + os.environ["MASTER_PORT"] = str(self.cluster_environment.main_port) + os.environ["RANK"] = str(self.global_rank) + os.environ["WORLD_SIZE"] = str(self.world_size) + os.environ["LOCAL_RANK"] = str(self.local_rank) + + def _set_deepspeed_activation_checkpointing(self) -> None: + import deepspeed + + assert isinstance(self.config, dict) + if self.config.get("activation_checkpointing"): + checkpoint_config = self.config["activation_checkpointing"] + deepspeed.checkpointing.configure( + mpu_=None, + partition_activations=checkpoint_config.get("partition_activations"), + contiguous_checkpointing=checkpoint_config.get("contiguous_memory_optimization"), + checkpoint_in_cpu=checkpoint_config.get("cpu_checkpointing"), + profile=checkpoint_config.get("profile"), + ) + + def _format_config(self) -> None: + if self.config is None: + raise ValueError( + "To use DeepSpeed you must pass in a DeepSpeed config dict, or a path to a JSON config." + " See: https://lightning.ai/docs/pytorch/stable/advanced/model_parallel.html#deepspeed" + ) + + self.config.setdefault("train_micro_batch_size_per_gpu", 1) + _format_precision_config( + config=self.config, + precision=self.precision.precision, + loss_scale=self.loss_scale, + loss_scale_window=self.loss_scale_window, + min_loss_scale=self.min_loss_scale, + initial_scale_power=self.initial_scale_power, + hysteresis=self.hysteresis, + ) - @config.setter - def config(self, config: dict[str, Any]) -> None: - self.deepspeed_impl.config = config + def _create_default_config( + self, + zero_optimization: bool, + zero_allow_untested_optimizer: bool, + logging_batch_size_per_gpu: Optional[int], + partition_activations: bool, + cpu_checkpointing: bool, + contiguous_memory_optimization: bool, + synchronize_checkpoint_boundary: bool, + offload_optimizer: bool, + offload_parameters: bool, + nvme_path: str, + offload_params_device: str, + params_buffer_count: int, + params_buffer_size: int, + max_in_cpu: int, + offload_optimizer_device: str, + optimizer_buffer_count: int, + pin_memory: bool, + block_size: int, + queue_depth: int, + single_submit: bool, + overlap_events: bool, + thread_count: int, + **zero_kwargs: Any, + ) -> dict: + cfg = { + "activation_checkpointing": { + "partition_activations": partition_activations, + "cpu_checkpointing": cpu_checkpointing, + "contiguous_memory_optimization": contiguous_memory_optimization, + "synchronize_checkpoint_boundary": synchronize_checkpoint_boundary, + }, + "aio": { + "block_size": block_size, + "queue_depth": queue_depth, + "single_submit": single_submit, + "overlap_events": overlap_events, + "thread_count": thread_count, + }, + } + if zero_optimization: + zero_config = zero_kwargs + + if offload_optimizer: + zero_config["offload_optimizer"] = { + "device": offload_optimizer_device, + "nvme_path": nvme_path, + "buffer_count": optimizer_buffer_count, + "pin_memory": pin_memory, + } + if offload_parameters: + zero_config["offload_param"] = { + "device": offload_params_device, + "nvme_path": nvme_path, + "buffer_count": params_buffer_count, + "buffer_size": params_buffer_size, + "max_in_cpu": max_in_cpu, + "pin_memory": pin_memory, + } + cfg.update({ + "zero_allow_untested_optimizer": zero_allow_untested_optimizer, + "zero_optimization": zero_config, + }) + if logging_batch_size_per_gpu: + cfg["train_micro_batch_size_per_gpu"] = logging_batch_size_per_gpu + return cfg + + def _restore_zero_state(self, module: Module, ckpt: Mapping[str, Any]) -> None: + """Overrides the normal load_state_dict behaviour in PyTorch to ensure we gather parameters that may be sharded + across processes before loading the state dictionary when using ZeRO stage 3. This is then automatically synced + across processes. - @property - def load_full_weights(self) -> bool: - return self.deepspeed_impl.load_full_weights + Args: + ckpt: The ckpt file. + + """ + import deepspeed + + def load(module: torch.nn.Module, prefix: str = "") -> None: + missing_keys: list[str] = [] + unexpected_keys: list[str] = [] + error_msgs: list[str] = [] + state_dict = ckpt["state_dict"] + + # copy state_dict so _load_from_state_dict can modify it + metadata = getattr(state_dict, "_metadata", None) + state_dict = state_dict.copy() + if metadata is not None: + state_dict._metadata = metadata + + local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {}) + # because zero3 puts placeholders in model params, this context + # manager gathers (unpartitions) the params of the current layer, then loads from + # the state dict and then re-partitions them again + with deepspeed.zero.GatheredParameters(list(module.parameters(recurse=False)), modifier_rank=0): + if self.is_global_zero: + module._load_from_state_dict( + state_dict=state_dict, + prefix=prefix, + local_metadata=local_metadata, + strict=True, + missing_keys=missing_keys, + unexpected_keys=unexpected_keys, + error_msgs=error_msgs, + ) + + for name, child in module._modules.items(): + if child is not None: + load(child, prefix + name + ".") + + load(module, prefix="") + + def _load_config(self, config: Optional[Union[_PATH, dict[str, Any]]]) -> Optional[dict[str, Any]]: + if config is None and self.DEEPSPEED_ENV_VAR in os.environ: + rank_zero_info(f"Loading DeepSpeed config from set {self.DEEPSPEED_ENV_VAR} environment variable") + config = os.environ[self.DEEPSPEED_ENV_VAR] + if isinstance(config, (str, Path)): + if not os.path.isfile(config): + raise FileNotFoundError( + f"You passed in a path to a DeepSpeed config but the path does not exist: {config}" + ) + with open(config) as f: + config = json.load(f) + assert isinstance(config, dict) or config is None + return config + + +def _get_deepspeed_engines_from_state(state: dict[str, Any]) -> list["DeepSpeedEngine"]: + from deepspeed import DeepSpeedEngine + + modules = chain(*(module.modules() for module in state.values() if isinstance(module, Module))) + return [engine for engine in modules if isinstance(engine, DeepSpeedEngine)] + + +def _validate_state_keys(state: dict[str, Any]) -> None: + # DeepSpeed merges the client state into its internal engine state when saving, but it does not check for + # colliding keys from the user. We explicitly check it here: + deepspeed_internal_keys = { + "module", + "buffer_names", + "optimizer", + "param_shapes", + "lr_scheduler", + "sparse_tensor_module_names", + "skipped_steps", + "global_steps", + "global_samples", + "dp_world_size", + "mp_world_size", + "ds_config", + "ds_version", + } + colliding_keys = deepspeed_internal_keys.intersection(state.keys()) + if colliding_keys: + rank_zero_warn( + "Your state has keys that collide with DeepSpeed's internal engine state. This could result in your" + " values being overwritten by DeepSpeed. Consider changing the name of these keys to something else: " + + ", ".join(colliding_keys) + ) + + +def _validate_device_index_selection(parallel_devices: list[torch.device]) -> None: + selected_device_indices = [device.index for device in parallel_devices] + expected_device_indices = list(range(len(parallel_devices))) + if selected_device_indices != expected_device_indices: + raise RuntimeError( + f"The selected device indices {selected_device_indices!r} don't match the local rank values of processes." + " If you need to select GPUs at a specific index, set the `CUDA_VISIBLE_DEVICES` environment variable" + f" instead. For example: `CUDA_VISIBLE_DEVICES={','.join(str(i) for i in selected_device_indices)}`." + ) + + +def _is_deepspeed_checkpoint(path: Path) -> bool: + """Heuristic check whether the path points to a top-level DeepSpeed checkpoint directory.""" + return path.is_dir() and (path / "checkpoint").is_dir() + + +def _validate_checkpoint_directory(path: _PATH) -> None: + """Validates that the path points to a DeepSpeed checkpoint directory and suggests fixes for user error.""" + # Example DeepSpeed checkpoint directory: + # + # epoch=5-step=10999.ckpt + # ├── checkpoint + # │ ├── zero_pp_rank_0_mp_rank_00_model_states.pt + # │ ├── zero_pp_rank_0_mp_rank_00_optim_states.pt + # │ ├── zero_pp_rank_1_mp_rank_00_model_states.pt + # │ └── zero_pp_rank_1_mp_rank_00_optim_states.pt + # ├── latest + # └── zero_to_fp32.py + + path = Path(path) + path_is_ds_checkpoint = _is_deepspeed_checkpoint(path) + default_message = f"The provided path is not a valid DeepSpeed checkpoint: {path}" + + if not path_is_ds_checkpoint: + # Case 1: User may have accidentally passed the subfolder "checkpoint" + parent_is_ds_checkpoint = _is_deepspeed_checkpoint(path.parent) + if parent_is_ds_checkpoint: + raise FileNotFoundError( + f"{default_message}. It looks like you passed the path to a subfolder." + f" Try to load using this parent directory instead: {path.parent}" + ) + # Case 2: User may have accidentally passed the path to a file inside the "checkpoint" subfolder + parent_parent_is_ds_checkpoint = path.is_file() and _is_deepspeed_checkpoint(path.parent.parent) + if parent_parent_is_ds_checkpoint: + raise FileNotFoundError( + f"{default_message}. It looks like you passed the path to a file inside a DeepSpeed checkpoint folder." + f" Try to load using this parent directory instead: {path.parent.parent}" + ) + raise FileNotFoundError(default_message) + + +def _format_precision_config( + config: dict[str, Any], + precision: str, + loss_scale: float, + loss_scale_window: int, + min_loss_scale: int, + initial_scale_power: int, + hysteresis: int, +) -> None: + if "fp16" not in config and precision in ("16-mixed", "16-true"): + # FP16 is a DeepSpeed standalone AMP implementation + rank_zero_info("Enabling DeepSpeed FP16. Model parameters and inputs will be cast to `float16`.") + config["fp16"] = { + "enabled": True, + "loss_scale": loss_scale, + "initial_scale_power": initial_scale_power, + "loss_scale_window": loss_scale_window, + "hysteresis": hysteresis, + "min_loss_scale": min_loss_scale, + } + elif "bf16" not in config and precision in ("bf16-mixed", "bf16-true"): + rank_zero_info("Enabling DeepSpeed BF16. Model parameters and inputs will be cast to `bfloat16`.") + config["bf16"] = {"enabled": True} diff --git a/src/lightning/fabric/strategies/launchers/xla.py b/src/lightning/fabric/strategies/launchers/xla.py index 566312754339b..639de55805646 100644 --- a/src/lightning/fabric/strategies/launchers/xla.py +++ b/src/lightning/fabric/strategies/launchers/xla.py @@ -11,12 +11,17 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import TYPE_CHECKING, Any, Callable, Union +import queue +import time +from typing import TYPE_CHECKING, Any, Callable, Optional, Union +import torch.multiprocessing as mp from typing_extensions import override +from lightning.fabric.accelerators.xla import _XLA_AVAILABLE from lightning.fabric.strategies.launchers.launcher import _Launcher -from lightning.fabric.utilities.imports import _raise_enterprise_not_available +from lightning.fabric.strategies.launchers.multiprocessing import _GlobalStateSnapshot +from lightning.fabric.utilities.apply_func import move_data_to_device if TYPE_CHECKING: from lightning.fabric.strategies import XLAFSDPStrategy, XLAStrategy @@ -40,16 +45,15 @@ class _XLALauncher(_Launcher): """ def __init__(self, strategy: Union["XLAStrategy", "XLAFSDPStrategy"]) -> None: - super().__init__() - _raise_enterprise_not_available() - from pytorch_lightning_enterprise.strategies.xla.launcher import XLALauncherFabric as EnterpriseXLALauncher - - self.xla_impl = EnterpriseXLALauncher(strategy=strategy) + if not _XLA_AVAILABLE: + raise ModuleNotFoundError(str(_XLA_AVAILABLE)) + self._strategy = strategy + self._start_method = "fork" @property @override def is_interactive_compatible(self) -> bool: - return self.xla_impl.is_interactive_compatible + return True @override def launch(self, function: Callable, *args: Any, **kwargs: Any) -> Any: @@ -64,12 +68,61 @@ def launch(self, function: Callable, *args: Any, **kwargs: Any) -> Any: **kwargs: Optional keyword arguments to be passed to the given function. """ - return self.xla_impl.launch(function=function, *args, **kwargs) - - @property - def _start_method(self) -> str: - return self.xla_impl._start_method - - @_start_method.setter - def _start_method(self, start_method: str) -> None: - self.xla_impl._start_method = start_method + return_queue: Union[queue.Queue, mp.SimpleQueue] + return_queue = mp.Manager().Queue() + + import torch_xla.distributed.xla_multiprocessing as xmp + + spawn_kwargs = {} + nprocs = self._strategy.num_processes + if nprocs == 1: + # avoid warning: "Unsupported nprocs". If it's 1, it will call the launched function directly. + # otherwise it will use all devices + spawn_kwargs["nprocs"] = nprocs + + xmp.spawn( + self._wrapping_function, + args=(function, args, kwargs, return_queue), + start_method=self._start_method, + **spawn_kwargs, + ) + return return_queue.get() + + def _wrapping_function( + self, + # XLA's multiprocessing returns the global index, not the local index as torch's multiprocessing + # https://github.com/pytorch/xla/blob/v1.13.0/torch_xla/distributed/xla_multiprocessing.py#L321 + process_idx: int, + function: Callable, + args: Any, + kwargs: Any, + return_queue: Union[mp.SimpleQueue, queue.Queue], + global_states: Optional[_GlobalStateSnapshot] = None, + ) -> None: + import torch_xla.core.xla_model as xm + + if len(xm.get_xla_supported_devices()) > 1: + # `get_xla_supported_devices` in the spawned process returns the logical devices (2 for v2/v3 and 1 for v4) + # so when there's more than one (multithreading), objects need to be deep-copied + import copy + + function, args, kwargs = copy.deepcopy((function, args, kwargs)) + + results = function(*args, **kwargs) + + if self._strategy.local_rank == 0: + return_queue.put(move_data_to_device(results, "cpu")) + + _rank_teardown(self._strategy.local_rank) + + +def _rank_teardown(rank: int) -> None: + import torch_xla.core.xla_model as xm + + # Make all processes wait for each other before joining + # https://github.com/pytorch/xla/issues/1801#issuecomment-602799542 + xm.rendezvous("end-process") + # Ensure that the rank 0 process is the one exiting last + # https://github.com/pytorch/xla/issues/2190#issuecomment-641665358 + if rank == 0: + time.sleep(1) diff --git a/src/lightning/fabric/strategies/single_xla.py b/src/lightning/fabric/strategies/single_xla.py index 59b38ff6157f6..ba2fce91f1146 100644 --- a/src/lightning/fabric/strategies/single_xla.py +++ b/src/lightning/fabric/strategies/single_xla.py @@ -13,14 +13,15 @@ # limitations under the License. from typing import Optional +import torch from typing_extensions import override from lightning.fabric.accelerators import Accelerator +from lightning.fabric.accelerators.xla import _XLA_AVAILABLE from lightning.fabric.plugins import CheckpointIO, Precision, XLAPrecision from lightning.fabric.plugins.io.xla import XLACheckpointIO from lightning.fabric.strategies import _StrategyRegistry from lightning.fabric.strategies.single_device import SingleDeviceStrategy -from lightning.fabric.utilities.imports import _raise_enterprise_not_available from lightning.fabric.utilities.types import _DEVICE @@ -34,16 +35,20 @@ def __init__( checkpoint_io: Optional[XLACheckpointIO] = None, precision: Optional[XLAPrecision] = None, ): - _raise_enterprise_not_available() - from pytorch_lightning_enterprise.strategies.xla.single import validate_xla_strategy + if not _XLA_AVAILABLE: + raise ModuleNotFoundError(str(_XLA_AVAILABLE)) + if isinstance(device, torch.device): + # unwrap the `torch.device` in favor of `xla_device` + device = device.index + + import torch_xla.core.xla_model as xm super().__init__( accelerator=accelerator, - device=device, + device=xm.xla_device(device), checkpoint_io=checkpoint_io, precision=precision, ) - validate_xla_strategy(strategy=self, device=device) @property @override diff --git a/src/lightning/fabric/strategies/xla.py b/src/lightning/fabric/strategies/xla.py index 5a877a5c0dcc8..3a571fef37f00 100644 --- a/src/lightning/fabric/strategies/xla.py +++ b/src/lightning/fabric/strategies/xla.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import io from typing import TYPE_CHECKING, Any, Callable, Optional, Union import torch @@ -21,13 +22,14 @@ from typing_extensions import override from lightning.fabric.accelerators import Accelerator +from lightning.fabric.accelerators.xla import _XLA_GREATER_EQUAL_2_1 from lightning.fabric.plugins import CheckpointIO, Precision, XLAPrecision from lightning.fabric.plugins.environments import XLAEnvironment from lightning.fabric.plugins.io.xla import XLACheckpointIO from lightning.fabric.strategies import ParallelStrategy, _StrategyRegistry from lightning.fabric.strategies.launchers.xla import _XLALauncher from lightning.fabric.strategies.strategy import TBroadcast -from lightning.fabric.utilities.imports import _raise_enterprise_not_available +from lightning.fabric.utilities.rank_zero import rank_zero_only from lightning.fabric.utilities.types import _PATH, ReduceOp if TYPE_CHECKING: @@ -53,19 +55,22 @@ def __init__( checkpoint_io=checkpoint_io, precision=precision, ) - _raise_enterprise_not_available() - from pytorch_lightning_enterprise.strategies.xla.ddp import XLAStrategyFabric as EnterpriseXLAStrategy - - self.xla_strategy_impl = EnterpriseXLAStrategy(outer_object=self, sync_module_states=sync_module_states) + self._backward_sync_control = None # XLA synchronizes gradients in the optimizer.step() call + self._launched = False + self._sync_module_states = sync_module_states @property @override def root_device(self) -> torch.device: - return self.xla_strategy_impl.root_device + if not self._launched: + raise RuntimeError("Accessing the XLA device before processes have spawned is not allowed.") + import torch_xla.core.xla_model as xm + + return xm.xla_device() @property def num_processes(self) -> int: - return self.xla_strategy_impl.num_processes + return len(self.parallel_devices) if self.parallel_devices is not None else 0 @property @override @@ -102,22 +107,22 @@ def precision(self, precision: Optional[Precision]) -> None: @property @override def global_rank(self) -> int: - return self.xla_strategy_impl.global_rank + return super().global_rank if self._launched else 0 @property @override def local_rank(self) -> int: - return self.xla_strategy_impl.local_rank + return super().local_rank if self._launched else 0 @property @override def node_rank(self) -> int: - return self.xla_strategy_impl.node_rank + return super().node_rank if self._launched else 0 @property @override def world_size(self) -> int: - return self.xla_strategy_impl.world_size + return super().world_size if self._launched else 1 @override def _configure_launcher(self) -> None: @@ -125,19 +130,48 @@ def _configure_launcher(self) -> None: @override def setup_environment(self) -> None: - return self.xla_strategy_impl.setup_environment() + assert self.parallel_devices is not None + if len(self.parallel_devices) == 1: + # spawning only 1 device with PjRT is not supported: + # https://github.com/Lightning-AI/pytorch-lightning/pull/17408#discussion_r1170671732 + raise NotImplementedError( + f"The {type(self).__name__} does not support running on a single device with the PjRT runtime." + " Try using all devices or the `SingleDeviceXLAStrategy` strategy" + ) + + self._launched = True + rank_zero_only.rank = self.global_rank + super().setup_environment() @override def setup_module(self, module: Module) -> Module: - return self.xla_strategy_impl.setup_module(module=module) + if self._sync_module_states: + if _XLA_GREATER_EQUAL_2_1: + from torch_xla.core.xla_model import broadcast_master_param + else: + from torch_xla.experimental.pjrt import broadcast_master_param + + broadcast_master_param(module) + + return module @override def module_to_device(self, module: Module) -> None: - return self.xla_strategy_impl.module_to_device(module=module) + module.to(self.root_device) @override def process_dataloader(self, dataloader: DataLoader) -> "MpDeviceLoader": - return self.xla_strategy_impl.process_dataloader(dataloader=dataloader) + from torch_xla.distributed.parallel_loader import MpDeviceLoader + + if isinstance(dataloader, MpDeviceLoader): + # dataloader is already wrapped by MpDeviceLoader + return dataloader + + dataloader = MpDeviceLoader(dataloader, self.root_device) + # Mimic interface to torch.utils.data.DataLoader + dataloader.dataset = dataloader._loader.dataset + dataloader.batch_sampler = getattr(dataloader._loader, "batch_sampler", None) + return dataloader @override def all_gather(self, tensor: Tensor, group: Optional[Any] = None, sync_grads: bool = False) -> Tensor: @@ -151,21 +185,92 @@ def all_gather(self, tensor: Tensor, group: Optional[Any] = None, sync_grads: bo A tensor of shape (world_size, ...) """ - return self.xla_strategy_impl.all_gather(tensor=tensor, group=group, sync_grads=sync_grads) + if not self._launched: + return tensor + if not isinstance(tensor, Tensor): + raise NotImplementedError( + f"`{type(self).__name__}.all_gather` is only implemented for tensors. Given {tensor}" + ) + if tensor.dim() == 0: + tensor = tensor.unsqueeze(0) + original_device = tensor.device + tensor = tensor.to(self.root_device) + + import torch_xla.core.functions as xf + import torch_xla.core.xla_model as xm + + tensor = xf.all_gather(tensor) if sync_grads else xm.all_gather(tensor) + tensor = tensor.to(original_device) + return tensor @override def all_reduce( self, output: Union[Tensor, Any], group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = None ) -> Tensor: - return self.xla_strategy_impl.all_reduce(output=output, group=group, reduce_op=reduce_op) + if not isinstance(output, Tensor): + output = torch.tensor(output, device=self.root_device) + + invalid_reduce_op = isinstance(reduce_op, ReduceOp) and reduce_op != ReduceOp.SUM + invalid_reduce_op_str = isinstance(reduce_op, str) and reduce_op.lower() not in ("sum", "mean", "avg") + if invalid_reduce_op or invalid_reduce_op_str: + raise ValueError( + "Currently, the XLAStrategy only supports `sum`, `mean`, `avg` for the reduce operation, got:" + f" {reduce_op}" + ) + import torch_xla.core.xla_model as xm + + output = xm.mesh_reduce("reduce", output, sum) + + if isinstance(reduce_op, str) and reduce_op.lower() in ("avg", "mean"): + output = output / self.world_size + + return output @override def barrier(self, name: Optional[str] = None, *args: Any, **kwargs: Any) -> None: - return self.xla_strategy_impl.barrier(name=name, *args, **kwargs) + if not self._launched: + return + import torch_xla.core.xla_model as xm + + if name is None: + # `None` is not supported: "TypeError: _xla_rendezvous(): incompatible function arguments" + name = "" + xm.rendezvous(name) @override def broadcast(self, obj: TBroadcast, src: int = 0) -> TBroadcast: - return self.xla_strategy_impl.broadcast(obj=obj, src=src) + if not self._launched: + return obj + + import torch_xla.core.xla_model as xm + + is_tensor = isinstance(obj, Tensor) + if is_tensor: + if obj.dim() == 0: + obj = obj.unsqueeze(0) + original_device = obj.device + # XLA distributed requires that the data is on the XLA device + obj = obj.to(self.root_device) + else: + # support for arbitrary pickle-ables + buffer = io.BytesIO() + torch.save(obj, buffer) + obj = torch.tensor( # type: ignore[assignment] + bytearray(buffer.getbuffer()), device=self.root_device, dtype=torch.float + ) + + obj = [obj] + xm.collective_broadcast(obj, root_ordinal=src) + obj = obj[0] + + if not is_tensor: + # this will preserve the dtype and device of any tensors + buffer = io.BytesIO(obj.cpu().byte().numpy()) + obj = torch.load(buffer) + else: + obj = obj.to(original_device) + + return obj @override def save_checkpoint( @@ -186,9 +291,12 @@ def save_checkpoint( boolean indicating whether the given parameter should be saved (``True``) or filtered out (``False``). """ - return self.xla_strategy_impl.save_checkpoint( - path=path, state=state, storage_options=storage_options, filter=filter - ) + import torch_xla.core.xla_model as xm + + # sync any pending lazy tensors on all ranks before saving to prevent potential collective hangs + xm.mark_step() + # save on global rank zero only + super().save_checkpoint(path, state, storage_options=storage_options, filter=filter) @classmethod @override diff --git a/src/lightning/fabric/strategies/xla_fsdp.py b/src/lightning/fabric/strategies/xla_fsdp.py index f54267ca59a72..51b528eff26ff 100644 --- a/src/lightning/fabric/strategies/xla_fsdp.py +++ b/src/lightning/fabric/strategies/xla_fsdp.py @@ -11,7 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from contextlib import AbstractContextManager +import io +from contextlib import AbstractContextManager, ExitStack, nullcontext +from functools import partial +from pathlib import Path from typing import TYPE_CHECKING, Any, Callable, Literal, Optional, Union import torch @@ -22,16 +25,22 @@ from typing_extensions import override from lightning.fabric.accelerators import Accelerator +from lightning.fabric.accelerators.xla import _XLA_AVAILABLE from lightning.fabric.plugins import CheckpointIO, Precision, XLAPrecision from lightning.fabric.plugins.environments import XLAEnvironment from lightning.fabric.plugins.io.xla import XLACheckpointIO from lightning.fabric.strategies import ParallelStrategy, _StrategyRegistry +from lightning.fabric.strategies.fsdp import _apply_filter from lightning.fabric.strategies.launchers.xla import _XLALauncher from lightning.fabric.strategies.strategy import ( TBroadcast, + _BackwardSyncControl, _Sharded, + _validate_keys_for_strict_loading, ) -from lightning.fabric.utilities.imports import _raise_enterprise_not_available +from lightning.fabric.utilities.cloud_io import get_filesystem +from lightning.fabric.utilities.init import _EmptyInit +from lightning.fabric.utilities.rank_zero import rank_zero_only, rank_zero_warn from lightning.fabric.utilities.types import _PATH, Optimizable, ReduceOp if TYPE_CHECKING: @@ -84,6 +93,8 @@ def __init__( sequential_save: bool = False, **kwargs: Any, ) -> None: + if not _XLA_AVAILABLE: + raise ModuleNotFoundError(str(_XLA_AVAILABLE)) super().__init__( accelerator=accelerator, parallel_devices=parallel_devices, @@ -91,28 +102,27 @@ def __init__( checkpoint_io=checkpoint_io, precision=precision, ) - _raise_enterprise_not_available() - from pytorch_lightning_enterprise.strategies.xla.fsdp import ( - XLAFSDPStrategyFabric as EnterpriseXLAFSDPStrategy, - ) + self._backward_sync_control = _XLAFSDPBackwardSyncControl() - self.xla_fsdp_impl = EnterpriseXLAFSDPStrategy( - outer_object=self, - auto_wrap_policy=auto_wrap_policy, - activation_checkpointing_policy=activation_checkpointing_policy, - state_dict_type=state_dict_type, - sequential_save=sequential_save, - **kwargs, - ) + self._auto_wrap_policy = auto_wrap_policy + self._activation_checkpointing_policy = activation_checkpointing_policy + self._fsdp_kwargs = kwargs + self._state_dict_type = state_dict_type + self._sequential_save = sequential_save + self._launched = False @property @override def root_device(self) -> torch.device: - return self.xla_fsdp_impl.root_device + if not self._launched: + raise RuntimeError("Accessing the XLA device before processes have spawned is not allowed.") + import torch_xla.core.xla_model as xm + + return xm.xla_device() @property def num_processes(self) -> int: - return self.xla_fsdp_impl.num_processes + return len(self.parallel_devices) if self.parallel_devices is not None else 0 @property @override @@ -149,22 +159,22 @@ def precision(self, precision: Optional[Precision]) -> None: @property @override def global_rank(self) -> int: - return self.xla_fsdp_impl.global_rank + return super().global_rank if self._launched else 0 @property @override def local_rank(self) -> int: - return self.xla_fsdp_impl.local_rank + return super().local_rank if self._launched else 0 @property @override def node_rank(self) -> int: - return self.xla_fsdp_impl.node_rank + return super().node_rank if self._launched else 0 @property @override def world_size(self) -> int: - return self.xla_fsdp_impl.world_size + return super().world_size if self._launched else 1 @override def _configure_launcher(self) -> None: @@ -172,40 +182,108 @@ def _configure_launcher(self) -> None: @override def setup_environment(self) -> None: - return self.xla_fsdp_impl.setup_environment() + assert self.parallel_devices is not None + if len(self.parallel_devices) == 1: + # spawning only 1 device with PjRT is not supported: + # https://github.com/Lightning-AI/pytorch-lightning/pull/17408#discussion_r1170671732 + raise NotImplementedError( + f"The {type(self).__name__} does not support running on a single device with the PjRT runtime." + " Try using all devices or the `SingleDeviceXLAStrategy` strategy" + ) + + self._launched = True + rank_zero_only.rank = self.global_rank + super().setup_environment() @override def setup_module_and_optimizers( self, module: Module, optimizers: list[Optimizer], scheduler: Optional["_LRScheduler"] = None ) -> tuple[Module, list[Optimizer], Optional["_LRScheduler"]]: - return self.xla_fsdp_impl.setup_module_and_optimizers(module=module, optimizers=optimizers, scheduler=scheduler) + """Returns NotImplementedError since for XLAFSDP optimizer setup must happen after module setup.""" + raise NotImplementedError( + f"The `{type(self).__name__}` does not support the joint setup of module and optimizer(s)." + " Please do it in this order: Create the model, call `setup_module`, create the optimizer," + " call `setup_optimizer`." + ) @override def setup_module(self, module: Module) -> Module: - return self.xla_fsdp_impl.setup_module(module=module) + from torch_xla.distributed.fsdp import XlaFullyShardedDataParallel as XLAFSDP + + kwargs = self._parse_fsdp_kwargs() + if any(isinstance(mod, XLAFSDP) for mod in module.modules()) and "auto_wrap_policy" in kwargs: + rank_zero_warn( + "A XLAFSDP `auto_wrap_policy` is set, but at least one submodule is already wrapped." + " The policy will be ignored." + ) + del kwargs["auto_wrap_policy"] + # XLA FSDP requires that the root is wrapped, even if submodules are already wrapped + if not isinstance(module, XLAFSDP): + module = XLAFSDP(module=module, **kwargs) + return module @override def module_to_device(self, module: Module) -> None: - return self.xla_fsdp_impl.module_to_device(module=module) + pass def module_init_context(self, empty_init: Optional[bool] = None) -> AbstractContextManager: - return self.xla_fsdp_impl.module_init_context(empty_init=empty_init) + precision_init_ctx = self.precision.module_init_context() + module_sharded_ctx = self.module_sharded_context() + stack = ExitStack() + stack.enter_context(_EmptyInit(enabled=bool(empty_init))) + stack.enter_context(precision_init_ctx) + stack.enter_context(module_sharded_ctx) + return stack @override def module_sharded_context(self) -> AbstractContextManager: - return self.xla_fsdp_impl.module_sharded_context() + return nullcontext() @override def process_dataloader(self, dataloader: DataLoader) -> "MpDeviceLoader": - return self.xla_fsdp_impl.process_dataloader(dataloader=dataloader) + from torch_xla.distributed.parallel_loader import MpDeviceLoader + + if isinstance(dataloader, MpDeviceLoader): + # dataloader is already wrapped by MpDeviceLoader + return dataloader + + dataloader = MpDeviceLoader(dataloader, self.root_device) + # Mimic interface to torch.utils.data.DataLoader + dataloader.dataset = dataloader._loader.dataset + dataloader.batch_sampler = getattr(dataloader._loader, "batch_sampler", None) + return dataloader @override def setup_optimizer(self, optimizer: Optimizer) -> Optimizer: - return self.xla_fsdp_impl.setup_optimizer(optimizer=optimizer) + """Set up an optimizer for a model wrapped with XLAFSDP. + + This setup method doesn't modify the optimizer or wrap the optimizer. The only thing it currently does is verify + that the optimizer was created after the model was wrapped with :meth:`setup_module` with a reference to the + flattened parameters. + + """ + if any(getattr(p, "_is_sharded", False) for group in optimizer.param_groups for p in group["params"]): + return optimizer + raise ValueError( + "The optimizer does not seem to reference any XLAFSDP parameters. HINT: Make sure to create the optimizer" + " after setting up the model." + ) @override def optimizer_step(self, optimizer: Optimizable, **kwargs: Any) -> Any: - return self.xla_fsdp_impl.optimizer_step(optimizer=optimizer, **kwargs) + """Overrides default tpu optimizer_step since FSDP should not call `torch_xla.core.xla_model.optimizer_step`. + Performs the actual optimizer step. + + Args: + optimizer: the optimizer performing the step + **kwargs: Any extra arguments to ``optimizer.step`` + + """ + loss = optimizer.step(**kwargs) + import torch_xla.core.xla_model as xm + + xm.mark_step() + return loss @override def clip_gradients_norm( @@ -217,18 +295,17 @@ def clip_gradients_norm( error_if_nonfinite: bool = True, ) -> Tensor: """Clip gradients by norm.""" - return self.xla_fsdp_impl.clip_gradients_norm( - module=module, - optimizer=optimizer, - max_norm=max_norm, - norm_type=norm_type, - error_if_nonfinite=error_if_nonfinite, - ) + self.precision.unscale_gradients(optimizer) + assert callable(module.clip_grad_norm_) + return module.clip_grad_norm_(max_norm=max_norm, norm_type=norm_type) @override def clip_gradients_value(self, module: Module, optimizer: Optimizer, clip_val: Union[float, int]) -> None: """Clip gradients by value.""" - return self.xla_fsdp_impl.clip_gradients_value(module=module, optimizer=optimizer, clip_val=clip_val) + raise NotImplementedError( + "XLA's FSDP strategy does not support to clip gradients by value." + " Consider clipping by norm instead or choose another strategy!" + ) @override def all_gather(self, tensor: Tensor, group: Optional[Any] = None, sync_grads: bool = False) -> Tensor: @@ -242,21 +319,92 @@ def all_gather(self, tensor: Tensor, group: Optional[Any] = None, sync_grads: bo A tensor of shape (world_size, ...) """ - return self.xla_fsdp_impl.all_gather(tensor=tensor, group=group, sync_grads=sync_grads) + if not self._launched: + return tensor + if not isinstance(tensor, Tensor): + raise NotImplementedError( + f"`{type(self).__name__}.all_gather` is only implemented for tensors. Given {tensor}" + ) + if tensor.dim() == 0: + tensor = tensor.unsqueeze(0) + original_device = tensor.device + tensor = tensor.to(self.root_device) + + import torch_xla.core.functions as xf + import torch_xla.core.xla_model as xm + + tensor = xf.all_gather(tensor) if sync_grads else xm.all_gather(tensor) + tensor = tensor.to(original_device) + return tensor @override def all_reduce( self, output: Union[Tensor, Any], group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = None ) -> Tensor: - return self.xla_fsdp_impl.all_reduce(output=output, group=group, reduce_op=reduce_op) + if not isinstance(output, Tensor): + output = torch.tensor(output, device=self.root_device) + + invalid_reduce_op = isinstance(reduce_op, ReduceOp) and reduce_op != ReduceOp.SUM + invalid_reduce_op_str = isinstance(reduce_op, str) and reduce_op.lower() not in ("sum", "mean", "avg") + if invalid_reduce_op or invalid_reduce_op_str: + raise ValueError( + "Currently, the XLAFSDPStrategy only supports `sum`, `mean`, `avg` for the reduce operation, got:" + f" {reduce_op}" + ) + import torch_xla.core.xla_model as xm + + output = xm.mesh_reduce("reduce", output, sum) + + if isinstance(reduce_op, str) and reduce_op.lower() in ("avg", "mean"): + output = output / self.world_size + + return output @override def barrier(self, name: Optional[str] = None, *args: Any, **kwargs: Any) -> None: - return self.xla_fsdp_impl.barrier(name=name, *args, **kwargs) + if not self._launched: + return + import torch_xla.core.xla_model as xm + + if name is None: + # `None` is not supported: "TypeError: _xla_rendezvous(): incompatible function arguments" + name = "" + xm.rendezvous(name) @override def broadcast(self, obj: TBroadcast, src: int = 0) -> TBroadcast: - return self.xla_fsdp_impl.broadcast(obj=obj, src=src) + if not self._launched: + return obj + + import torch_xla.core.xla_model as xm + + is_tensor = isinstance(obj, Tensor) + if is_tensor: + if obj.dim() == 0: + obj = obj.unsqueeze(0) + original_device = obj.device + # XLA distributed requires that the data is on the XLA device + obj = obj.to(self.root_device) + else: + # support for arbitrary pickle-ables + buffer = io.BytesIO() + torch.save(obj, buffer) + obj = torch.tensor( # type: ignore[assignment] + bytearray(buffer.getbuffer()), device=self.root_device, dtype=torch.float + ) + + obj = [obj] + xm.collective_broadcast(obj, root_ordinal=src) + obj = obj[0] + + if not is_tensor: + # this will preserve the dtype and device of any tensors + buffer = io.BytesIO(obj.cpu().byte().numpy()) + obj = torch.load(buffer) + else: + obj = obj.to(original_device) + + return obj @override def save_checkpoint( @@ -273,8 +421,93 @@ def save_checkpoint( consolidated checkpoint combining all of the sharded checkpoints. """ - return self.xla_fsdp_impl.save_checkpoint( - path=path, state=state, storage_options=storage_options, filter=filter + # broadcast the path from rank 0 to ensure all the states are saved in a common path + path = Path(self.broadcast(path)) + if path.is_dir() and any(path.iterdir()): + raise FileExistsError(f"The checkpoint directory already exists and is not empty: {path}") + from torch_xla.distributed.fsdp import XlaFullyShardedDataParallel as XLAFSDP + + modules = [module for module in state.values() if isinstance(module, XLAFSDP)] + if len(modules) == 0: + raise ValueError( + "Could not find a XLAFSDP model in the provided checkpoint state. Please provide the model as" + " part of the state like so: `save_checkpoint(..., state={'model': model, ...})`. Make sure" + " you set up the model (and optimizers if any) through the strategy before saving the checkpoint." + ) + if len(modules) > 1: + raise ValueError( + "Found multiple XLAFSDP modules in the given state. Saving checkpoints with FSDP is" + " currently limited to a single model per checkpoint. To save multiple models, call the" + " save method for each model separately with a different path." + ) + import torch_xla.core.xla_model as xm + + # ensure model parameters are updated + xm.mark_step() + + parallel_devices = self.parallel_devices + assert parallel_devices is not None + if self._sequential_save: + # each host runs this in parallel, but the ranks in the host run it sequentially + for rank in range(len(parallel_devices)): + if rank == self.local_rank: + self._save_checkpoint_shard(path, state, storage_options, filter) + self.barrier(f"wait-for-{rank}-save") + else: + self._save_checkpoint_shard(path, state, storage_options, filter) + + if self._state_dict_type == "full": + ckpt_prefix = str(path / "checkpoint") + ckpt_suffix = "_rank-*-of-*.pth" + if len(parallel_devices) != self.world_size: # multihost + raise OSError( + "Multihost setups do not have a shared filesystem, so the checkpoint shards cannot be consolidated" + " into a single checkpoint after saving them. Please switch to" + " `XLAFSDPStrategy(state_dict_type='sharded')`. TIP: You can consolidate them manually by getting" + " them together into a single directory and running `python -m" + f" torch_xla.distributed.fsdp.consolidate_sharded_ckpts --ckpt_prefix {ckpt_prefix!r} --ckpt_suffix" + f" {ckpt_suffix!r} --save_path 'path/to/consolidated.ckpt'`." + ) + + from torch_xla.distributed.fsdp import consolidate_sharded_model_checkpoints + + self.barrier("before_ckpt_consolidation") + if self.is_global_zero: + save_path = path.parent / "consolidated.ckpt" + # save consolidated checkpoint separate to the shards + consolidate_sharded_model_checkpoints(ckpt_prefix, ckpt_suffix, str(save_path)) + # remove the shards directory + self.checkpoint_io.remove_checkpoint(path) + # mv the consolidated checkpoint where the user would expect it + get_filesystem(save_path).mv(str(save_path), str(path)) + self.barrier("after_ckpt_consolidation") + + def _save_checkpoint_shard( + self, + path: Path, + state: dict[str, Union[Module, Optimizer, Any]], + storage_options: Optional[Any], + filter: Optional[dict[str, Callable[[str, Any], bool]]], + ) -> None: + from torch_xla.distributed.fsdp import XlaFullyShardedDataParallel as XLAFSDP + + converted_state: dict[str, Any] = {} + for key, obj in state.items(): + # convert the state + if isinstance(obj, Module) and isinstance(obj, XLAFSDP): + converted = obj.state_dict() + # add shard_metadata to state + converted_state["shard_metadata"] = obj.get_shard_metadata() + elif isinstance(obj, Optimizer): + converted = obj.state_dict() + else: + converted = obj + _apply_filter(key, filter or {}, converted, converted_state) + + self.checkpoint_io.save_checkpoint( + converted_state, + path / f"checkpoint_rank-{self.global_rank:08d}-of-{self.world_size:08d}.pth", + storage_options=storage_options, ) @override @@ -291,9 +524,164 @@ def load_checkpoint( directory of multiple files rather than a single file. """ - return self.xla_fsdp_impl.load_checkpoint(path=path, state=state, strict=strict, weights_only=weights_only) + if not state: + raise ValueError( + f"Got `XLAFSDPStrategy.load_checkpoint(..., state={state!r})` but a state with at least " + " a model instance to reload is required. Pass it in like so:" + " `FSDPStrategy.load_checkpoint(..., state={'model': model, ...})`" + ) + + # broadcast the path from rank 0 to ensure all the states are loaded from a common path + path = Path(self.broadcast(path)) + + if isinstance(state, (Module, Optimizer)): + raise NotImplementedError( + "Loading a single module or optimizer object from a checkpoint" + " is not supported yet with the XLAFSDP strategy." + ) + + from torch_xla.distributed.fsdp import XlaFullyShardedDataParallel as XLAFSDP + + modules = {key: module for key, module in state.items() if isinstance(module, XLAFSDP)} + optimizers = {key: optim for key, optim in state.items() if isinstance(optim, Optimizer)} + if self._state_dict_type == "sharded": + file = path / f"checkpoint_rank-{self.global_rank:08d}-of-{self.world_size:08d}.pth" + if not file.is_file(): + raise ValueError( + f"The path {str(file)!r} does not point to valid sharded checkpoints. Make sure the path points to" + " a directory with XLAFSDP checkpoint shards." + ) + if len(modules) == 0: + raise ValueError( + "Could not find a XLAFSDP model in the provided checkpoint state. Please provide the model as" + " part of the state like so: `load_checkpoint(..., state={'model': model, ...})`. Make sure" + " you set up the model (and optimizers if any) through the strategy before loading the checkpoint." + ) + if len(modules) > 1: + raise ValueError( + "Found multiple XLAFSDP modules in the given state. Loading checkpoints with FSDP is" + " currently limited to a single model per checkpoint. To load multiple models, call the" + " load method for each model separately with a different path." + ) + + _, module = list(modules.items())[0] + sharded_ckpt = torch.load(file) + + module.load_state_dict(sharded_ckpt["model"], strict=strict) + for opt_key, opt in optimizers.items(): + opt.load_state_dict(sharded_ckpt[opt_key]) + + # Load anything leftover from sharded_ckpt + loaded_metadata_keys = sharded_ckpt.keys() - modules.keys() - optimizers.keys() + requested_metadata_keys = state.keys() - modules.keys() - optimizers.keys() + _validate_keys_for_strict_loading(requested_metadata_keys, loaded_metadata_keys, strict=strict) + for key in requested_metadata_keys: + if key in loaded_metadata_keys: + state[key] = sharded_ckpt[key] + loaded_metadata_keys.remove(key) + + metadata = {} + if len(loaded_metadata_keys): + for key in loaded_metadata_keys: + metadata[key] = sharded_ckpt[key] + + # remove "shard_metadata" that is loaded in + if "shard_metadata" in metadata: + metadata.pop("shard_metadata") + + return metadata + + if self._state_dict_type == "full": + if not path.is_file(): + raise ValueError( + f"The path {str(path)!r} does not point to a valid full checkpoint. Make sure the path points to a" + " directory with a full XLAFSDP checkpoint." + ) + if len(optimizers) > 0 or len(state.keys() - modules.keys() - optimizers.keys()) > 0: + rank_zero_warn( + "Loading a full checkpoint will only load the full model." + " The optimizer and any additional metadata are not included." + ) + if len(modules) > 0: + raise ValueError( + "Found a XLAFSDP model in the provided checkpoint state." + " Please provide the model without any XLAFSDP wrapper." + ) + if "model" not in state or not isinstance(model := state["model"], torch.nn.Module): + raise NotImplementedError("XLAFSDP only supports a single model instance with 'model' as the key.") + full_ckpt = torch.load(path, weights_only=weights_only) + model.load_state_dict(full_ckpt.pop("model"), strict=strict) + return full_ckpt + + raise ValueError(f"Unknown state_dict_type: {self._state_dict_type}") @classmethod @override def register_strategies(cls, strategy_registry: _StrategyRegistry) -> None: strategy_registry.register("xla_fsdp", cls, description=cls.__name__) + + def _parse_fsdp_kwargs(self) -> dict: + # this needs to be delayed because `self.precision` isn't available at init + kwargs = self._fsdp_kwargs.copy() + precision = self.precision + if isinstance(precision, XLAPrecision): + # the `compute_dtype` will be passed to the `auto_wrapper_callable` automatically, so we don't need to pass + # it when creating it + kwargs.setdefault("compute_dtype", precision._desired_dtype) + kwargs = _auto_wrap_policy_kwargs(self._auto_wrap_policy, kwargs) + return _activation_checkpointing_kwargs(self._activation_checkpointing_policy, kwargs) + + +def _auto_wrap_policy_kwargs(policy: Optional["_POLICY"], kwargs: dict) -> dict: + if policy is None: + return kwargs + if isinstance(policy, set): + from torch_xla.distributed.fsdp.wrap import transformer_auto_wrap_policy + + # this is not transformer specific despite the name + policy = partial(transformer_auto_wrap_policy, transformer_layer_cls=policy) + kwargs["auto_wrap_policy"] = policy + return kwargs + + +def _activation_checkpointing_auto_wrapper(policy: _POLICY_SET, module: Module, *args: Any, **kwargs: Any) -> Module: + from torch_xla.distributed.fsdp import XlaFullyShardedDataParallel as XLAFSDP + from torch_xla.distributed.fsdp import checkpoint_module + + module = checkpoint_module(module) if isinstance(module, tuple(policy)) else module + return XLAFSDP(module, *args, **kwargs) + + +def _activation_checkpointing_kwargs(policy: Optional[_POLICY_SET], kwargs: dict) -> dict: + if not policy: + return kwargs + if "auto_wrapper_callable" in kwargs: + raise ValueError( + "You cannot set both `auto_wrapper_callable` and `activation_checkpointing_policy`. Choose one" + ) + if not isinstance(policy, set): + raise TypeError( + f"`activation_checkpointing_policy` must be a set, found {policy}. You can try defining and" + " passing `auto_wrapper_callable` instead." + ) + auto_wrapper_callable = partial(_activation_checkpointing_auto_wrapper, policy) + kwargs["auto_wrapper_callable"] = auto_wrapper_callable + return kwargs + + +class _XLAFSDPBackwardSyncControl(_BackwardSyncControl): + @override + def no_backward_sync(self, module: Module, enabled: bool) -> AbstractContextManager: + """Blocks gradient synchronization inside the :class:`~torch_xla.distributed.fsdp.XlaFullyShardedDataParallel` + wrapper.""" + if not enabled: + return nullcontext() + from torch_xla.distributed.fsdp import XlaFullyShardedDataParallel as XLAFSDP + + if not isinstance(module, XLAFSDP): + raise TypeError( + "Blocking backward sync is only possible if the module passed to" + f" `{self.__class__.__name__}.no_backward_sync` is wrapped in `XlaFullyShardedDataParallel`." + f" Got: {module.__class__.__name__}." + ) + return module.no_sync() diff --git a/src/lightning/fabric/utilities/imports.py b/src/lightning/fabric/utilities/imports.py index bac63dd9419c5..ac706cc403a5d 100644 --- a/src/lightning/fabric/utilities/imports.py +++ b/src/lightning/fabric/utilities/imports.py @@ -40,22 +40,3 @@ _TORCH_GREATER_EQUAL_2_6 = compare_version("torch", operator.ge, "2.6.0") _TORCHMETRICS_GREATER_EQUAL_1_0_0 = compare_version("torchmetrics", operator.ge, "1.0.0") _PYTHON_GREATER_EQUAL_3_10_0 = (sys.version_info.major, sys.version_info.minor) >= (3, 10) - -_WANDB_AVAILABLE = RequirementCache("wandb>=0.12.10") -_COMET_AVAILABLE = RequirementCache("comet-ml>=3.44.4") -_MLFLOW_AVAILABLE = RequirementCache("mlflow>=1.0.0") -_MLFLOW_SYNCHRONOUS_AVAILABLE = RequirementCache("mlflow>=2.8.0") -_NEPTUNE_AVAILABLE = RequirementCache("neptune>=1.0") - -_DEEPSPEED_AVAILABLE = RequirementCache("deepspeed") -_DEEPSPEED_GREATER_EQUAL_0_16 = RequirementCache("deepspeed>=0.16.0") -_ENTERPRISE_AVAILABLE = RequirementCache("pytorch_lightning_enterprise") -_TRANSFORMER_ENGINE_AVAILABLE = RequirementCache("transformer_engine>=0.11.0") - - -def _raise_enterprise_not_available() -> None: - if not _ENTERPRISE_AVAILABLE: - raise ModuleNotFoundError( - "pytorch_lightning_enterprise is required to use this feature. " - "Install it with `pip install pytorch-lightning-enterprise`" - ) diff --git a/src/lightning/fabric/utilities/testing/_runif.py b/src/lightning/fabric/utilities/testing/_runif.py index 92521362a2885..d085e4138d742 100644 --- a/src/lightning/fabric/utilities/testing/_runif.py +++ b/src/lightning/fabric/utilities/testing/_runif.py @@ -23,7 +23,8 @@ from lightning.fabric.accelerators import XLAAccelerator from lightning.fabric.accelerators.cuda import num_cuda_devices from lightning.fabric.accelerators.mps import MPSAccelerator -from lightning.fabric.utilities.imports import _DEEPSPEED_AVAILABLE, _TORCH_GREATER_EQUAL_2_4 +from lightning.fabric.strategies.deepspeed import _DEEPSPEED_AVAILABLE +from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_4 def _runif_reasons( diff --git a/src/lightning/pytorch/accelerators/xla.py b/src/lightning/pytorch/accelerators/xla.py index c8b322f477d4d..700a29c2e3e7b 100644 --- a/src/lightning/pytorch/accelerators/xla.py +++ b/src/lightning/pytorch/accelerators/xla.py @@ -39,7 +39,15 @@ def get_device_stats(self, device: _DEVICE) -> dict[str, Any]: A dictionary mapping the metrics (free memory and peak memory) to their values. """ - return self.accelerator_impl.get_device_stats(device) + import torch_xla.core.xla_model as xm + + memory_info = xm.get_memory_info(device) + free_memory = memory_info["kb_free"] + peak_memory = memory_info["kb_total"] - free_memory + return { + "avg. free memory (MB)": free_memory, + "avg. peak memory (MB)": peak_memory, + } @staticmethod @override diff --git a/src/lightning/pytorch/loggers/comet.py b/src/lightning/pytorch/loggers/comet.py index 5e57e38dc4072..b544212e755e2 100644 --- a/src/lightning/pytorch/loggers/comet.py +++ b/src/lightning/pytorch/loggers/comet.py @@ -17,15 +17,18 @@ """ import logging +import os from argparse import Namespace from collections.abc import Mapping from typing import TYPE_CHECKING, Any, Literal, Optional, Union +from lightning_utilities.core.imports import RequirementCache from torch import Tensor from torch.nn import Module from typing_extensions import override -from lightning.fabric.utilities.imports import _raise_enterprise_not_available +from lightning.fabric.utilities.logger import _convert_params +from lightning.fabric.utilities.rank_zero import _get_rank from lightning.pytorch.loggers.logger import Logger, rank_zero_experiment from lightning.pytorch.utilities.rank_zero import rank_zero_only @@ -33,6 +36,7 @@ from comet_ml import ExistingExperiment, Experiment, OfflineExperiment log = logging.getLogger(__name__) +_COMET_AVAILABLE = RequirementCache("comet-ml>=3.44.4", module="comet_ml") FRAMEWORK_NAME = "pytorch-lightning" comet_experiment = Union["Experiment", "ExistingExperiment", "OfflineExperiment"] @@ -204,23 +208,99 @@ def __init__( prefix: Optional[str] = None, **kwargs: Any, ): - _raise_enterprise_not_available() + if not _COMET_AVAILABLE: + raise ModuleNotFoundError(str(_COMET_AVAILABLE)) super().__init__() - from pytorch_lightning_enterprise.loggers.comet import CometLogger as EnterpriseCometLogger - - self.logger_impl = EnterpriseCometLogger( - api_key=api_key, - workspace=workspace, - project=project, - experiment_key=experiment_key, - mode=mode, - online=online, - prefix=prefix, - **kwargs, + ################################################## + # HANDLE PASSED OLD TYPE PARAMS + + # handle old "experiment_name" param + if "experiment_name" in kwargs: + log.warning("The parameter `experiment_name` is deprecated, please use `name` instead.") + experiment_name = kwargs.pop("experiment_name") + + if "name" not in kwargs: + kwargs["name"] = experiment_name + else: + log.warning("You specified both `experiment_name` and `name` parameters, please use `name` only") + + # handle old "project_name" param + if "project_name" in kwargs: + log.warning("The parameter `project_name` is deprecated, please use `project` instead.") + if project is None: + project = kwargs.pop("project_name") + else: + log.warning("You specified both `project_name` and `project` parameters, please use `project` only") + + # handle old "offline" experiment flag + if "offline" in kwargs: + log.warning("The parameter `offline is deprecated, please use `online` instead.") + if online is None: + online = kwargs.pop("offline") + else: + log.warning("You specified both `offline` and `online` parameters, please use `online` only") + + # handle old "save_dir" param + if "save_dir" in kwargs: + log.warning("The parameter `save_dir` is deprecated, please use `offline_directory` instead.") + if "offline_directory" not in kwargs: + kwargs["offline_directory"] = kwargs.pop("save_dir") + else: + log.warning( + "You specified both `save_dir` and `offline_directory` parameters, " + "please use `offline_directory` only" + ) + ################################################## + + self._api_key: Optional[str] = api_key + self._experiment: Optional[comet_experiment] = None + self._workspace: Optional[str] = workspace + self._mode: Optional[Literal["get_or_create", "get", "create"]] = mode + self._online: Optional[bool] = online + self._project_name: Optional[str] = project + self._experiment_key: Optional[str] = experiment_key + self._prefix: Optional[str] = prefix + self._kwargs: dict[str, Any] = kwargs + + # needs to be set before the first `comet_ml` import + # because comet_ml imported after another machine learning libraries (Torch) + os.environ["COMET_DISABLE_AUTO_LOGGING"] = "1" + + import comet_ml + + config_kwargs = self._kwargs.copy() + if online is False: + config_kwargs["disabled"] = True + self._comet_config = comet_ml.ExperimentConfig(**config_kwargs) + + # create real experiment only on main node/process (when strategy=auto/ddp) + if _get_rank() is not None and _get_rank() != 0: + return + + self._create_experiment() + + def _create_experiment(self) -> None: + import comet_ml + + self._experiment = comet_ml.start( + api_key=self._api_key, + workspace=self._workspace, + project=self._project_name, + experiment_key=self._experiment_key, + mode=self._mode, + online=self._online, + experiment_config=self._comet_config, ) + if self._experiment is None: + raise comet_ml.exceptions.ExperimentNotFound("Failed to create Comet experiment.") + + self._experiment_key = self._experiment.get_key() + self._project_name = self._experiment.project_name + self._experiment.log_other("Created from", FRAMEWORK_NAME) + @property @rank_zero_experiment def experiment(self) -> comet_experiment: @@ -233,24 +313,55 @@ def experiment(self) -> comet_experiment: """ - return self.logger_impl.experiment + # if by some chance there is no experiment created yet (for example, when strategy=ddp_spawn) + # then we will create a new one + if not self._experiment: + self._create_experiment() + + return self._experiment @override @rank_zero_only def log_hyperparams(self, params: Union[dict[str, Any], Namespace]) -> None: - return self.logger_impl.log_hyperparams(params) + params = _convert_params(params) + self.experiment.__internal_api__log_parameters__( + parameters=params, + framework=FRAMEWORK_NAME, + flatten_nested=True, + source="manual", + ) @override @rank_zero_only def log_metrics(self, metrics: Mapping[str, Union[Tensor, float]], step: Optional[int] = None) -> None: - return self.logger_impl.log_metrics(metrics, step) + assert rank_zero_only.rank == 0, "experiment tried to log from global_rank != 0" + # Comet.com expects metrics to be a dictionary of detached tensors on CPU + metrics_without_epoch = metrics.copy() + for key, val in metrics_without_epoch.items(): + if isinstance(val, Tensor): + metrics_without_epoch[key] = val.cpu().detach() + + epoch = metrics_without_epoch.pop("epoch", None) + self.experiment.__internal_api__log_metrics__( + metrics_without_epoch, + step=step, + epoch=epoch, + prefix=self._prefix, + framework=FRAMEWORK_NAME, + ) @override @rank_zero_only def finalize(self, status: str) -> None: """We will not end experiment (will not call self._experiment.end()) here to have an ability to continue using it after training is complete but instead of ending we will upload/save all the data.""" - return self.logger_impl.finalize(status) + if self._experiment is None: + # When using multiprocessing, finalize() should be a no-op on the main process, as no experiment has been + # initialized there + return + + # just save the data + self.experiment.flush() @property @override @@ -261,7 +372,7 @@ def save_dir(self) -> Optional[str]: The path to the save directory. """ - return self.logger_impl.save_dir + return self._comet_config.offline_directory @property @override @@ -272,7 +383,7 @@ def name(self) -> Optional[str]: The project name if it is specified. """ - return self.logger_impl.name + return self._project_name @property @override @@ -284,8 +395,27 @@ def version(self) -> Optional[str]: """ # Don't create an experiment if we don't have one - return self.logger_impl.version + if self._experiment is not None: + return self._experiment.get_key() + + def __getstate__(self) -> dict[str, Any]: + state = self.__dict__.copy() + + # Save the experiment id in case an experiment object already exists, + # this way we could create an ExistingExperiment pointing to the same + # experiment + state["_experiment_key"] = self._experiment.get_key() if self._experiment is not None else None + + # Remove the experiment object as it contains hard to pickle objects + # (like network connections), the experiment object will be recreated if + # needed later + state["_experiment"] = None + return state @override def log_graph(self, model: Module, input_array: Optional[Tensor] = None) -> None: - return self.logger_impl.log_graph(model, input_array) + if self._experiment is not None: + self._experiment.__internal_api__set_model_graph__( + graph=model, + framework=FRAMEWORK_NAME, + ) diff --git a/src/lightning/pytorch/loggers/mlflow.py b/src/lightning/pytorch/loggers/mlflow.py index 725fd2cb8bfbb..ff9b2b0d7e542 100644 --- a/src/lightning/pytorch/loggers/mlflow.py +++ b/src/lightning/pytorch/loggers/mlflow.py @@ -16,21 +16,35 @@ ------------- """ +import logging import os +import re +import tempfile from argparse import Namespace from collections.abc import Mapping -from typing import TYPE_CHECKING, Any, Literal, Optional, Union +from pathlib import Path +from time import time +from typing import TYPE_CHECKING, Any, Callable, Literal, Optional, Union +import yaml +from lightning_utilities.core.imports import RequirementCache +from torch import Tensor from typing_extensions import override -from lightning.fabric.utilities.imports import _raise_enterprise_not_available +from lightning.fabric.utilities.logger import _add_prefix, _convert_params, _flatten_dict from lightning.pytorch.callbacks.model_checkpoint import ModelCheckpoint from lightning.pytorch.loggers.logger import Logger, rank_zero_experiment -from lightning.pytorch.utilities.rank_zero import rank_zero_only +from lightning.pytorch.loggers.utilities import _scan_checkpoints +from lightning.pytorch.utilities.rank_zero import rank_zero_only, rank_zero_warn if TYPE_CHECKING: from mlflow.tracking import MlflowClient +log = logging.getLogger(__name__) +LOCAL_FILE_URI_PREFIX = "file:" +_MLFLOW_AVAILABLE = RequirementCache("mlflow>=1.0.0", "mlflow") +_MLFLOW_SYNCHRONOUS_AVAILABLE = RequirementCache("mlflow>=2.8.0", "mlflow") + class MLFlowLogger(Logger): """Log using `MLflow `_. @@ -112,22 +126,31 @@ def __init__( run_id: Optional[str] = None, synchronous: Optional[bool] = None, ): - _raise_enterprise_not_available() - from pytorch_lightning_enterprise.loggers.mlflow import MLFlowLogger as EnterpriseMLFlowLogger - + if not _MLFLOW_AVAILABLE: + raise ModuleNotFoundError(str(_MLFLOW_AVAILABLE)) + if synchronous is not None and not _MLFLOW_SYNCHRONOUS_AVAILABLE: + raise ModuleNotFoundError("`synchronous` requires mlflow>=2.8.0") super().__init__() - self.logger_impl = EnterpriseMLFlowLogger( - experiment_name=experiment_name, - run_name=run_name, - tracking_uri=tracking_uri, - tags=tags, - save_dir=save_dir, - log_model=log_model, - prefix=prefix, - artifact_location=artifact_location, - run_id=run_id, - synchronous=synchronous, - ) + if not tracking_uri: + tracking_uri = f"{LOCAL_FILE_URI_PREFIX}{save_dir}" + + self._experiment_name = experiment_name + self._experiment_id: Optional[str] = None + self._tracking_uri = tracking_uri + self._run_name = run_name + self._run_id = run_id + self.tags = tags + self._log_model = log_model + self._logged_model_time: dict[str, float] = {} + self._checkpoint_callback: Optional[ModelCheckpoint] = None + self._prefix = prefix + self._artifact_location = artifact_location + self._log_batch_kwargs = {} if synchronous is None else {"synchronous": synchronous} + self._initialized = False + + from mlflow.tracking import MlflowClient + + self._mlflow_client = MlflowClient(tracking_uri) @property @rank_zero_experiment @@ -140,7 +163,46 @@ def experiment(self) -> "MlflowClient": self.logger.experiment.some_mlflow_function() """ - return self.logger_impl.experiment + import mlflow + + if self._initialized: + return self._mlflow_client + + mlflow.set_tracking_uri(self._tracking_uri) + + if self._run_id is not None: + run = self._mlflow_client.get_run(self._run_id) + self._experiment_id = run.info.experiment_id + self._initialized = True + return self._mlflow_client + + if self._experiment_id is None: + expt = self._mlflow_client.get_experiment_by_name(self._experiment_name) + if expt is not None and expt.lifecycle_stage != "deleted": + self._experiment_id = expt.experiment_id + else: + log.warning(f"Experiment with name {self._experiment_name} not found. Creating it.") + self._experiment_id = self._mlflow_client.create_experiment( + name=self._experiment_name, artifact_location=self._artifact_location + ) + + if self._run_id is None: + if self._run_name is not None: + self.tags = self.tags or {} + + from mlflow.utils.mlflow_tags import MLFLOW_RUN_NAME + + if MLFLOW_RUN_NAME in self.tags: + log.warning( + f"The tag {MLFLOW_RUN_NAME} is found in tags. The value will be overridden by {self._run_name}." + ) + self.tags[MLFLOW_RUN_NAME] = self._run_name + + resolve_tags = _get_resolve_tags() + run = self._mlflow_client.create_run(experiment_id=self._experiment_id, tags=resolve_tags(self.tags)) + self._run_id = run.info.run_id + self._initialized = True + return self._mlflow_client @property def run_id(self) -> Optional[str]: @@ -150,7 +212,8 @@ def run_id(self) -> Optional[str]: The run id. """ - return self.logger_impl.run_id + _ = self.experiment + return self._run_id @property def experiment_id(self) -> Optional[str]: @@ -160,22 +223,71 @@ def experiment_id(self) -> Optional[str]: The experiment id. """ - return self.logger_impl.experiment_id + _ = self.experiment + return self._experiment_id @override @rank_zero_only def log_hyperparams(self, params: Union[dict[str, Any], Namespace]) -> None: - return self.logger_impl.log_hyperparams(params) + params = _convert_params(params) + params = _flatten_dict(params) + + from mlflow.entities import Param + + # Truncate parameter values to 250 characters. + # TODO: MLflow 1.28 allows up to 500 characters: https://github.com/mlflow/mlflow/releases/tag/v1.28.0 + params_list = [Param(key=k, value=str(v)[:250]) for k, v in params.items()] + + # Log in chunks of 100 parameters (the maximum allowed by MLflow). + for idx in range(0, len(params_list), 100): + self.experiment.log_batch(run_id=self.run_id, params=params_list[idx : idx + 100], **self._log_batch_kwargs) @override @rank_zero_only def log_metrics(self, metrics: Mapping[str, float], step: Optional[int] = None) -> None: - return self.logger_impl.log_metrics(metrics, step) + assert rank_zero_only.rank == 0, "experiment tried to log from global_rank != 0" + + from mlflow.entities import Metric + + metrics = _add_prefix(metrics, self._prefix, self.LOGGER_JOIN_CHAR) + metrics_list: list[Metric] = [] + + timestamp_ms = int(time() * 1000) + for k, v in metrics.items(): + if isinstance(v, str): + log.warning(f"Discarding metric with string value {k}={v}.") + continue + + new_k = re.sub("[^a-zA-Z0-9_/. -]+", "", k) + if k != new_k: + rank_zero_warn( + "MLFlow only allows '_', '/', '.' and ' ' special characters in metric name." + f" Replacing {k} with {new_k}.", + category=RuntimeWarning, + ) + k = new_k + metrics_list.append(Metric(key=k, value=v, timestamp=timestamp_ms, step=step or 0)) + + self.experiment.log_batch(run_id=self.run_id, metrics=metrics_list, **self._log_batch_kwargs) @override @rank_zero_only def finalize(self, status: str = "success") -> None: - return self.logger_impl.finalize(status) + if not self._initialized: + return + if status == "success": + status = "FINISHED" + elif status == "failed": + status = "FAILED" + elif status == "finished": + status = "FINISHED" + + # log checkpoints as artifacts + if self._checkpoint_callback: + self._scan_and_log_checkpoints(self._checkpoint_callback) + + if self.experiment.get_run(self.run_id): + self.experiment.set_terminated(self.run_id, status) @property @override @@ -187,7 +299,9 @@ def save_dir(self) -> Optional[str]: Otherwise returns `None`. """ - return self.logger_impl.save_dir + if self._tracking_uri.startswith(LOCAL_FILE_URI_PREFIX): + return self._tracking_uri[len(LOCAL_FILE_URI_PREFIX) :] + return None @property @override @@ -198,7 +312,7 @@ def name(self) -> Optional[str]: The experiment id. """ - return self.logger_impl.name + return self.experiment_id @property @override @@ -209,8 +323,76 @@ def version(self) -> Optional[str]: The run id. """ - return self.logger_impl.version + return self.run_id @override def after_save_checkpoint(self, checkpoint_callback: ModelCheckpoint) -> None: - return self.logger_impl.after_save_checkpoint(checkpoint_callback) + # log checkpoints as artifacts + if self._log_model == "all" or self._log_model is True and checkpoint_callback.save_top_k == -1: + self._scan_and_log_checkpoints(checkpoint_callback) + elif self._log_model is True: + self._checkpoint_callback = checkpoint_callback + + def _scan_and_log_checkpoints(self, checkpoint_callback: ModelCheckpoint) -> None: + # get checkpoints to be saved with associated score + checkpoints = _scan_checkpoints(checkpoint_callback, self._logged_model_time) + + # log iteratively all new checkpoints + for t, p, s, tag in checkpoints: + metadata = { + # Ensure .item() is called to store Tensor contents + "score": s.item() if isinstance(s, Tensor) else s, + "original_filename": Path(p).name, + "Checkpoint": { + k: getattr(checkpoint_callback, k) + for k in [ + "monitor", + "mode", + "save_last", + "save_top_k", + "save_weights_only", + "_every_n_train_steps", + "_every_n_val_epochs", + ] + # ensure it does not break if `Checkpoint` args change + if hasattr(checkpoint_callback, k) + }, + } + aliases = ["latest", "best"] if p == checkpoint_callback.best_model_path else ["latest"] + + # Artifact path on mlflow + artifact_path = Path(p).stem + + # Log the checkpoint + self.experiment.log_artifact(self._run_id, p, artifact_path) + + # Create a temporary directory to log on mlflow + with tempfile.TemporaryDirectory() as tmp_dir: + # Log the metadata + with open(f"{tmp_dir}/metadata.yaml", "w") as tmp_file_metadata: + yaml.dump(metadata, tmp_file_metadata, default_flow_style=False) + + # Log the aliases + with open(f"{tmp_dir}/aliases.txt", "w") as tmp_file_aliases: + tmp_file_aliases.write(str(aliases)) + + # Log the metadata and aliases + self.experiment.log_artifacts(self._run_id, tmp_dir, artifact_path) + + # remember logged models - timestamp needed in case filename didn't change (lastkckpt or custom name) + self._logged_model_time[p] = t + + +def _get_resolve_tags() -> Callable: + from mlflow.tracking import context + + # before v1.1.0 + if hasattr(context, "resolve_tags"): + from mlflow.tracking.context import resolve_tags + # since v1.1.0 + elif hasattr(context, "registry"): + from mlflow.tracking.context.registry import resolve_tags + else: + resolve_tags = lambda tags: tags + + return resolve_tags diff --git a/src/lightning/pytorch/loggers/neptune.py b/src/lightning/pytorch/loggers/neptune.py index d6ec7d482fdc0..bf9669c824784 100644 --- a/src/lightning/pytorch/loggers/neptune.py +++ b/src/lightning/pytorch/loggers/neptune.py @@ -16,17 +16,23 @@ -------------- """ +import contextlib import logging +import os from argparse import Namespace -from typing import TYPE_CHECKING, Any, Optional, Union +from collections.abc import Generator +from functools import wraps +from typing import TYPE_CHECKING, Any, Callable, Optional, Union +from lightning_utilities.core.imports import RequirementCache from torch import Tensor from typing_extensions import override import lightning.pytorch as pl -from lightning.fabric.utilities.imports import _raise_enterprise_not_available +from lightning.fabric.utilities.logger import _add_prefix, _convert_params, _sanitize_callable_params from lightning.pytorch.callbacks import Checkpoint from lightning.pytorch.loggers.logger import Logger, rank_zero_experiment +from lightning.pytorch.utilities.model_summary import ModelSummary from lightning.pytorch.utilities.rank_zero import rank_zero_only if TYPE_CHECKING: @@ -35,6 +41,27 @@ log = logging.getLogger(__name__) +# Neptune is available with two names on PyPI : `neptune` and `neptune-client` +# `neptune` was introduced as a name transition of neptune-client and the long-term target is to get +# rid of Neptune-client package completely someday. It was introduced as a part of breaking-changes with a release +# of neptune-client==1.0. neptune-client>=1.0 is just an alias of neptune package and have some breaking-changes +# in compare to neptune-client<1.0.0. +_NEPTUNE_AVAILABLE = RequirementCache("neptune>=1.0") +_INTEGRATION_VERSION_KEY = "source_code/integrations/pytorch-lightning" + + +# Neptune client throws `InactiveRunException` when trying to log to an inactive run. +# This may happen when the run was stopped through the UI and the logger is still trying to log to it. +def _catch_inactive(func: Callable) -> Callable: + @wraps(func) + def wrapper(*args: Any, **kwargs: Any) -> Any: + from neptune.exceptions import InactiveRunException + + with contextlib.suppress(InactiveRunException): + return func(*args, **kwargs) + + return wrapper + class NeptuneLogger(Logger): r"""Log using `Neptune `_. @@ -215,19 +242,113 @@ def __init__( prefix: str = "training", **neptune_run_kwargs: Any, ): - _raise_enterprise_not_available() + if not _NEPTUNE_AVAILABLE: + raise ModuleNotFoundError(str(_NEPTUNE_AVAILABLE)) + + # verify if user passed proper init arguments + self._verify_input_arguments(api_key, project, name, run, neptune_run_kwargs) super().__init__() - from pytorch_lightning_enterprise.loggers.neptune import NeptuneLogger as EnterpriseNeptuneLogger - - self.logger_impl = EnterpriseNeptuneLogger( - api_key=api_key, - project=project, - name=name, - run=run, - log_model_checkpoints=log_model_checkpoints, - prefix=prefix, - **neptune_run_kwargs, - ) + self._log_model_checkpoints = log_model_checkpoints + self._prefix = prefix + self._run_name = name + self._project_name = project + self._api_key = api_key + self._run_instance = run + self._neptune_run_kwargs = neptune_run_kwargs + self._run_short_id: Optional[str] = None + + if self._run_instance is not None: + self._retrieve_run_data() + + from neptune.handler import Handler + + # make sure that we've log integration version for outside `Run` instances + root_obj = self._run_instance + if isinstance(root_obj, Handler): + root_obj = root_obj.get_root_object() + + root_obj[_INTEGRATION_VERSION_KEY] = pl.__version__ + + def _retrieve_run_data(self) -> None: + from neptune.handler import Handler + + assert self._run_instance is not None + root_obj = self._run_instance + if isinstance(root_obj, Handler): + root_obj = root_obj.get_root_object() + + root_obj.wait() + + if root_obj.exists("sys/id"): + self._run_short_id = root_obj["sys/id"].fetch() + self._run_name = root_obj["sys/name"].fetch() + else: + self._run_short_id = "OFFLINE" + self._run_name = "offline-name" + + @property + def _neptune_init_args(self) -> dict: + args: dict = {} + # Backward compatibility in case of previous version retrieval + with contextlib.suppress(AttributeError): + args = self._neptune_run_kwargs + + if self._project_name is not None: + args["project"] = self._project_name + + if self._api_key is not None: + args["api_token"] = self._api_key + + if self._run_short_id is not None: + args["run"] = self._run_short_id + + # Backward compatibility in case of previous version retrieval + with contextlib.suppress(AttributeError): + if self._run_name is not None: + args["name"] = self._run_name + + return args + + def _construct_path_with_prefix(self, *keys: str) -> str: + """Return sequence of keys joined by `LOGGER_JOIN_CHAR`, started with `_prefix` if defined.""" + if self._prefix: + return self.LOGGER_JOIN_CHAR.join([self._prefix, *keys]) + return self.LOGGER_JOIN_CHAR.join(keys) + + @staticmethod + def _verify_input_arguments( + api_key: Optional[str], + project: Optional[str], + name: Optional[str], + run: Optional[Union["Run", "Handler"]], + neptune_run_kwargs: dict, + ) -> None: + from neptune import Run + from neptune.handler import Handler + + # check if user passed the client `Run`/`Handler` object + if run is not None and not isinstance(run, (Run, Handler)): + raise ValueError("Run parameter expected to be of type `neptune.Run`, or `neptune.handler.Handler`.") + + # check if user passed redundant neptune.init_run arguments when passed run + any_neptune_init_arg_passed = any(arg is not None for arg in [api_key, project, name]) or neptune_run_kwargs + if run is not None and any_neptune_init_arg_passed: + raise ValueError( + "When an already initialized run object is provided, you can't provide other `neptune.init_run()`" + " parameters." + ) + + def __getstate__(self) -> dict[str, Any]: + state = self.__dict__.copy() + # Run instance can't be pickled + state["_run_instance"] = None + return state + + def __setstate__(self, state: dict[str, Any]) -> None: + import neptune + + self.__dict__ = state + self._run_instance = neptune.init_run(**self._neptune_init_args) @property @rank_zero_experiment @@ -257,23 +378,73 @@ def training_step(self, batch, batch_idx): with NeptuneLogger. """ - return self.logger_impl.experiment + return self.run @property @rank_zero_experiment def run(self) -> "Run": - return self.logger_impl.run + import neptune + + if not self._run_instance: + self._run_instance = neptune.init_run(**self._neptune_init_args) + self._retrieve_run_data() + # make sure that we've log integration version for newly created + self._run_instance[_INTEGRATION_VERSION_KEY] = pl.__version__ + + return self._run_instance @override @rank_zero_only + @_catch_inactive def log_hyperparams(self, params: Union[dict[str, Any], Namespace]) -> None: - return self.logger_impl.log_hyperparams(params) + r"""Log hyperparameters to the run. + + Hyperparameters will be logged under the "/hyperparams" namespace. + + Note: + + You can also log parameters by directly using the logger instance: + ``neptune_logger.experiment["model/hyper-parameters"] = params_dict``. + + In this way you can keep hierarchical structure of the parameters. + + Args: + params: `dict`. + Python dictionary structure with parameters. + + Example:: + + from lightning.pytorch.loggers import NeptuneLogger + import neptune + + PARAMS = { + "batch_size": 64, + "lr": 0.07, + "decay_factor": 0.97, + } + + neptune_logger = NeptuneLogger( + api_key=neptune.ANONYMOUS_API_TOKEN, + project="common/pytorch-lightning-integration" + ) + + neptune_logger.log_hyperparams(PARAMS) + + """ + from neptune.utils import stringify_unsupported + + params = _convert_params(params) + params = _sanitize_callable_params(params) + + parameters_key = self.PARAMETERS_KEY + parameters_key = self._construct_path_with_prefix(parameters_key) + + self.run[parameters_key] = stringify_unsupported(params) @override @rank_zero_only - def log_metrics( # type: ignore[override] - self, metrics: dict[str, Union[Tensor, float]], step: Optional[int] = None - ) -> None: + @_catch_inactive + def log_metrics(self, metrics: dict[str, Union[Tensor, float]], step: Optional[int] = None) -> None: """Log metrics (numeric values) in Neptune runs. Args: @@ -281,12 +452,26 @@ def log_metrics( # type: ignore[override] step: Step number at which the metrics should be recorded """ - return self.logger_impl.log_metrics(metrics, step) + if rank_zero_only.rank != 0: + raise ValueError("run tried to log from global_rank != 0") + + metrics = _add_prefix(metrics, self._prefix, self.LOGGER_JOIN_CHAR) + + for key, val in metrics.items(): + self.run[key].append(val, step=step) @override @rank_zero_only + @_catch_inactive def finalize(self, status: str) -> None: - return self.logger_impl.finalize(status) + if not self._run_instance: + # When using multiprocessing, finalize() should be a no-op on the main process, as no experiment has been + # initialized there + return + if status: + self.run[self._construct_path_with_prefix("status")] = status + + super().finalize(status) @property @override @@ -298,14 +483,21 @@ def save_dir(self) -> Optional[str]: the root directory where experiment logs get saved """ - return self.logger_impl.save_dir + return os.path.join(os.getcwd(), ".neptune") @rank_zero_only + @_catch_inactive def log_model_summary(self, model: "pl.LightningModule", max_depth: int = -1) -> None: - return self.logger_impl.log_model_summary(model, max_depth) + from neptune.types import File + + model_str = str(ModelSummary(model=model, max_depth=max_depth)) + self.run[self._construct_path_with_prefix("model/summary")] = File.from_content( + content=model_str, extension="txt" + ) @override @rank_zero_only + @_catch_inactive def after_save_checkpoint(self, checkpoint_callback: Checkpoint) -> None: """Automatically log checkpointed model. Called after model checkpoint callback saves a new checkpoint. @@ -313,13 +505,83 @@ def after_save_checkpoint(self, checkpoint_callback: Checkpoint) -> None: checkpoint_callback: the model checkpoint callback instance """ - return self.logger_impl.after_save_checkpoint(checkpoint_callback) + if not self._log_model_checkpoints: + return + + file_names = set() + checkpoints_namespace = self._construct_path_with_prefix("model/checkpoints") + + # save last model + if hasattr(checkpoint_callback, "last_model_path") and checkpoint_callback.last_model_path: + model_last_name = self._get_full_model_name(checkpoint_callback.last_model_path, checkpoint_callback) + file_names.add(model_last_name) + self.run[f"{checkpoints_namespace}/{model_last_name}"].upload(checkpoint_callback.last_model_path) + + # save best k models + if hasattr(checkpoint_callback, "best_k_models"): + for key in checkpoint_callback.best_k_models: + model_name = self._get_full_model_name(key, checkpoint_callback) + file_names.add(model_name) + self.run[f"{checkpoints_namespace}/{model_name}"].upload(key) + + # log best model path and checkpoint + if hasattr(checkpoint_callback, "best_model_path") and checkpoint_callback.best_model_path: + self.run[self._construct_path_with_prefix("model/best_model_path")] = checkpoint_callback.best_model_path + + model_name = self._get_full_model_name(checkpoint_callback.best_model_path, checkpoint_callback) + file_names.add(model_name) + self.run[f"{checkpoints_namespace}/{model_name}"].upload(checkpoint_callback.best_model_path) + + # remove old models logged to experiment if they are not part of best k models at this point + if self.run.exists(checkpoints_namespace): + exp_structure = self.run.get_structure() + uploaded_model_names = self._get_full_model_names_from_exp_structure(exp_structure, checkpoints_namespace) + + for file_to_drop in list(uploaded_model_names - file_names): + del self.run[f"{checkpoints_namespace}/{file_to_drop}"] + + # log best model score + if hasattr(checkpoint_callback, "best_model_score") and checkpoint_callback.best_model_score: + self.run[self._construct_path_with_prefix("model/best_model_score")] = ( + checkpoint_callback.best_model_score.cpu().detach().numpy() + ) + + @staticmethod + def _get_full_model_name(model_path: str, checkpoint_callback: Checkpoint) -> str: + """Returns model name which is string `model_path` appended to `checkpoint_callback.dirpath`.""" + if hasattr(checkpoint_callback, "dirpath"): + model_path = os.path.normpath(model_path) + expected_model_path = os.path.normpath(checkpoint_callback.dirpath) + if not model_path.startswith(expected_model_path): + raise ValueError(f"{model_path} was expected to start with {expected_model_path}.") + # Remove extension from filepath + filepath, _ = os.path.splitext(model_path[len(expected_model_path) + 1 :]) + return filepath.replace(os.sep, "/") + return model_path.replace(os.sep, "/") + + @classmethod + def _get_full_model_names_from_exp_structure(cls, exp_structure: dict[str, Any], namespace: str) -> set[str]: + """Returns all paths to properties which were already logged in `namespace`""" + structure_keys: list[str] = namespace.split(cls.LOGGER_JOIN_CHAR) + for key in structure_keys: + exp_structure = exp_structure[key] + uploaded_models_dict = exp_structure + return set(cls._dict_paths(uploaded_models_dict)) + + @classmethod + def _dict_paths(cls, d: dict[str, Any], path_in_build: Optional[str] = None) -> Generator: + for k, v in d.items(): + path = f"{path_in_build}/{k}" if path_in_build is not None else k + if not isinstance(v, dict): + yield path + else: + yield from cls._dict_paths(v, path) @property @override def name(self) -> Optional[str]: """Return the experiment name or 'offline-name' when exp is run in offline mode.""" - return self.logger_impl.name + return self._run_name @property @override @@ -329,4 +591,4 @@ def version(self) -> Optional[str]: It's Neptune Run's short_id """ - return self.logger_impl.version + return self._run_short_id diff --git a/src/lightning/pytorch/loggers/wandb.py b/src/lightning/pytorch/loggers/wandb.py index 41d38b460d0f0..37ca362fa40c1 100644 --- a/src/lightning/pytorch/loggers/wandb.py +++ b/src/lightning/pytorch/loggers/wandb.py @@ -16,24 +16,37 @@ ------------------------- """ +import os from argparse import Namespace from collections.abc import Mapping +from pathlib import Path from typing import TYPE_CHECKING, Any, Literal, Optional, Union import torch.nn as nn +from lightning_utilities.core.imports import RequirementCache +from torch import Tensor from typing_extensions import override -from lightning.fabric.utilities.imports import _raise_enterprise_not_available +from lightning.fabric.utilities.logger import ( + _add_prefix, + _convert_json_serializable, + _convert_params, + _sanitize_callable_params, +) from lightning.fabric.utilities.types import _PATH from lightning.pytorch.callbacks.model_checkpoint import ModelCheckpoint from lightning.pytorch.loggers.logger import Logger, rank_zero_experiment -from lightning.pytorch.utilities.rank_zero import rank_zero_only +from lightning.pytorch.loggers.utilities import _scan_checkpoints +from lightning.pytorch.utilities.exceptions import MisconfigurationException +from lightning.pytorch.utilities.rank_zero import rank_zero_only, rank_zero_warn if TYPE_CHECKING: from wandb import Artifact from wandb.sdk.lib import RunDisabled from wandb.wandb_run import Run +_WANDB_AVAILABLE = RequirementCache("wandb>=0.12.10") + class WandbLogger(Logger): r"""Log using `Weights and Biases `_. @@ -295,27 +308,70 @@ def __init__( add_file_policy: Literal["mutable", "immutable"] = "mutable", **kwargs: Any, ) -> None: - _raise_enterprise_not_available() + if not _WANDB_AVAILABLE: + raise ModuleNotFoundError(str(_WANDB_AVAILABLE)) + + if offline and log_model: + raise MisconfigurationException( + f"Providing log_model={log_model} and offline={offline} is an invalid configuration" + " since model checkpoints cannot be uploaded in offline mode.\n" + "Hint: Set `offline=False` to log your model." + ) super().__init__() - from pytorch_lightning_enterprise.loggers.wandb import WandbLogger as EnterpriseWandbLogger - - self.logger_impl = EnterpriseWandbLogger( - name=name, - save_dir=save_dir, - version=version, - offline=offline, - dir=dir, - id=id, - anonymous=anonymous, - project=project, - log_model=log_model, - experiment=experiment, - prefix=prefix, - checkpoint_name=checkpoint_name, - add_file_policy=add_file_policy, - **kwargs, - ) + self._offline = offline + self._log_model = log_model + self._prefix = prefix + self._experiment = experiment + self._logged_model_time: dict[str, float] = {} + self._checkpoint_callbacks: dict[int, ModelCheckpoint] = {} + self.add_file_policy = add_file_policy + + # paths are processed as strings + if save_dir is not None: + save_dir = os.fspath(save_dir) + elif dir is not None: + dir = os.fspath(dir) + + project = project or os.environ.get("WANDB_PROJECT", "lightning_logs") + + # set wandb init arguments + self._wandb_init: dict[str, Any] = { + "name": name, + "project": project, + "dir": save_dir or dir, + "id": version or id, + "resume": "allow", + "anonymous": ("allow" if anonymous else None), + } + self._wandb_init.update(**kwargs) + # extract parameters + self._project = self._wandb_init.get("project") + self._save_dir = self._wandb_init.get("dir") + self._name = self._wandb_init.get("name") + self._id = self._wandb_init.get("id") + self._checkpoint_name = checkpoint_name + + def __getstate__(self) -> dict[str, Any]: + import wandb + + # Hack: If the 'spawn' launch method is used, the logger will get pickled and this `__getstate__` gets called. + # We create an experiment here in the main process, and attach to it in the worker process. + # Using wandb-service, we persist the same experiment even if multiple `Trainer.fit/test/validate` calls + # are made. + wandb.require("service") + _ = self.experiment + + state = self.__dict__.copy() + # args needed to reload correct experiment + if self._experiment is not None: + state["_id"] = getattr(self._experiment, "id", None) + state["_attach_id"] = getattr(self._experiment, "_attach_id", None) + state["_name"] = self._experiment.name + + # cannot be pickled + state["_experiment"] = None + return state @property @rank_zero_experiment @@ -330,7 +386,40 @@ def experiment(self) -> Union["Run", "RunDisabled"]: self.logger.experiment.some_wandb_function() """ - return self.logger_impl.experiment + import wandb + from wandb.sdk.lib import RunDisabled + from wandb.wandb_run import Run + + if self._experiment is None: + if self._offline: + os.environ["WANDB_MODE"] = "dryrun" + + attach_id = getattr(self, "_attach_id", None) + if wandb.run is not None: + # wandb process already created in this instance + rank_zero_warn( + "There is a wandb run already in progress and newly created instances of `WandbLogger` will reuse" + " this run. If this is not desired, call `wandb.finish()` before instantiating `WandbLogger`." + ) + self._experiment = wandb.run + elif attach_id is not None and hasattr(wandb, "_attach"): + # attach to wandb process referenced + self._experiment = wandb._attach(attach_id) + else: + # create new wandb process + self._experiment = wandb.init(**self._wandb_init) + + # define default x-axis + if isinstance(self._experiment, (Run, RunDisabled)) and getattr( + self._experiment, "define_metric", None + ): + if self._wandb_init.get("sync_tensorboard"): + self._experiment.define_metric("*", step_metric="global_step") + else: + self._experiment.define_metric("trainer/global_step") + self._experiment.define_metric("*", step_metric="trainer/global_step", step_sync=True) + + return self._experiment def watch( self, model: nn.Module, log: Optional[str] = "gradients", log_freq: int = 100, log_graph: bool = True @@ -340,12 +429,21 @@ def watch( @override @rank_zero_only def log_hyperparams(self, params: Union[dict[str, Any], Namespace]) -> None: - return self.logger_impl.log_hyperparams(params) + params = _convert_params(params) + params = _sanitize_callable_params(params) + params = _convert_json_serializable(params) + self.experiment.config.update(params, allow_val_change=True) @override @rank_zero_only def log_metrics(self, metrics: Mapping[str, float], step: Optional[int] = None) -> None: - return self.logger_impl.log_metrics(metrics, step) + assert rank_zero_only.rank == 0, "experiment tried to log from global_rank != 0" + + metrics = _add_prefix(metrics, self._prefix, self.LOGGER_JOIN_CHAR) + if step is not None and not self._wandb_init.get("sync_tensorboard"): + self.experiment.log(dict(metrics, **{"trainer/global_step": step})) + else: + self.experiment.log(metrics) @rank_zero_only def log_table( @@ -361,7 +459,10 @@ def log_table( Can be defined either with `columns` and `data` or with `dataframe`. """ - return self.logger_impl.log_table(key, columns, data, dataframe, step) + import wandb + + metrics = {key: wandb.Table(columns=columns, data=data, dataframe=dataframe)} + self.log_metrics(metrics, step) @rank_zero_only def log_text( @@ -377,7 +478,8 @@ def log_text( Can be defined either with `columns` and `data` or with `dataframe`. """ - return self.logger_impl.log_text(key, columns, data, dataframe, step) + + self.log_table(key, columns, data, dataframe, step) @rank_zero_only def log_image(self, key: str, images: list[Any], step: Optional[int] = None, **kwargs: Any) -> None: @@ -386,7 +488,18 @@ def log_image(self, key: str, images: list[Any], step: Optional[int] = None, **k Optional kwargs are lists passed to each image (ex: caption, masks, boxes). """ - return self.logger_impl.log_image(key, images, step, **kwargs) + if not isinstance(images, list): + raise TypeError(f'Expected a list as "images", found {type(images)}') + n = len(images) + for k, v in kwargs.items(): + if len(v) != n: + raise ValueError(f"Expected {n} items but only found {len(v)} for {k}") + kwarg_list = [{k: kwargs[k][i] for k in kwargs} for i in range(n)] + + import wandb + + metrics = {key: [wandb.Image(img, **kwarg) for img, kwarg in zip(images, kwarg_list)]} + self.log_metrics(metrics, step) # type: ignore[arg-type] @rank_zero_only def log_audio(self, key: str, audios: list[Any], step: Optional[int] = None, **kwargs: Any) -> None: @@ -401,7 +514,18 @@ def log_audio(self, key: str, audios: list[Any], step: Optional[int] = None, **k Optional kwargs are lists passed to each audio (ex: caption, sample_rate). """ - return self.logger_impl.log_audio(key, audios, step, **kwargs) + if not isinstance(audios, list): + raise TypeError(f'Expected a list as "audios", found {type(audios)}') + n = len(audios) + for k, v in kwargs.items(): + if len(v) != n: + raise ValueError(f"Expected {n} items but only found {len(v)} for {k}") + kwarg_list = [{k: kwargs[k][i] for k in kwargs} for i in range(n)] + + import wandb + + metrics = {key: [wandb.Audio(audio, **kwarg) for audio, kwarg in zip(audios, kwarg_list)]} + self.log_metrics(metrics, step) # type: ignore[arg-type] @rank_zero_only def log_video(self, key: str, videos: list[Any], step: Optional[int] = None, **kwargs: Any) -> None: @@ -416,7 +540,18 @@ def log_video(self, key: str, videos: list[Any], step: Optional[int] = None, **k Optional kwargs are lists passed to each video (ex: caption, fps, format). """ - return self.logger_impl.log_video(key, videos, step, **kwargs) + if not isinstance(videos, list): + raise TypeError(f'Expected a list as "videos", found {type(videos)}') + n = len(videos) + for k, v in kwargs.items(): + if len(v) != n: + raise ValueError(f"Expected {n} items but only found {len(v)} for {k}") + kwarg_list = [{k: kwargs[k][i] for k in kwargs} for i in range(n)] + + import wandb + + metrics = {key: [wandb.Video(video, **kwarg) for video, kwarg in zip(videos, kwarg_list)]} + self.log_metrics(metrics, step) # type: ignore[arg-type] @property @override @@ -427,7 +562,7 @@ def save_dir(self) -> Optional[str]: The path to the save directory. """ - return self.logger_impl.save_dir + return self._save_dir @property @override @@ -439,7 +574,7 @@ def name(self) -> Optional[str]: name. To access wandb's internal experiment name, use ``logger.experiment.name`` instead. """ - return self.logger_impl.name + return self._project @property @override @@ -451,12 +586,15 @@ def version(self) -> Optional[str]: """ # don't create an experiment if we don't have one - return self.logger_impl.version + return self._experiment.id if self._experiment else self._id @override def after_save_checkpoint(self, checkpoint_callback: ModelCheckpoint) -> None: # log checkpoints as artifacts - return self.logger_impl.after_save_checkpoint(checkpoint_callback) + if self._log_model == "all" or self._log_model is True and checkpoint_callback.save_top_k == -1: + self._scan_and_log_checkpoints(checkpoint_callback) + elif self._log_model is True: + self._checkpoint_callbacks[id(checkpoint_callback)] = checkpoint_callback @staticmethod @rank_zero_only @@ -478,10 +616,16 @@ def download_artifact( The path to the downloaded artifact. """ - _raise_enterprise_not_available() - from pytorch_lightning_enterprise.loggers.wandb import WandbLogger as EnterpriseWandbLogger + import wandb + + if wandb.run is not None and use_artifact: + artifact = wandb.run.use_artifact(artifact) + else: + api = wandb.Api() + artifact = api.artifact(artifact, type=artifact_type) - return EnterpriseWandbLogger.download_artifact(artifact, save_dir, artifact_type, use_artifact) + save_dir = None if save_dir is None else os.fspath(save_dir) + return artifact.download(root=save_dir) def use_artifact(self, artifact: str, artifact_type: Optional[str] = None) -> "Artifact": """Logs to the wandb dashboard that the mentioned artifact is used by the run. @@ -494,9 +638,49 @@ def use_artifact(self, artifact: str, artifact_type: Optional[str] = None) -> "A wandb Artifact object for the artifact. """ - return self.logger_impl.use_artifact(artifact, artifact_type) + return self.experiment.use_artifact(artifact, type=artifact_type) @override @rank_zero_only def finalize(self, status: str) -> None: - return self.logger_impl.finalize(status) + if status != "success": + # Currently, checkpoints only get logged on success + return + # log checkpoints as artifacts + if self._experiment is not None: + for checkpoint_callback in self._checkpoint_callbacks.values(): + self._scan_and_log_checkpoints(checkpoint_callback) + + def _scan_and_log_checkpoints(self, checkpoint_callback: ModelCheckpoint) -> None: + import wandb + + # get checkpoints to be saved with associated score + checkpoints = _scan_checkpoints(checkpoint_callback, self._logged_model_time) + + # log iteratively all new checkpoints + for t, p, s, tag in checkpoints: + metadata = { + "score": s.item() if isinstance(s, Tensor) else s, + "original_filename": Path(p).name, + checkpoint_callback.__class__.__name__: { + k: getattr(checkpoint_callback, k) + for k in [ + "monitor", + "mode", + "save_last", + "save_top_k", + "save_weights_only", + "_every_n_train_steps", + ] + # ensure it does not break if `ModelCheckpoint` args change + if hasattr(checkpoint_callback, k) + }, + } + if not self._checkpoint_name: + self._checkpoint_name = f"model-{self.experiment.id}" + artifact = wandb.Artifact(name=self._checkpoint_name, type="model", metadata=metadata) + artifact.add_file(p, name="model.ckpt", policy=self.add_file_policy) + aliases = ["latest", "best"] if p == checkpoint_callback.best_model_path else ["latest"] + self.experiment.log_artifact(artifact, aliases=aliases) + # remember logged models - timestamp needed in case filename didn't change (lastkckpt or custom name) + self._logged_model_time[p] = t diff --git a/src/lightning/pytorch/plugins/precision/deepspeed.py b/src/lightning/pytorch/plugins/precision/deepspeed.py index 29834f5d519d0..9225e3bb9e7be 100644 --- a/src/lightning/pytorch/plugins/precision/deepspeed.py +++ b/src/lightning/pytorch/plugins/precision/deepspeed.py @@ -11,21 +11,30 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from contextlib import AbstractContextManager -from typing import Any, Callable, Optional, Union +from contextlib import AbstractContextManager, nullcontext +from typing import TYPE_CHECKING, Any, Callable, Optional, Union import torch +from lightning_utilities import apply_to_collection from torch import Tensor from torch.nn import Module -from torch.optim import Optimizer -from typing_extensions import override +from torch.optim import LBFGS, Optimizer +from typing_extensions import get_args, override import lightning.pytorch as pl -from lightning.fabric.plugins.precision.deepspeed import _PRECISION_INPUT, _PRECISION_INPUT_STR -from lightning.fabric.utilities.imports import _raise_enterprise_not_available +from lightning.fabric.plugins.precision.deepspeed import _PRECISION_INPUT +from lightning.fabric.plugins.precision.utils import _convert_fp_tensor, _DtypeContextManager from lightning.fabric.utilities.types import Steppable from lightning.pytorch.plugins.precision.precision import Precision from lightning.pytorch.utilities import GradClipAlgorithmType +from lightning.pytorch.utilities.exceptions import MisconfigurationException +from lightning.pytorch.utilities.model_helpers import is_overridden +from lightning.pytorch.utilities.rank_zero import WarningCache + +if TYPE_CHECKING: + import deepspeed + +warning_cache = WarningCache() class DeepSpeedPrecision(Precision): @@ -44,29 +53,41 @@ class DeepSpeedPrecision(Precision): """ def __init__(self, precision: _PRECISION_INPUT) -> None: - super().__init__() - _raise_enterprise_not_available() - from pytorch_lightning_enterprise.plugins.precision.deepspeed import ( - DeepSpeedPrecisionTrainer as EnterpriseDeepSpeedPrecision, - ) - - self.deepspeed_precision_impl = EnterpriseDeepSpeedPrecision(outer_object=self, precision=precision) + supported_precision = get_args(_PRECISION_INPUT) + if precision not in supported_precision: + raise ValueError( + f"`Trainer(strategy='deepspeed', precision={precision!r})` is not supported." + f" `precision` must be one of: {supported_precision}." + ) + self.precision = precision + precision_to_type = { + "bf16-mixed": torch.bfloat16, + "16-mixed": torch.float16, + "bf16-true": torch.bfloat16, + "16-true": torch.float16, + "32-true": torch.float32, + } + self._desired_dtype = precision_to_type[self.precision] @override def convert_module(self, module: Module) -> Module: - return self.deepspeed_precision_impl.convert_module(module=module) + if "true" in self.precision: + return module.to(dtype=self._desired_dtype) + return module @override def convert_input(self, data: Any) -> Any: - return self.deepspeed_precision_impl.convert_input(data=data) + return apply_to_collection(data, function=_convert_fp_tensor, dtype=Tensor, dst_type=self._desired_dtype) @override def tensor_init_context(self) -> AbstractContextManager: - return self.deepspeed_precision_impl.tensor_init_context() + if "true" not in self.precision: + return nullcontext() + return _DtypeContextManager(self._desired_dtype) @override def module_init_context(self) -> AbstractContextManager: - return self.deepspeed_precision_impl.module_init_context() + return self.tensor_init_context() @override def backward( # type: ignore[override] @@ -77,7 +98,7 @@ def backward( # type: ignore[override] *args: Any, **kwargs: Any, ) -> None: - r"""Performs back-propagation. + r"""Performs back-propagation using DeepSpeed's engine. Args: tensor: the loss tensor @@ -87,7 +108,13 @@ def backward( # type: ignore[override] \**kwargs: additional keyword arguments for the :meth:`deepspeed.DeepSpeedEngine.backward` call """ - return self.deepspeed_precision_impl.backward(tensor=tensor, model=model, optimizer=optimizer, *args, **kwargs) + if is_overridden("backward", model): + warning_cache.warn( + "You have overridden the `LightningModule.backward` hook but it will be ignored since DeepSpeed handles" + " the backward logic internally." + ) + deepspeed_engine: deepspeed.DeepSpeedEngine = model.trainer.model + deepspeed_engine.backward(tensor, *args, **kwargs) @override def optimizer_step( # type: ignore[override] @@ -97,7 +124,19 @@ def optimizer_step( # type: ignore[override] closure: Callable[[], Any], **kwargs: Any, ) -> Any: - return self.deepspeed_precision_impl.optimizer_step(optimizer=optimizer, model=model, closure=closure, **kwargs) + if isinstance(optimizer, LBFGS): + raise MisconfigurationException("DeepSpeed and the LBFGS optimizer are not compatible.") + closure_result = closure() + self._after_closure(model, optimizer) + skipped_backward = closure_result is None + # in manual optimization, the closure does not return a value + if model.automatic_optimization and skipped_backward: + raise MisconfigurationException( + "Skipping backward by returning `None` from your `training_step` is not supported by `DeepSpeed`" + ) + # DeepSpeed handles the optimizer step internally + deepspeed_engine: deepspeed.DeepSpeedEngine = model.trainer.model + return deepspeed_engine.step(**kwargs) @override def clip_gradients( @@ -106,22 +145,4 @@ def clip_gradients( clip_val: Union[int, float] = 0.0, gradient_clip_algorithm: GradClipAlgorithmType = GradClipAlgorithmType.NORM, ) -> None: - return self.deepspeed_precision_impl.clip_gradients( - optimizer=optimizer, clip_val=clip_val, gradient_clip_algorithm=gradient_clip_algorithm - ) - - @property - def precision(self) -> _PRECISION_INPUT_STR: - return self.deepspeed_precision_impl.precision - - @precision.setter - def precision(self, value: _PRECISION_INPUT_STR) -> None: - self.deepspeed_precision_impl.precision = value - - @property - def _desired_dtype(self) -> torch.dtype: - return self.deepspeed_precision_impl._desired_dtype - - @_desired_dtype.setter - def _desired_dtype(self, value: torch.dtype) -> None: - self.deepspeed_precision_impl._desired_dtype = value + """DeepSpeed handles gradient clipping internally.""" diff --git a/src/lightning/pytorch/plugins/precision/xla.py b/src/lightning/pytorch/plugins/precision/xla.py index 5d9a818aecf85..6890cc4c1d825 100644 --- a/src/lightning/pytorch/plugins/precision/xla.py +++ b/src/lightning/pytorch/plugins/precision/xla.py @@ -11,16 +11,19 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import os +from functools import partial from typing import Any, Callable import torch -from typing_extensions import override +from typing_extensions import get_args, override import lightning.pytorch as pl -from lightning.fabric.plugins.precision.xla import _PRECISION_INPUT, _PRECISION_INPUT_STR -from lightning.fabric.utilities.imports import _raise_enterprise_not_available +from lightning.fabric.accelerators.xla import _XLA_AVAILABLE +from lightning.fabric.plugins.precision.xla import _PRECISION_INPUT from lightning.fabric.utilities.types import Optimizable from lightning.pytorch.plugins.precision.precision import Precision +from lightning.pytorch.utilities.exceptions import MisconfigurationException class XLAPrecision(Precision): @@ -36,12 +39,25 @@ class XLAPrecision(Precision): """ def __init__(self, precision: _PRECISION_INPUT = "32-true") -> None: - super().__init__() - - _raise_enterprise_not_available() - from pytorch_lightning_enterprise.plugins.precision.xla import XLAPrecision as EnterpriseXLAPrecision - - self.xla_impl = EnterpriseXLAPrecision(precision) + if not _XLA_AVAILABLE: + raise ModuleNotFoundError(str(_XLA_AVAILABLE)) + + supported_precision = get_args(_PRECISION_INPUT) + if precision not in supported_precision: + raise ValueError( + f"`precision={precision!r})` is not supported in XLA." + f" `precision` must be one of: {supported_precision}." + ) + self.precision = precision + + if precision == "16-true": + os.environ["XLA_USE_F16"] = "1" + self._desired_dtype = torch.float16 + elif precision == "bf16-true": + os.environ["XLA_USE_BF16"] = "1" + self._desired_dtype = torch.bfloat16 + else: + self._desired_dtype = torch.float32 @override def optimizer_step( # type: ignore[override] @@ -51,24 +67,31 @@ def optimizer_step( # type: ignore[override] closure: Callable[[], Any], **kwargs: Any, ) -> Any: - return self.xla_impl.optimizer_step(optimizer, model, closure, **kwargs) - - @property - def precision(self) -> _PRECISION_INPUT_STR: - return self.xla_impl.precision - - @precision.setter - def precision(self, precision: _PRECISION_INPUT_STR) -> None: - self.xla_impl.precision = precision - - @property - def _desired_dtype(self) -> torch.dtype: - return self.xla_impl._desired_dtype - - @_desired_dtype.setter - def _desired_dtype(self, dtype: torch.dtype) -> None: - self.xla_impl._desired_dtype = dtype + import torch_xla.core.xla_model as xm + + closure = partial(self._xla_wrap_closure, optimizer, closure) + closure = partial(self._wrap_closure, model, optimizer, closure) + closure_result = optimizer.step(closure=closure, **kwargs) + xm.mark_step() + skipped_backward = closure_result is None + # in manual optimization, the closure does not return a value + if model.automatic_optimization and skipped_backward: + # we lack coverage here so disable this - something to explore if there's demand + raise MisconfigurationException( + "Skipping backward by returning `None` from your `training_step` is not implemented with XLA." + " Please, open an issue in `https://github.com/Lightning-AI/pytorch-lightning/issues`" + " requesting this feature." + ) + return closure_result @override def teardown(self) -> None: - return self.xla_impl.teardown() + os.environ.pop("XLA_USE_BF16", None) + os.environ.pop("XLA_USE_F16", None) + + def _xla_wrap_closure(self, optimizer: Optimizable, closure: Callable[[], Any]) -> Any: + import torch_xla.core.xla_model as xm + + closure_result = closure() + xm.reduce_gradients(optimizer) + return closure_result diff --git a/src/lightning/pytorch/strategies/deepspeed.py b/src/lightning/pytorch/strategies/deepspeed.py index 8bad6a4c6d40b..1fce7b06887cd 100644 --- a/src/lightning/pytorch/strategies/deepspeed.py +++ b/src/lightning/pytorch/strategies/deepspeed.py @@ -11,27 +11,52 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import argparse +import json import logging +import os +import platform from collections import OrderedDict from collections.abc import Generator, Mapping from contextlib import contextmanager from datetime import timedelta +from pathlib import Path from typing import TYPE_CHECKING, Any, Optional, Union import torch from torch.nn import Module from torch.optim import Optimizer +from torch.optim.lr_scheduler import LRScheduler, ReduceLROnPlateau from typing_extensions import override import lightning.pytorch as pl from lightning.fabric.plugins import ClusterEnvironment from lightning.fabric.plugins.collectives.torch_collective import default_pg_timeout from lightning.fabric.strategies import _StrategyRegistry -from lightning.fabric.utilities.imports import _raise_enterprise_not_available +from lightning.fabric.strategies.deepspeed import ( + _DEEPSPEED_AVAILABLE, + _DEEPSPEED_GREATER_EQUAL_0_16, + _format_precision_config, + _validate_checkpoint_directory, + _validate_device_index_selection, +) +from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_6 +from lightning.fabric.utilities.optimizer import _optimizers_to_device +from lightning.fabric.utilities.seed import reset_seed from lightning.fabric.utilities.types import _PATH +from lightning.pytorch.accelerators.cuda import CUDAAccelerator +from lightning.pytorch.core.optimizer import _init_optimizers_and_lr_schedulers from lightning.pytorch.plugins.precision import Precision from lightning.pytorch.strategies.ddp import DDPStrategy +from lightning.pytorch.trainer.states import TrainerFn +from lightning.pytorch.utilities import GradClipAlgorithmType +from lightning.pytorch.utilities.exceptions import MisconfigurationException +from lightning.pytorch.utilities.model_helpers import is_overridden +from lightning.pytorch.utilities.rank_zero import WarningCache, rank_zero_info, rank_zero_warn +from lightning.pytorch.utilities.types import LRSchedulerConfig + +log = logging.getLogger(__name__) +warning_cache = WarningCache() if TYPE_CHECKING: import deepspeed @@ -233,84 +258,150 @@ def __init__( exclude_frozen_parameters: Exclude frozen parameters when saving checkpoints. """ - super().__init__( - accelerator=accelerator, - parallel_devices=parallel_devices, - cluster_environment=cluster_environment, - precision_plugin=precision_plugin, - process_group_backend=process_group_backend, - ) - _raise_enterprise_not_available() - from pytorch_lightning_enterprise.strategies.deepspeed import ( - DeepSpeedStrategyTrainer as EnterpriseDeepSpeedStrategy, - ) + if not _DEEPSPEED_AVAILABLE: + raise MisconfigurationException( + "To use the `DeepSpeedStrategy`, you must have DeepSpeed installed." + " Install it by running `pip install -U deepspeed`." + ) - self.deepspeed_strategy_impl = EnterpriseDeepSpeedStrategy( - outer_object=self, + if _TORCH_GREATER_EQUAL_2_6 and not _DEEPSPEED_GREATER_EQUAL_0_16: + # Starting with PyTorch 2.6, `torch.load` defaults to `weights_only=True` when loading full checkpoints. + # DeepSpeed added support for this behavior in version 0.16.0. + import deepspeed + + deepspeed_version = deepspeed.__version__ + + raise ImportError( + f"PyTorch >= 2.6 requires DeepSpeed >= 0.16.0. " + f"Detected DeepSpeed version: {deepspeed_version}. " + "Please upgrade by running `pip install -U 'deepspeed>=0.16.0'`." + ) + + super().__init__( accelerator=accelerator, - zero_optimization=zero_optimization, - stage=stage, - remote_device=remote_device, - offload_optimizer=offload_optimizer, - offload_parameters=offload_parameters, - offload_params_device=offload_params_device, - nvme_path=nvme_path, - params_buffer_count=params_buffer_count, - params_buffer_size=params_buffer_size, - max_in_cpu=max_in_cpu, - offload_optimizer_device=offload_optimizer_device, - optimizer_buffer_count=optimizer_buffer_count, - block_size=block_size, - queue_depth=queue_depth, - single_submit=single_submit, - overlap_events=overlap_events, - thread_count=thread_count, - pin_memory=pin_memory, - sub_group_size=sub_group_size, - contiguous_gradients=contiguous_gradients, - overlap_comm=overlap_comm, - allgather_partitions=allgather_partitions, - reduce_scatter=reduce_scatter, - allgather_bucket_size=allgather_bucket_size, - reduce_bucket_size=reduce_bucket_size, - zero_allow_untested_optimizer=zero_allow_untested_optimizer, - logging_batch_size_per_gpu=logging_batch_size_per_gpu, - config=config, - logging_level=logging_level, parallel_devices=parallel_devices, cluster_environment=cluster_environment, - loss_scale=loss_scale, - initial_scale_power=initial_scale_power, - loss_scale_window=loss_scale_window, - hysteresis=hysteresis, - min_loss_scale=min_loss_scale, - partition_activations=partition_activations, - cpu_checkpointing=cpu_checkpointing, - contiguous_memory_optimization=contiguous_memory_optimization, - synchronize_checkpoint_boundary=synchronize_checkpoint_boundary, - load_full_weights=load_full_weights, precision_plugin=precision_plugin, process_group_backend=process_group_backend, - timeout=timeout, - exclude_frozen_parameters=exclude_frozen_parameters, ) + self._timeout: Optional[timedelta] = timeout + + self.config = self._load_config(config) + if self.config is None: + # User has not overridden config, set defaults + self.config = self._create_default_config( + zero_optimization, + zero_allow_untested_optimizer, + logging_batch_size_per_gpu, + offload_optimizer=offload_optimizer, + offload_parameters=offload_parameters, + nvme_path=nvme_path, + offload_params_device=offload_params_device, + params_buffer_count=params_buffer_count, + params_buffer_size=params_buffer_size, + max_in_cpu=max_in_cpu, + pin_memory=pin_memory, + offload_optimizer_device=offload_optimizer_device, + optimizer_buffer_count=optimizer_buffer_count, + block_size=block_size, + queue_depth=queue_depth, + single_submit=single_submit, + overlap_events=overlap_events, + thread_count=thread_count, + partition_activations=partition_activations, + cpu_checkpointing=cpu_checkpointing, + contiguous_memory_optimization=contiguous_memory_optimization, + synchronize_checkpoint_boundary=synchronize_checkpoint_boundary, + stage=stage, + contiguous_gradients=contiguous_gradients, + overlap_comm=overlap_comm, + allgather_partitions=allgather_partitions, + reduce_scatter=reduce_scatter, + allgather_bucket_size=allgather_bucket_size, + reduce_bucket_size=reduce_bucket_size, + sub_group_size=sub_group_size, + ) + import deepspeed + + self._config_initialized = False + deepspeed.utils.logging.logger.setLevel(logging_level) + + self.remote_device = remote_device + self.load_full_weights = load_full_weights + self.exclude_frozen_parameters = exclude_frozen_parameters + + # default FP16 parameters. + self.loss_scale = loss_scale + self.initial_scale_power = initial_scale_power + self.loss_scale_window = loss_scale_window + self.hysteresis = hysteresis + self.min_loss_scale = min_loss_scale @override def setup_environment(self) -> None: - return self.deepspeed_strategy_impl.setup_environment() + if not isinstance(self.accelerator, CUDAAccelerator): + raise RuntimeError( + f"The DeepSpeed strategy is only supported on CUDA GPUs but `{self.accelerator.__class__.__name__}`" + " is used." + ) + super().setup_environment() @override def setup_distributed(self) -> None: - return self.deepspeed_strategy_impl.setup_distributed() + assert self.parallel_devices is not None + _validate_device_index_selection(self.parallel_devices) + reset_seed() + self.set_world_ranks() + self._init_deepspeed_distributed() @override def setup(self, trainer: "pl.Trainer") -> None: - return self.deepspeed_strategy_impl.setup(trainer=trainer) + self._init_config_if_needed() + assert self.accelerator is not None + self.accelerator.setup(trainer) + + assert self.model is not None + self.model = self.precision_plugin.convert_module(self.model) + self.model = self._setup_model(self.model) + + if trainer.state.fn == TrainerFn.FITTING: + self.setup_optimizers(trainer) + self.setup_precision_plugin() + if trainer.state.fn == TrainerFn.FITTING: + _optimizers_to_device(self.optimizers, self.root_device) + + self.init_deepspeed() + self.barrier() + + def _init_deepspeed_distributed(self) -> None: + import deepspeed + + assert self.cluster_environment is not None + if platform.system() != "Windows": + # do not set env variables on windows, allow deepspeed to control setup + self._set_node_environment_variables() + log.info( + "initializing deepspeed distributed: " + f"GLOBAL_RANK: {self.global_rank}, " + f"MEMBER: {self.global_rank + 1}/{self.world_size}" + ) + self._process_group_backend = self._get_process_group_backend() + deepspeed.init_distributed( + self._process_group_backend, distributed_port=self.cluster_environment.main_port, timeout=self._timeout + ) + + def _set_node_environment_variables(self) -> None: + assert self.cluster_environment is not None + os.environ["MASTER_ADDR"] = self.cluster_environment.main_address + os.environ["MASTER_PORT"] = str(self.cluster_environment.main_port) + os.environ["RANK"] = str(self.global_rank) + os.environ["WORLD_SIZE"] = str(self.world_size) + os.environ["LOCAL_RANK"] = str(self.local_rank) @property @override def restore_checkpoint_after_setup(self) -> bool: - return self.deepspeed_strategy_impl.restore_checkpoint_after_setup + return True @override def _setup_model_and_optimizers( @@ -325,28 +416,188 @@ def _setup_model_and_optimizers( deepspeed optimizer. """ - return self.deepspeed_strategy_impl._setup_model_and_optimizers(model=model, optimizers=optimizers) + if len(optimizers) != 1: + raise ValueError( + f"Currently only one optimizer is supported with DeepSpeed. Got {len(optimizers)} optimizers instead." + ) + + # train_micro_batch_size_per_gpu is used for throughput logging purposes + # normally we set this to the batch size, but it is not available here unless the user provides it + # as part of the config + assert self.config is not None + self.config.setdefault("train_micro_batch_size_per_gpu", 1) + self.model, optimizer = self._setup_model_and_optimizer(model, optimizers[0]) + self._set_deepspeed_activation_checkpointing() + return self.model, [optimizer] + + def _setup_model_and_optimizer( + self, + model: Module, + optimizer: Optional[Optimizer], + lr_scheduler: Optional[Union[LRScheduler, ReduceLROnPlateau]] = None, + ) -> tuple["deepspeed.DeepSpeedEngine", Optimizer]: + """Initialize one model and one optimizer with an optional learning rate scheduler. + + This calls ``deepspeed.initialize`` internally. + + """ + import deepspeed + + model_parameters = filter(lambda p: p.requires_grad, model.parameters()) + deepspeed_engine, deepspeed_optimizer, _, _ = deepspeed.initialize( + args=argparse.Namespace(device_rank=self.root_device.index), + config=self.config, + model=model, + model_parameters=model_parameters, + optimizer=optimizer, + lr_scheduler=lr_scheduler, + dist_init_required=False, + ) + return deepspeed_engine, deepspeed_optimizer + + def init_deepspeed(self) -> None: + assert self.lightning_module is not None + # deepspeed handles gradient clipping internally + if is_overridden("configure_gradient_clipping", self.lightning_module, pl.LightningModule): + rank_zero_warn( + "Since DeepSpeed handles gradient clipping internally, the default" + " `LightningModule.configure_gradient_clipping` implementation will not actually clip gradients." + " The hook will still be called. Consider setting" + " `Trainer(gradient_clip_val=..., gradient_clip_algorithm='norm')`" + " which will use the internal mechanism." + ) + + if self.lightning_module.trainer.gradient_clip_algorithm == GradClipAlgorithmType.VALUE: + raise MisconfigurationException("DeepSpeed does not support clipping gradients by value.") + + assert isinstance(self.model, pl.LightningModule) + if self.lightning_module.trainer and self.lightning_module.trainer.training: + self._initialize_deepspeed_train(self.model) + else: + self._initialize_deepspeed_inference(self.model) + + def _init_optimizers(self) -> tuple[Optimizer, Optional[LRSchedulerConfig]]: + assert self.lightning_module is not None + optimizers, lr_schedulers = _init_optimizers_and_lr_schedulers(self.lightning_module) + if len(optimizers) > 1 or len(lr_schedulers) > 1: + raise MisconfigurationException( + "DeepSpeed currently only supports single optimizer, single optional scheduler." + ) + return optimizers[0], lr_schedulers[0] if lr_schedulers else None @property def zero_stage_3(self) -> bool: - return self.deepspeed_strategy_impl.zero_stage_3 + assert isinstance(self.config, dict) + zero_optimization = self.config.get("zero_optimization") + return zero_optimization is not None and zero_optimization.get("stage") == 3 + + def _initialize_deepspeed_train(self, model: Module) -> None: + optimizer, scheduler = None, None + assert isinstance(self.config, dict) + if "optimizer" in self.config: + rank_zero_info( + "You have specified an optimizer and/or scheduler within the DeepSpeed config." + " It is recommended to define it in `LightningModule.configure_optimizers`." + ) + lr_scheduler = None + else: + ( + optimizer, + lr_scheduler, + ) = self._init_optimizers() + if lr_scheduler is not None: + scheduler = lr_scheduler.scheduler + + model, deepspeed_optimizer = self._setup_model_and_optimizer(model, optimizer, scheduler) + self._set_deepspeed_activation_checkpointing() + + # although we set these here, deepspeed manages the specific optimizer logic + self.optimizers = [deepspeed_optimizer] + + deepspeed_scheduler = model.lr_scheduler + if deepspeed_scheduler is not None: + # disable deepspeed lr scheduling as lightning manages scheduling + model.lr_scheduler = None + if lr_scheduler is None: + lr_scheduler = LRSchedulerConfig(deepspeed_scheduler, interval="step") + else: + lr_scheduler.scheduler = deepspeed_scheduler + self.lr_scheduler_configs = [lr_scheduler] + self.model = model @contextmanager @override def tensor_init_context(self, empty_init: Optional[bool] = None) -> Generator[None, None, None]: - with self.deepspeed_strategy_impl.tensor_init_context(empty_init=empty_init): + if self.zero_stage_3: + if empty_init is False: + raise NotImplementedError( + f"`{empty_init=}` is not a valid choice with `DeepSpeedStrategy` when ZeRO stage 3 is enabled." + ) + yield + return + with super().tensor_init_context(empty_init=empty_init): yield @contextmanager @override def model_sharded_context(self) -> Generator[None, None, None]: - with self.deepspeed_strategy_impl.model_sharded_context(): + import deepspeed + + self._init_config_if_needed() + with deepspeed.zero.Init( + enabled=self.zero_stage_3, + remote_device=self.remote_device, + config_dict_or_path=self.config, + ): yield + def _set_deepspeed_activation_checkpointing(self) -> None: + import deepspeed + + assert isinstance(self.config, dict) + if self.config.get("activation_checkpointing"): + checkpoint_config = self.config["activation_checkpointing"] + deepspeed.checkpointing.configure( + mpu_=None, + partition_activations=checkpoint_config.get("partition_activations"), + contiguous_checkpointing=checkpoint_config.get("contiguous_memory_optimization"), + checkpoint_in_cpu=checkpoint_config.get("cpu_checkpointing"), + profile=checkpoint_config.get("profile"), + ) + + def _initialize_deepspeed_inference(self, model: Module) -> None: + import deepspeed + + assert isinstance(self.config, dict) + + # todo: this is required for DeepSpeed throughput timers + inference_config = {"train_micro_batch_size_per_gpu": 1} + if "fp16" in self.config: + inference_config.update({"fp16": self.config["fp16"]}) + if "bf16" in self.config: + inference_config.update({"bf16": self.config["bf16"]}) + if self.zero_stage_3: + inference_config.update({ + "zero_allow_untested_optimizer": self.config["zero_allow_untested_optimizer"], + "zero_optimization": self.config["zero_optimization"], + }) + # Remove all module hooks before initializing new model + remove_module_hooks(model) + model, _, _, _ = deepspeed.initialize( + args=argparse.Namespace(device_rank=self.root_device.index), + config=inference_config, + model=model, + optimizer=None, + lr_scheduler=None, + model_parameters=[], + dist_init_required=False, + ) + self.model = model + @property @override def distributed_sampler_kwargs(self) -> dict[str, int]: - return self.deepspeed_strategy_impl.distributed_sampler_kwargs + return {"num_replicas": self.world_size, "rank": self.global_rank} @override def setup_optimizers(self, trainer: "pl.Trainer") -> None: @@ -356,21 +607,29 @@ def setup_optimizers(self, trainer: "pl.Trainer") -> None: trainer: the Trainer, these optimizers should be connected to """ - return self.deepspeed_strategy_impl.setup_optimizers(trainer=trainer) + # Skip initializing optimizers here as DeepSpeed handles optimizers via config. + # User may have specified config options instead in configure_optimizers, but this is handled + # via `_initialize_deepspeed_train` + # empty optimizers, schedulers + self.optimizers = [] + self.lr_scheduler_configs = [] + + def _setup_model(self, model: Module) -> Module: # type: ignore[override] + return model @property @override def handles_gradient_accumulation(self) -> bool: """Whether the strategy handles gradient accumulation internally.""" - return self.deepspeed_strategy_impl.handles_gradient_accumulation + return True @property def deepspeed_engine(self) -> "deepspeed.DeepSpeedEngine": - return self.deepspeed_strategy_impl.deepspeed_engine + return self.model @property def _multi_device(self) -> bool: - return self.deepspeed_strategy_impl._multi_device + return self.num_processes > 1 or self.num_nodes > 1 @override def save_checkpoint(self, checkpoint: dict, filepath: _PATH, storage_options: Optional[Any] = None) -> None: @@ -386,27 +645,135 @@ def save_checkpoint(self, checkpoint: dict, filepath: _PATH, storage_options: Op If ``storage_options`` arg is passed in """ - return self.deepspeed_strategy_impl.save_checkpoint( - checkpoint=checkpoint, filepath=filepath, storage_options=storage_options + # broadcast the filepath from rank 0 to ensure all the states are saved in a common filepath + filepath = self.broadcast(filepath) + + if storage_options is not None: + raise TypeError( + "`Trainer.save_checkpoint(..., storage_options=...)` with `storage_options` arg" + f" is not supported for `{self.__class__.__name__}` as `CheckpointIO` is not used." + ) + + if self.zero_stage_3 and self._multi_device and self.is_global_zero: + warning_cache.warn( + "When saving the DeepSpeed Stage 3 checkpoint, " + "each worker will save a shard of the checkpoint within a directory. " + "If a single file is required after training, " + "see https://lightning.ai/docs/pytorch/stable/advanced/model_parallel.html#" + "deepspeed-zero-stage-3-single-file for instructions." + ) + # Use deepspeed's internal checkpointing function to handle partitioned weights across processes + # dump states as a checkpoint dictionary object + _exclude_keys = ["state_dict", "optimizer_states"] + checkpoint = {k: v for k, v in checkpoint.items() if k not in _exclude_keys} + self.deepspeed_engine.save_checkpoint( + filepath, + client_state=checkpoint, + tag="checkpoint", + exclude_frozen_parameters=self.exclude_frozen_parameters, ) @override def load_checkpoint(self, checkpoint_path: _PATH, weights_only: Optional[bool] = None) -> dict[str, Any]: - return self.deepspeed_strategy_impl.load_checkpoint(checkpoint_path=checkpoint_path, weights_only=weights_only) + if self.load_full_weights and self.zero_stage_3: + # Broadcast to ensure we load from the rank 0 checkpoint + # This doesn't have to be the case when using deepspeed sharded checkpointing + checkpoint_path = self.broadcast(checkpoint_path) + return super().load_checkpoint(checkpoint_path, weights_only) + + _validate_checkpoint_directory(checkpoint_path) + + # Rely on deepspeed to load the checkpoint and necessary information + assert self.lightning_module is not None + + from lightning.pytorch.trainer.states import TrainerFn + + is_fitting = self.lightning_module.trainer.state.fn == TrainerFn.FITTING + + _, client_state = self.deepspeed_engine.load_checkpoint( + checkpoint_path, + load_optimizer_states=is_fitting, + load_lr_scheduler_states=False, + load_module_strict=self.lightning_module.strict_loading, + ) + if client_state is None: + raise MisconfigurationException( + "DeepSpeed was unable to load the checkpoint. Ensure you passed in a DeepSpeed compatible checkpoint " + "or a single checkpoint file with `Trainer(strategy=DeepSpeedStrategy(load_full_weights=True))`." + ) + return client_state @property @override def lightning_restore_optimizer(self) -> bool: - return self.deepspeed_strategy_impl.lightning_restore_optimizer + assert self.lightning_module is not None + # managed by DeepSpeed + if self.load_full_weights and self.zero_stage_3 and self.lightning_module.trainer.state.fn == TrainerFn.FITTING: + rank_zero_warn( + "A single checkpoint file has been given. This means optimizer states cannot be restored." + " If you'd like to restore these states, you must provide a path to the originally saved DeepSpeed" + " checkpoint. When using ZeRO 3, the original path should be a directory." + ) + return False @override def load_model_state_dict(self, checkpoint: Mapping[str, Any], strict: bool = True) -> None: # override to do nothing, deepspeed engine already loaded the weights in `load_checkpoint()` - return self.deepspeed_strategy_impl.load_model_state_dict(checkpoint=checkpoint, strict=strict) + if self.load_full_weights and self.zero_stage_3: + self.model_to_device() + self._restore_zero_state(checkpoint, strict=strict) + + def _restore_zero_state(self, ckpt: Mapping[str, Any], strict: bool) -> None: + """Overrides the normal load_state_dict behaviour in PyTorch to ensure we gather parameters that may be sharded + across processes before loading the state dictionary when using ZeRO stage 3. This is then automatically synced + across processes. + + Args: + ckpt: The ckpt file. + + """ + import deepspeed + + assert self.lightning_module is not None + + def load(module: torch.nn.Module, prefix: str = "") -> None: + missing_keys: list[str] = [] + unexpected_keys: list[str] = [] + error_msgs: list[str] = [] + state_dict = ckpt["state_dict"] + + # copy state_dict so _load_from_state_dict can modify it + metadata = getattr(state_dict, "_metadata", None) + state_dict = state_dict.copy() + if metadata is not None: + state_dict._metadata = metadata + + local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {}) + # because zero3 puts placeholders in model params, this context + # manager gathers (unpartitions) the params of the current layer, then loads from + # the state dict and then re-partitions them again + with deepspeed.zero.GatheredParameters(list(module.parameters(recurse=False)), modifier_rank=0): + if self.is_global_zero: + module._load_from_state_dict( + state_dict=state_dict, + prefix=prefix, + local_metadata=local_metadata, + strict=strict, + missing_keys=missing_keys, + unexpected_keys=unexpected_keys, + error_msgs=error_msgs, + ) + + for name, child in module._modules.items(): + if child is not None: + load(child, prefix + name + ".") + + load(self.lightning_module, prefix="") @override def load_optimizer_state_dict(self, checkpoint: Mapping[str, Any]) -> None: - return self.deepspeed_strategy_impl.load_optimizer_state_dict(checkpoint=checkpoint) + # Override to do nothing, the deepspeed engine already loaded the states in `load_checkpoint()` + pass @classmethod @override @@ -442,10 +809,137 @@ def register_strategies(cls, strategy_registry: _StrategyRegistry) -> None: offload_optimizer_device="nvme", ) - @property - def config(self) -> dict[str, Any]: - return self.deepspeed_strategy_impl.config + def _load_config(self, config: Optional[Union[_PATH, dict[str, Any]]]) -> Optional[dict[str, Any]]: + if config is None and self.DEEPSPEED_ENV_VAR in os.environ: + rank_zero_info(f"Loading DeepSpeed config from set {self.DEEPSPEED_ENV_VAR} environment variable") + config = os.environ[self.DEEPSPEED_ENV_VAR] + if isinstance(config, (str, Path)): + if not os.path.isfile(config): + raise MisconfigurationException( + f"You passed in a path to a DeepSpeed config but the path does not exist: {config}" + ) + with open(config) as f: + config = json.load(f) + assert isinstance(config, dict) or config is None + return config + + def _init_config_if_needed(self) -> None: + if not self._config_initialized: + self._format_config() + self._config_initialized = True + + def _format_config(self) -> None: + if self.config is None: + raise MisconfigurationException( + "To use DeepSpeed you must pass in a DeepSpeed config dict, or a path to a JSON config." + " See: https://lightning.ai/docs/pytorch/stable/advanced/model_parallel.html#deepspeed" + ) + self._format_batch_size_and_grad_accum_config() + _format_precision_config( + config=self.config, + precision=self.precision_plugin.precision, + loss_scale=self.loss_scale, + loss_scale_window=self.loss_scale_window, + min_loss_scale=self.min_loss_scale, + initial_scale_power=self.initial_scale_power, + hysteresis=self.hysteresis, + ) - @property - def load_full_weights(self) -> bool: - return self.deepspeed_strategy_impl.load_full_weights + def _create_default_config( + self, + zero_optimization: bool, + zero_allow_untested_optimizer: bool, + logging_batch_size_per_gpu: Union[str, int], + partition_activations: bool, + cpu_checkpointing: bool, + contiguous_memory_optimization: bool, + synchronize_checkpoint_boundary: bool, + offload_optimizer: bool, + offload_parameters: bool, + nvme_path: str, + offload_params_device: str, + params_buffer_count: int, + params_buffer_size: int, + max_in_cpu: int, + offload_optimizer_device: str, + optimizer_buffer_count: int, + pin_memory: bool, + block_size: int, + queue_depth: int, + single_submit: bool, + overlap_events: bool, + thread_count: int, + **zero_kwargs: Any, + ) -> dict: + cfg = { + "activation_checkpointing": { + "partition_activations": partition_activations, + "cpu_checkpointing": cpu_checkpointing, + "contiguous_memory_optimization": contiguous_memory_optimization, + "synchronize_checkpoint_boundary": synchronize_checkpoint_boundary, + }, + "aio": { + "block_size": block_size, + "queue_depth": queue_depth, + "single_submit": single_submit, + "overlap_events": overlap_events, + "thread_count": thread_count, + }, + } + if zero_optimization: + zero_config = zero_kwargs + + if offload_optimizer: + zero_config["offload_optimizer"] = { + "device": offload_optimizer_device, + "nvme_path": nvme_path, + "buffer_count": optimizer_buffer_count, + "pin_memory": pin_memory, + } + if offload_parameters: + zero_config["offload_param"] = { + "device": offload_params_device, + "nvme_path": nvme_path, + "buffer_count": params_buffer_count, + "buffer_size": params_buffer_size, + "max_in_cpu": max_in_cpu, + "pin_memory": pin_memory, + } + cfg = { + "zero_allow_untested_optimizer": zero_allow_untested_optimizer, + "zero_optimization": zero_config, + **cfg, + } + if logging_batch_size_per_gpu != "auto": + cfg = {"train_micro_batch_size_per_gpu": logging_batch_size_per_gpu, **cfg} + return cfg + + def _format_batch_size_and_grad_accum_config(self) -> None: + # TODO: Using Fabric, we do not support these variables within the config + assert isinstance(self.config, dict) + if self.lightning_module is None: + return + + if "gradient_accumulation_steps" in self.config: + raise MisconfigurationException( + "Do not set `gradient_accumulation_steps` in the DeepSpeed config" + " as this will be set with the `accumulate_grad_batches` argument passed via the Lightning Trainer." + ) + self.config["gradient_accumulation_steps"] = self.lightning_module.trainer.accumulate_grad_batches + if "train_micro_batch_size_per_gpu" not in self.config: + batch_size = self._auto_select_batch_size() + self.config["train_micro_batch_size_per_gpu"] = batch_size + if "gradient_clipping" not in self.config: + self.config["gradient_clipping"] = self.lightning_module.trainer.gradient_clip_val or 0.0 + + def _auto_select_batch_size(self) -> int: + # train_micro_batch_size_per_gpu is used for throughput logging purposes + # by default we try to use the batch size of the loader + assert self.lightning_module is not None + batch_size = 1 + data_source = self.lightning_module.trainer.fit_loop._data_source + if data_source.is_defined(): + train_dataloader = data_source.dataloader() + if hasattr(train_dataloader, "batch_sampler"): + batch_size = train_dataloader.batch_sampler.batch_size + return batch_size diff --git a/src/lightning/pytorch/strategies/launchers/xla.py b/src/lightning/pytorch/strategies/launchers/xla.py index 63842d5557d7e..066fecc79f208 100644 --- a/src/lightning/pytorch/strategies/launchers/xla.py +++ b/src/lightning/pytorch/strategies/launchers/xla.py @@ -11,18 +11,23 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import os import queue from typing import TYPE_CHECKING, Any, Callable, Optional, Union import torch.multiprocessing as mp from typing_extensions import override -from lightning.fabric.utilities.imports import _raise_enterprise_not_available +from lightning.fabric.accelerators.xla import _XLA_AVAILABLE +from lightning.fabric.strategies.launchers.xla import _rank_teardown +from lightning.fabric.utilities import move_data_to_device from lightning.pytorch.strategies.launchers.multiprocessing import ( _GlobalStateSnapshot, _MultiProcessingLauncher, _WorkerOutput, ) +from lightning.pytorch.trainer.states import TrainerFn +from lightning.pytorch.utilities.rank_zero import rank_zero_debug if TYPE_CHECKING: import lightning.pytorch as pl @@ -46,16 +51,14 @@ class _XLALauncher(_MultiProcessingLauncher): """ def __init__(self, strategy: "pl.strategies.XLAStrategy") -> None: - super().__init__(strategy) - _raise_enterprise_not_available() - from pytorch_lightning_enterprise.strategies.xla.launcher import _XLALauncherTrainer as EnterpriseXLALauncher - - self.xla_launcher_impl = EnterpriseXLALauncher(strategy) + if not _XLA_AVAILABLE: + raise ModuleNotFoundError(str(_XLA_AVAILABLE)) + super().__init__(strategy=strategy, start_method="fork") @property @override def is_interactive_compatible(self) -> bool: - return self.xla_launcher_impl.is_interactive_compatible + return True @override def launch(self, function: Callable, *args: Any, trainer: Optional["pl.Trainer"] = None, **kwargs: Any) -> Any: @@ -72,7 +75,46 @@ def launch(self, function: Callable, *args: Any, trainer: Optional["pl.Trainer"] **kwargs: Optional keyword arguments to be passed to the given function. """ - return self.xla_launcher_impl.launch(function, *args, trainer=trainer, **kwargs) + if self._already_fit and trainer is not None and trainer.state.fn == TrainerFn.FITTING: + # resolving https://github.com/Lightning-AI/pytorch-lightning/issues/18775 will lift this restriction + raise NotImplementedError( + "Calling `trainer.fit()` twice on the same Trainer instance using a spawn-based strategy is not" + " supported. You can work around this by creating a new Trainer instance and passing the" + " `fit(ckpt_path=...)` argument." + ) + + # pjrt requires that the queue is serializable + return_queue = mp.Manager().Queue() + + import torch_xla.distributed.xla_multiprocessing as xmp + + spawn_kwargs = {} + nprocs = self._strategy.num_processes + if nprocs == 1: + # avoid warning: "Unsupported nprocs". If it's 1, it will call the launched function directly. + # otherwise it will use all devices + spawn_kwargs["nprocs"] = nprocs + + process_context = xmp.spawn( + self._wrapping_function, + args=(trainer, function, args, kwargs, return_queue), + start_method=self._start_method, + join=False, # we will join ourselves to get the process references + **spawn_kwargs, + ) + # xla will not actually create processes if only 1 device + if process_context is not None: + self.procs = process_context.processes + while not process_context.join(): + pass + + worker_output = return_queue.get() + if trainer is None: + return worker_output + + self._already_fit |= trainer.state.fn == TrainerFn.FITTING + self._recover_results_in_main_process(worker_output, trainer) + return worker_output.trainer_results @override def _wrapping_function( @@ -87,10 +129,48 @@ def _wrapping_function( return_queue: Union[mp.SimpleQueue, queue.Queue], global_states: Optional[_GlobalStateSnapshot] = None, ) -> None: - return self.xla_launcher_impl._wrapping_function( - process_idx, trainer, function, args, kwargs, return_queue, global_states - ) + import torch_xla.core.xla_model as xm + + if len(xm.get_xla_supported_devices()) > 1: + # `get_xla_supported_devices` in the spawned process returns the logical devices (2 for v2/v3 and 1 for v4) + # so when there's more than one (multithreading), objects need to be deep-copied + import copy + + trainer, function, args, kwargs = copy.deepcopy((trainer, function, args, kwargs)) + + results = function(*args, **kwargs) + + if trainer is not None: + results = self._collect_rank_zero_results(trainer, results) + + if self._strategy.local_rank == 0: + return_queue.put(move_data_to_device(results, "cpu")) + + _rank_teardown(self._strategy.local_rank) @override def _collect_rank_zero_results(self, trainer: "pl.Trainer", results: Any) -> Optional["_WorkerOutput"]: - return self.xla_launcher_impl._collect_rank_zero_results(trainer, results) + rank_zero_debug("Collecting results from rank 0 process.") + checkpoint_callback = trainer.checkpoint_callback + best_model_path = ( + checkpoint_callback.best_model_path + if checkpoint_callback and hasattr(checkpoint_callback, "best_model_path") + else None + ) + + # save the last weights + weights_path = None + if trainer.state.fn == TrainerFn.FITTING: + # requires to compute the state_dict on all processes in case Metrics are present + state_dict = self._strategy.lightning_module_state_dict() + weights_path = os.path.join(trainer.default_root_dir, ".temp.ckpt") + self._strategy.checkpoint_io.save_checkpoint(state_dict, weights_path) + + # We use `local_rank` here as separate filesystems are used for each VM for TPU Pod Training + if self._strategy.local_rank != 0: + return None + + # add extra result data from trainer to send to main process + extra = self.get_extra_results(trainer) + + return _WorkerOutput(best_model_path, weights_path, trainer.state, results, extra) diff --git a/src/lightning/pytorch/strategies/single_xla.py b/src/lightning/pytorch/strategies/single_xla.py index 6e687ac815639..2a5e2f3a85b96 100644 --- a/src/lightning/pytorch/strategies/single_xla.py +++ b/src/lightning/pytorch/strategies/single_xla.py @@ -11,18 +11,23 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import os from typing import Optional, Union +import torch from typing_extensions import override import lightning.pytorch as pl +from lightning.fabric.accelerators.xla import _XLA_AVAILABLE from lightning.fabric.plugins import CheckpointIO, Precision, XLACheckpointIO from lightning.fabric.strategies import _StrategyRegistry -from lightning.fabric.utilities.imports import _raise_enterprise_not_available +from lightning.fabric.utilities.optimizer import _optimizers_to_device from lightning.fabric.utilities.types import _DEVICE from lightning.pytorch.plugins.io.wrapper import _WrappingCheckpointIO from lightning.pytorch.plugins.precision.xla import XLAPrecision from lightning.pytorch.strategies.single_device import SingleDeviceStrategy +from lightning.pytorch.trainer.states import TrainerFn +from lightning.pytorch.utilities import find_shared_parameters, set_shared_parameters class SingleDeviceXLAStrategy(SingleDeviceStrategy): @@ -36,18 +41,20 @@ def __init__( precision_plugin: Optional[XLAPrecision] = None, debug: bool = False, ): + if not _XLA_AVAILABLE: + raise ModuleNotFoundError(str(_XLA_AVAILABLE)) + if isinstance(device, torch.device): + # unwrap the `torch.device` in favor of `xla_device` + device = device.index + import torch_xla.core.xla_model as xm + super().__init__( accelerator=accelerator, - device=device, + device=xm.xla_device(device), checkpoint_io=checkpoint_io, precision_plugin=precision_plugin, ) - _raise_enterprise_not_available() - from pytorch_lightning_enterprise.strategies.xla.single import ( - SingleDeviceXLAStrategyTrainer as EnterpriseSingleDeviceXLAStrategy, - ) - - self.single_xla_strategy_impl = EnterpriseSingleDeviceXLAStrategy(outer_object=self, device=device, debug=debug) + self.debug = debug @property @override @@ -83,7 +90,26 @@ def precision_plugin(self, precision_plugin: Optional[Precision]) -> None: @override def setup(self, trainer: "pl.Trainer") -> None: - return self.single_xla_strategy_impl.setup(trainer=trainer) + if self.debug: + os.environ["PT_XLA_DEBUG"] = str(1) + + assert self.accelerator is not None + self.accelerator.setup(trainer) + + assert self.model is not None + self.precision_plugin.convert_module(self.model) + + shared_params = find_shared_parameters(self.model) + self.model_to_device() + set_shared_parameters(self.model, shared_params) + + self.model = self._setup_model(self.model) + + if trainer.state.fn == TrainerFn.FITTING: + self.setup_optimizers(trainer) + self.setup_precision_plugin() + if trainer.state.fn == TrainerFn.FITTING: + _optimizers_to_device(self.optimizers, self.root_device) @classmethod @override @@ -92,4 +118,5 @@ def register_strategies(cls, strategy_registry: _StrategyRegistry) -> None: @override def teardown(self) -> None: - return self.single_xla_strategy_impl.teardown() + super().teardown() + os.environ.pop("PT_XLA_DEBUG", None) diff --git a/src/lightning/pytorch/strategies/xla.py b/src/lightning/pytorch/strategies/xla.py index e57c627f374d7..cbdc890a1ca32 100644 --- a/src/lightning/pytorch/strategies/xla.py +++ b/src/lightning/pytorch/strategies/xla.py @@ -11,6 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import io +import os from typing import TYPE_CHECKING, Any, Optional, Union import torch @@ -19,16 +21,20 @@ from typing_extensions import override import lightning.pytorch as pl +from lightning.fabric.accelerators.xla import _XLA_AVAILABLE, _XLA_GREATER_EQUAL_2_1 from lightning.fabric.plugins import CheckpointIO, Precision, XLACheckpointIO from lightning.fabric.plugins.environments import XLAEnvironment from lightning.fabric.strategies import _StrategyRegistry -from lightning.fabric.utilities.imports import _raise_enterprise_not_available +from lightning.fabric.utilities.optimizer import _optimizers_to_device from lightning.fabric.utilities.types import _PATH, ReduceOp from lightning.pytorch.plugins import XLAPrecision from lightning.pytorch.plugins.io.wrapper import _WrappingCheckpointIO from lightning.pytorch.strategies.ddp import DDPStrategy from lightning.pytorch.strategies.launchers.xla import _XLALauncher from lightning.pytorch.strategies.strategy import TBroadcast +from lightning.pytorch.trainer.states import TrainerFn +from lightning.pytorch.utilities import find_shared_parameters, set_shared_parameters +from lightning.pytorch.utilities.rank_zero import rank_zero_only if TYPE_CHECKING: from torch_xla.distributed.parallel_loader import MpDeviceLoader @@ -50,6 +56,8 @@ def __init__( sync_module_states: bool = True, **_: Any, ) -> None: + if not _XLA_AVAILABLE: + raise ModuleNotFoundError(str(_XLA_AVAILABLE)) super().__init__( accelerator=accelerator, parallel_devices=parallel_devices, @@ -58,12 +66,9 @@ def __init__( precision_plugin=precision_plugin, start_method="fork", ) - _raise_enterprise_not_available() - from pytorch_lightning_enterprise.strategies.xla.ddp import XLAStrategyTrainer as EnterpriseXLAStrategy - - self.xla_strategy_impl = EnterpriseXLAStrategy( - outer_object=self, debug=debug, sync_module_states=sync_module_states - ) + self.debug = debug + self._launched = False + self._sync_module_states = sync_module_states @property @override @@ -100,27 +105,31 @@ def precision_plugin(self, precision_plugin: Optional[Precision]) -> None: @property @override def root_device(self) -> torch.device: - return self.xla_strategy_impl.root_device + if not self._launched: + raise RuntimeError("Accessing the XLA device before processes have spawned is not allowed.") + import torch_xla.core.xla_model as xm + + return xm.xla_device() @property @override def global_rank(self) -> int: - return self.xla_strategy_impl.global_rank + return super().global_rank if self._launched else 0 @property @override def local_rank(self) -> int: - return self.xla_strategy_impl.local_rank + return super().local_rank if self._launched else 0 @property @override def node_rank(self) -> int: - return self.xla_strategy_impl.node_rank + return super().node_rank if self._launched else 0 @property @override def world_size(self) -> int: - return self.xla_strategy_impl.world_size + return super().world_size if self._launched else 1 @override def _configure_launcher(self) -> None: @@ -128,36 +137,113 @@ def _configure_launcher(self) -> None: @override def setup(self, trainer: "pl.Trainer") -> None: - return self.xla_strategy_impl.setup(trainer=trainer) + assert self.accelerator is not None + self.accelerator.setup(trainer) + + if self.debug: + os.environ["PT_XLA_DEBUG"] = "1" + + assert self.model is not None + self.precision_plugin.convert_module(self.model) + + shared_params = find_shared_parameters(self.model) + self.model_to_device() + set_shared_parameters(self.model, shared_params) + + self.model = self._setup_model(self.model) + + if self._sync_module_states: + if _XLA_GREATER_EQUAL_2_1: + from torch_xla.core.xla_model import broadcast_master_param + else: + from torch_xla.experimental.pjrt import broadcast_master_param + + broadcast_master_param(self.model) + + if trainer.state.fn == TrainerFn.FITTING: + self.setup_optimizers(trainer) + self.setup_precision_plugin() + if trainer.state.fn == TrainerFn.FITTING: + _optimizers_to_device(self.optimizers, self.root_device) @override def _setup_model(self, model: Module) -> Module: # type: ignore - return self.xla_strategy_impl._setup_model(model=model) + return model @property @override def distributed_sampler_kwargs(self) -> dict[str, int]: - return self.xla_strategy_impl.distributed_sampler_kwargs + return {"num_replicas": self.world_size, "rank": self.global_rank} @override def process_dataloader(self, dataloader: object) -> "MpDeviceLoader": - return self.xla_strategy_impl.process_dataloader(dataloader=dataloader) + from torch_xla.distributed.parallel_loader import MpDeviceLoader + + if isinstance(dataloader, MpDeviceLoader): + # dataloader is already wrapped by MpDeviceLoader + return dataloader + + dataloader = MpDeviceLoader(dataloader, self.root_device) + # Mimic interface to torch.utils.data.DataLoader + dataloader.dataset = dataloader._loader.dataset + dataloader.batch_sampler = getattr(dataloader._loader, "batch_sampler", None) + return dataloader @override def configure_ddp(self) -> None: - return self.xla_strategy_impl.configure_ddp() + pass @override def model_to_device(self) -> None: - return self.xla_strategy_impl.model_to_device() + assert self.model is not None + self.model = self.model.to(self.root_device) @override def barrier(self, name: Optional[str] = None, *args: Any, **kwargs: Any) -> None: - return self.xla_strategy_impl.barrier(name=name, *args, **kwargs) + if not self._launched: + return + + import torch_xla.core.xla_model as xm + + if name is None: + # `None` is not supported: "TypeError: _xla_rendezvous(): incompatible function arguments" + name = "" + xm.rendezvous(name) @override def broadcast(self, obj: TBroadcast, src: int = 0) -> TBroadcast: - return self.xla_strategy_impl.broadcast(obj=obj, src=src) + if not self._launched: + return obj + + import torch_xla.core.xla_model as xm + + is_tensor = isinstance(obj, Tensor) + if is_tensor: + if obj.dim() == 0: + obj = obj.unsqueeze(0) + original_device = obj.device + # XLA distributed requires that the data is on the XLA device + obj = obj.to(self.root_device) + else: + # support for arbitrary pickle-ables + buffer = io.BytesIO() + torch.save(obj, buffer) + obj = torch.tensor( # type: ignore[assignment] + bytearray(buffer.getbuffer()), device=self.root_device, dtype=torch.float + ) + + obj = [obj] + xm.collective_broadcast(obj, root_ordinal=src) + obj = obj[0] + + if not is_tensor: + # this will preserve the dtype and device of any tensors + buffer = io.BytesIO(obj.cpu().byte().numpy()) + obj = torch.load(buffer) + else: + obj = obj.to(original_device) + + return obj @override def reduce( @@ -166,27 +252,60 @@ def reduce( group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = "mean", ) -> Tensor: - return self.xla_strategy_impl.reduce(output=output, group=group, reduce_op=reduce_op) + if not isinstance(output, Tensor): + output = torch.tensor(output, device=self.root_device) + + invalid_reduce_op = isinstance(reduce_op, ReduceOp) and reduce_op != ReduceOp.SUM + invalid_reduce_op_str = isinstance(reduce_op, str) and reduce_op.lower() not in ("sum", "mean", "avg") + if invalid_reduce_op or invalid_reduce_op_str: + raise ValueError( + "Currently, the XLAStrategy only supports `sum`, `mean`, `avg` for the reduce operation, got:" + f" {reduce_op}" + ) + + import torch_xla.core.xla_model as xm + + output = xm.mesh_reduce("reduce", output, sum) + + if isinstance(reduce_op, str) and reduce_op.lower() in ("avg", "mean"): + output = output / self.world_size + + return output @override def setup_environment(self) -> None: - return self.xla_strategy_impl.setup_environment() + self._launched = True + super().setup_environment() @override def setup_distributed(self) -> None: - return self.xla_strategy_impl.setup_distributed() + assert self.parallel_devices is not None + if len(self.parallel_devices) == 1: + # spawning only 1 device with PjRT is not supported: + # https://github.com/Lightning-AI/pytorch-lightning/pull/17408#discussion_r1170671732 + raise NotImplementedError( + "The `XLAStrategy` does not support running on a single device with the PjRT runtime." + " Try using all devices or the `SingleDeviceXLAStrategy` strategy" + ) + rank_zero_only.rank = self.global_rank @override def set_world_ranks(self) -> None: - return self.xla_strategy_impl.set_world_ranks() + # accessing global_rank will initialize the XLA computation client. since this is called outside of the spawned + # processes (by the accelerator connector), we cannot run the code that would normally be here. + # instead it's done in `setup_distributed` + pass @override def save_checkpoint( self, checkpoint: dict[str, Any], filepath: _PATH, storage_options: Optional[Any] = None ) -> None: - return self.xla_strategy_impl.save_checkpoint( - checkpoint=checkpoint, filepath=filepath, storage_options=storage_options - ) + import torch_xla.core.xla_model as xm + + # sync any pending lazy tensors on all ranks before saving to prevent potential collective hangs + xm.mark_step() + # save on global rank zero only + super().save_checkpoint(checkpoint, filepath, storage_options=storage_options) @override def remove_checkpoint(self, filepath: _PATH) -> None: @@ -196,7 +315,8 @@ def remove_checkpoint(self, filepath: _PATH) -> None: filepath: Path to checkpoint """ - return self.xla_strategy_impl.remove_checkpoint(filepath=filepath) + if self.local_rank == 0: + self.checkpoint_io.remove_checkpoint(filepath) @override def all_gather(self, tensor: Tensor, group: Optional[Any] = None, sync_grads: bool = False) -> Tensor: @@ -210,11 +330,29 @@ def all_gather(self, tensor: Tensor, group: Optional[Any] = None, sync_grads: bo A tensor of shape (world_size, ...) """ - return self.xla_strategy_impl.all_gather(tensor=tensor, group=group, sync_grads=sync_grads) + if not self._launched: + return tensor + if not isinstance(tensor, Tensor): + raise NotImplementedError( + f"`{type(self).__name__}.all_gather` is only implemented for tensors. Given {tensor}" + ) + if tensor.dim() == 0: + tensor = tensor.unsqueeze(0) + original_device = tensor.device + tensor = tensor.to(self.root_device) + + import torch_xla.core.functions as xf + import torch_xla.core.xla_model as xm + + tensor = xf.all_gather(tensor) if sync_grads else xm.all_gather(tensor) + tensor = tensor.to(original_device) + return tensor @override def teardown(self) -> None: - return self.xla_strategy_impl.teardown() + super().teardown() + self._launched = False # after the Trainer finishes, we aren't inside the spawned region + os.environ.pop("PT_XLA_DEBUG", None) @classmethod @override diff --git a/src/lightning/pytorch/utilities/deepspeed.py b/src/lightning/pytorch/utilities/deepspeed.py index 21a9190552f43..20b418437c681 100644 --- a/src/lightning/pytorch/utilities/deepspeed.py +++ b/src/lightning/pytorch/utilities/deepspeed.py @@ -20,8 +20,8 @@ import torch -from lightning.fabric.utilities.imports import _DEEPSPEED_AVAILABLE from lightning.fabric.utilities.types import _PATH +from lightning.pytorch.strategies.deepspeed import _DEEPSPEED_AVAILABLE CPU_DEVICE = torch.device("cpu") diff --git a/tests/tests_fabric/conftest.py b/tests/tests_fabric/conftest.py index 0564632b315a7..9d4a0b9462f2e 100644 --- a/tests/tests_fabric/conftest.py +++ b/tests/tests_fabric/conftest.py @@ -19,13 +19,9 @@ from unittest.mock import Mock import pytest -import pytorch_lightning_enterprise.utils.imports import torch.distributed import lightning.fabric -import lightning.fabric.plugins.environments.xla -import lightning.fabric.plugins.io.xla -import lightning.fabric.plugins.precision.xla from lightning.fabric.accelerators import XLAAccelerator from lightning.fabric.strategies.launchers.subprocess_script import _ChildProcessObserver from lightning.fabric.utilities.distributed import _destroy_dist_connection @@ -73,8 +69,6 @@ def restore_env_variables(): # set by torchdynamo "TRITON_CACHE_DIR", "TORCHINDUCTOR_CACHE_DIR", - "TQDM_MININTERVAL", # set by our platform - "TQDM_POSITION", # set by our platform } leaked_vars.difference_update(allowlist) assert not leaked_vars, f"test is leaking environment variable(s): {set(leaked_vars)}" @@ -150,33 +144,17 @@ def reset_cudnn_benchmark(): def mock_xla_available(monkeypatch: pytest.MonkeyPatch, value: bool = True) -> None: - # First, mock torch_xla modules in sys.modules so imports succeed + monkeypatch.setattr(lightning.fabric.accelerators.xla, "_XLA_AVAILABLE", value) + monkeypatch.setattr(lightning.fabric.plugins.environments.xla, "_XLA_AVAILABLE", value) + monkeypatch.setattr(lightning.fabric.plugins.precision.xla, "_XLA_AVAILABLE", value) + monkeypatch.setattr(lightning.fabric.plugins.io.xla, "_XLA_AVAILABLE", value) + monkeypatch.setattr(lightning.fabric.strategies.single_xla, "_XLA_AVAILABLE", value) + monkeypatch.setattr(lightning.fabric.strategies.xla_fsdp, "_XLA_AVAILABLE", value) + monkeypatch.setattr(lightning.fabric.strategies.launchers.xla, "_XLA_AVAILABLE", value) monkeypatch.setitem(sys.modules, "torch_xla", Mock()) monkeypatch.setitem(sys.modules, "torch_xla.core.xla_model", Mock()) monkeypatch.setitem(sys.modules, "torch_xla.experimental", Mock()) monkeypatch.setitem(sys.modules, "torch_xla.distributed.fsdp.wrap", Mock()) - monkeypatch.setitem(sys.modules, "torch_xla._internal", Mock()) - monkeypatch.setitem(sys.modules, "torch_xla._internal.tpu", Mock()) - - # Then patch the _XLA_AVAILABLE flags in various modules - monkeypatch.setattr(pytorch_lightning_enterprise.utils.imports, "_XLA_AVAILABLE", value) - monkeypatch.setattr(pytorch_lightning_enterprise.utils.imports, "_XLA_GREATER_EQUAL_2_1", value) - monkeypatch.setattr(pytorch_lightning_enterprise.utils.imports, "_XLA_GREATER_EQUAL_2_5", value) - monkeypatch.setattr(lightning.fabric.accelerators.xla, "_XLA_AVAILABLE", value) - monkeypatch.setattr(lightning.fabric.accelerators.xla, "_XLA_GREATER_EQUAL_2_1", value) - monkeypatch.setattr(lightning.fabric.accelerators.xla, "_XLA_GREATER_EQUAL_2_5", value) - # Patch in the modules where they're used after import - monkeypatch.setattr("pytorch_lightning_enterprise.accelerators.xla._XLA_AVAILABLE", value) - monkeypatch.setattr("pytorch_lightning_enterprise.accelerators.xla._XLA_GREATER_EQUAL_2_1", value) - monkeypatch.setattr("pytorch_lightning_enterprise.accelerators.xla._XLA_GREATER_EQUAL_2_5", value) - monkeypatch.setattr("pytorch_lightning_enterprise.plugins.environments.xla._XLA_AVAILABLE", value) - monkeypatch.setattr("pytorch_lightning_enterprise.plugins.environments.xla._XLA_GREATER_EQUAL_2_1", value) - monkeypatch.setattr("pytorch_lightning_enterprise.plugins.precision.xla._XLA_AVAILABLE", value) - monkeypatch.setattr("pytorch_lightning_enterprise.strategies.xla.single._XLA_AVAILABLE", value) - monkeypatch.setattr("pytorch_lightning_enterprise.strategies.xla.ddp._XLA_AVAILABLE", value) - monkeypatch.setattr("pytorch_lightning_enterprise.strategies.xla.ddp._XLA_GREATER_EQUAL_2_1", value) - monkeypatch.setattr("pytorch_lightning_enterprise.strategies.xla.fsdp._XLA_AVAILABLE", value) - monkeypatch.setattr("pytorch_lightning_enterprise.strategies.xla.launcher._XLA_AVAILABLE", value) @pytest.fixture @@ -188,14 +166,6 @@ def mock_tpu_available(monkeypatch: pytest.MonkeyPatch, value: bool = True) -> N mock_xla_available(monkeypatch, value) monkeypatch.setattr(lightning.fabric.accelerators.xla.XLAAccelerator, "is_available", lambda: value) monkeypatch.setattr(lightning.fabric.accelerators.xla.XLAAccelerator, "auto_device_count", lambda *_: 8) - # Also mock the enterprise XLAAccelerator methods - import pytorch_lightning_enterprise.accelerators.xla - - monkeypatch.setattr(pytorch_lightning_enterprise.accelerators.xla.XLAAccelerator, "is_available", lambda: value) - monkeypatch.setattr(pytorch_lightning_enterprise.accelerators.xla.XLAAccelerator, "auto_device_count", lambda *_: 8) - monkeypatch.setitem(sys.modules, "torch_xla", Mock()) - monkeypatch.setitem(sys.modules, "torch_xla.core.xla_model", Mock()) - monkeypatch.setitem(sys.modules, "torch_xla.experimental", Mock()) @pytest.fixture diff --git a/tests/tests_fabric/graveyard/test_tpu.py b/tests/tests_fabric/graveyard/test_tpu.py index 089720a381cca..5b72d60491df7 100644 --- a/tests/tests_fabric/graveyard/test_tpu.py +++ b/tests/tests_fabric/graveyard/test_tpu.py @@ -11,11 +11,11 @@ ("lightning.fabric.strategies.single_tpu", "SingleTPUStrategy"), ], ) -def test_graveyard_single_tpu(import_path, name, tpu_available): +def test_graveyard_single_tpu(import_path, name): module = import_module(import_path) cls = getattr(module, name) device = torch.device("cpu") - with pytest.deprecated_call(match="is deprecated"): + with pytest.deprecated_call(match="is deprecated"), pytest.raises(ModuleNotFoundError, match="torch_xla"): cls(device) @@ -34,12 +34,8 @@ def test_graveyard_single_tpu(import_path, name, tpu_available): ("lightning.fabric.plugins.precision.xlabf16", "XLABf16Precision"), ], ) -def test_graveyard_no_device(import_path, name, tpu_available): +def test_graveyard_no_device(import_path, name): module = import_module(import_path) cls = getattr(module, name) - with pytest.deprecated_call(match="is deprecated"): - _instance = cls() - - # required to prevent env-var leakage - if hasattr(cls, "teardown"): - cls.teardown(_instance) + with pytest.deprecated_call(match="is deprecated"), pytest.raises(ModuleNotFoundError, match="torch_xla"): + cls() diff --git a/tests/tests_fabric/plugins/environments/test_kubeflow.py b/tests/tests_fabric/plugins/environments/test_kubeflow.py index 72fb527d3facb..3436adc9ce2aa 100644 --- a/tests/tests_fabric/plugins/environments/test_kubeflow.py +++ b/tests/tests_fabric/plugins/environments/test_kubeflow.py @@ -61,14 +61,14 @@ def test_attributes_from_environment_variables(caplog): assert env.local_rank() == 0 assert env.node_rank() == 1 # setter should be no-op - with caplog.at_level(logging.DEBUG, logger="pytorch_lightning_enterprise.plugins.environments.kubeflow"): + with caplog.at_level(logging.DEBUG, logger="lightning.fabric.plugins.environments"): env.set_global_rank(100) assert env.global_rank() == 1 assert "setting global rank is not allowed" in caplog.text caplog.clear() - with caplog.at_level(logging.DEBUG, logger="pytorch_lightning_enterprise.plugins.environments.kubeflow"): + with caplog.at_level(logging.DEBUG, logger="lightning.fabric.plugins.environments"): env.set_world_size(100) assert env.world_size() == 20 assert "setting world size is not allowed" in caplog.text diff --git a/tests/tests_fabric/plugins/environments/test_slurm.py b/tests/tests_fabric/plugins/environments/test_slurm.py index 742fc44855083..b907c287faa5f 100644 --- a/tests/tests_fabric/plugins/environments/test_slurm.py +++ b/tests/tests_fabric/plugins/environments/test_slurm.py @@ -21,6 +21,7 @@ from lightning_utilities.test.warning import no_warning_call from lightning.fabric.plugins.environments import SLURMEnvironment +from lightning.fabric.utilities.warnings import PossibleUserWarning from tests_fabric.helpers.runif import RunIf @@ -71,14 +72,14 @@ def test_attributes_from_environment_variables(caplog): assert env.node_rank() == 3 assert env.job_name() == "JOB" # setter should be no-op - with caplog.at_level(logging.DEBUG, logger="pytorch_lightning_enterprise.plugins.environments.slurm"): + with caplog.at_level(logging.DEBUG, logger="lightning.fabric.plugins.environments"): env.set_global_rank(100) assert env.global_rank() == 1 assert "setting global rank is not allowed" in caplog.text caplog.clear() - with caplog.at_level(logging.DEBUG, logger="pytorch_lightning_enterprise.plugins.environments.slurm"): + with caplog.at_level(logging.DEBUG, logger="lightning.fabric.plugins.environments"): env.set_world_size(100) assert env.world_size() == 20 assert "setting world size is not allowed" in caplog.text @@ -134,18 +135,18 @@ def test_detect(): def test_srun_available_and_not_used(monkeypatch): """Test that a warning is emitted if Lightning suspects the user forgot to run their script with `srun`.""" monkeypatch.setattr(sys, "argv", ["train.py", "--lr", "0.01"]) - expected = r"`srun` .* available .* but is not used. HINT: .* srun python\d* train.py --lr 0.01" + expected = "`srun` .* available .* but is not used. HINT: .* srun python train.py --lr 0.01" # pretend `srun` is available with mock.patch("lightning.fabric.plugins.environments.slurm.shutil.which", return_value="/usr/bin/srun"): - with pytest.warns(UserWarning, match=expected): + with pytest.warns(PossibleUserWarning, match=expected): SLURMEnvironment() - with pytest.warns(UserWarning, match=expected): + with pytest.warns(PossibleUserWarning, match=expected): SLURMEnvironment.detect() # no warning if `srun` is unavailable - with no_warning_call(UserWarning, match=expected): + with no_warning_call(PossibleUserWarning, match=expected): SLURMEnvironment() assert not SLURMEnvironment.detect() @@ -175,7 +176,7 @@ def test_validate_user_settings(): # in interactive mode, validation is skipped because processes get launched by Fabric/Trainer, not SLURM with mock.patch( - "pytorch_lightning_enterprise.plugins.environments.slurm.SLURMEnvironment.job_name", return_value="interactive" + "lightning.fabric.plugins.environments.slurm.SLURMEnvironment.job_name", return_value="interactive" ): env = SLURMEnvironment() env.validate_settings(num_devices=4, num_nodes=1) # no error diff --git a/tests/tests_fabric/plugins/environments/test_torchelastic.py b/tests/tests_fabric/plugins/environments/test_torchelastic.py index cfd196c231036..161d42894df30 100644 --- a/tests/tests_fabric/plugins/environments/test_torchelastic.py +++ b/tests/tests_fabric/plugins/environments/test_torchelastic.py @@ -58,14 +58,14 @@ def test_attributes_from_environment_variables(caplog): assert env.local_rank() == 2 assert env.node_rank() == 3 # setter should be no-op - with caplog.at_level(logging.DEBUG, logger="pytorch_lightning_enterprise.plugins.environments.torchelastic"): + with caplog.at_level(logging.DEBUG, logger="lightning.fabric.plugins.environments"): env.set_global_rank(100) assert env.global_rank() == 1 assert "setting global rank is not allowed" in caplog.text caplog.clear() - with caplog.at_level(logging.DEBUG, logger="pytorch_lightning_enterprise.plugins.environments.torchelastic"): + with caplog.at_level(logging.DEBUG, logger="lightning.fabric.plugins.environments"): env.set_world_size(100) assert env.world_size() == 20 assert "setting world size is not allowed" in caplog.text diff --git a/tests/tests_fabric/plugins/environments/test_xla.py b/tests/tests_fabric/plugins/environments/test_xla.py index 0557f5620e8e1..76fd18012ee94 100644 --- a/tests/tests_fabric/plugins/environments/test_xla.py +++ b/tests/tests_fabric/plugins/environments/test_xla.py @@ -100,9 +100,8 @@ def test_detect(monkeypatch): @mock.patch.dict(os.environ, {}, clear=True) -@mock.patch("pytorch_lightning_enterprise.utils.imports._XLA_AVAILABLE", True) -@mock.patch("pytorch_lightning_enterprise.utils.imports._XLA_GREATER_EQUAL_2_1", True) @mock.patch("lightning.fabric.accelerators.xla._XLA_GREATER_EQUAL_2_1", True) +@mock.patch("lightning.fabric.plugins.environments.xla._XLA_GREATER_EQUAL_2_1", True) def test_world_size_from_xla_runtime_greater_2_1(xla_available): """Test that world_size uses torch_xla.runtime when XLA >= 2.1.""" env = XLAEnvironment() @@ -114,9 +113,8 @@ def test_world_size_from_xla_runtime_greater_2_1(xla_available): @mock.patch.dict(os.environ, {}, clear=True) -@mock.patch("pytorch_lightning_enterprise.utils.imports._XLA_AVAILABLE", True) -@mock.patch("pytorch_lightning_enterprise.utils.imports._XLA_GREATER_EQUAL_2_1", True) @mock.patch("lightning.fabric.accelerators.xla._XLA_GREATER_EQUAL_2_1", True) +@mock.patch("lightning.fabric.plugins.environments.xla._XLA_GREATER_EQUAL_2_1", True) def test_global_rank_from_xla_runtime_greater_2_1(xla_available): """Test that global_rank uses torch_xla.runtime when XLA >= 2.1.""" env = XLAEnvironment() @@ -128,9 +126,8 @@ def test_global_rank_from_xla_runtime_greater_2_1(xla_available): @mock.patch.dict(os.environ, {}, clear=True) -@mock.patch("pytorch_lightning_enterprise.utils.imports._XLA_AVAILABLE", True) -@mock.patch("pytorch_lightning_enterprise.utils.imports._XLA_GREATER_EQUAL_2_1", True) @mock.patch("lightning.fabric.accelerators.xla._XLA_GREATER_EQUAL_2_1", True) +@mock.patch("lightning.fabric.plugins.environments.xla._XLA_GREATER_EQUAL_2_1", True) def test_local_rank_from_xla_runtime_greater_2_1(xla_available): """Test that local_rank uses torch_xla.runtime when XLA >= 2.1.""" env = XLAEnvironment() @@ -142,9 +139,8 @@ def test_local_rank_from_xla_runtime_greater_2_1(xla_available): @mock.patch.dict(os.environ, {}, clear=True) -@mock.patch("pytorch_lightning_enterprise.utils.imports._XLA_AVAILABLE", True) -@mock.patch("pytorch_lightning_enterprise.utils.imports._XLA_GREATER_EQUAL_2_1", True) @mock.patch("lightning.fabric.accelerators.xla._XLA_GREATER_EQUAL_2_1", True) +@mock.patch("lightning.fabric.plugins.environments.xla._XLA_GREATER_EQUAL_2_1", True) def test_setters_readonly_when_xla_runtime_greater_2_1(xla_available): """Test that set_world_size and set_global_rank don't affect values when using XLA runtime >= 2.1.""" env = XLAEnvironment() diff --git a/tests/tests_fabric/plugins/precision/test_transformer_engine.py b/tests/tests_fabric/plugins/precision/test_transformer_engine.py index 22908e107131b..ed7c984b1ae64 100644 --- a/tests/tests_fabric/plugins/precision/test_transformer_engine.py +++ b/tests/tests_fabric/plugins/precision/test_transformer_engine.py @@ -15,22 +15,19 @@ from unittest.mock import Mock import pytest -import pytorch_lightning_enterprise.utils.imports import torch import torch.distributed +import lightning.fabric from lightning.fabric.connector import _Connector from lightning.fabric.plugins.precision.transformer_engine import TransformerEnginePrecision def test_transformer_engine_plugin(monkeypatch): - module = pytorch_lightning_enterprise.utils.imports + module = lightning.fabric.plugins.precision.transformer_engine if module._TRANSFORMER_ENGINE_AVAILABLE: pytest.skip("Assumes transformer_engine is unavailable") - monkeypatch.setattr(module, "_TRANSFORMER_ENGINE_AVAILABLE", True) - monkeypatch.setattr( - "pytorch_lightning_enterprise.plugins.precision.transformer_engine._TRANSFORMER_ENGINE_AVAILABLE", True - ) + monkeypatch.setattr(module, "_TRANSFORMER_ENGINE_AVAILABLE", lambda: True) transformer_engine_mock = Mock() monkeypatch.setitem(sys.modules, "transformer_engine", transformer_engine_mock) monkeypatch.setitem(sys.modules, "transformer_engine.pytorch", Mock()) @@ -121,11 +118,8 @@ class TELayerNormMock(Mock): ... def test_convert_module_handles_linear_without_bias(monkeypatch): - module = pytorch_lightning_enterprise.utils.imports # Set up mock transformer_engine - monkeypatch.setattr(module, "_TRANSFORMER_ENGINE_AVAILABLE", True) - monkeypatch.setattr( - "pytorch_lightning_enterprise.plugins.precision.transformer_engine._TRANSFORMER_ENGINE_AVAILABLE", True - ) + module = lightning.fabric.plugins.precision.transformer_engine # Set up mock transformer_engine + monkeypatch.setattr(module, "_TRANSFORMER_ENGINE_AVAILABLE", lambda: True) transformer_engine_mock = Mock() monkeypatch.setitem(sys.modules, "transformer_engine", transformer_engine_mock) diff --git a/tests/tests_fabric/strategies/test_deepspeed.py b/tests/tests_fabric/strategies/test_deepspeed.py index 022e1320068ff..0194c7b87820a 100644 --- a/tests/tests_fabric/strategies/test_deepspeed.py +++ b/tests/tests_fabric/strategies/test_deepspeed.py @@ -75,7 +75,7 @@ def test_deepspeed_defaults(): strategy = DeepSpeedStrategy() assert strategy.config is not None assert isinstance(strategy.config["zero_optimization"], dict) - assert strategy.deepspeed_impl._backward_sync_control is None + assert strategy._backward_sync_control is None @RunIf(deepspeed=True) @@ -230,7 +230,7 @@ def test_deepspeed_save_checkpoint_exclude_frozen_parameters(exclude_frozen_para from deepspeed import DeepSpeedEngine strategy = DeepSpeedStrategy(exclude_frozen_parameters=exclude_frozen_parameters) - assert strategy.deepspeed_impl.exclude_frozen_parameters is exclude_frozen_parameters + assert strategy.exclude_frozen_parameters is exclude_frozen_parameters model = Mock(spec=DeepSpeedEngine, optimizer=None) model.modules.return_value = [model] @@ -275,7 +275,7 @@ def test_deepspeed_load_checkpoint_no_state(tmp_path): @RunIf(deepspeed=True) -@mock.patch("pytorch_lightning_enterprise.strategies.deepspeed._is_deepspeed_checkpoint", return_value=True) +@mock.patch("lightning.fabric.strategies.deepspeed._is_deepspeed_checkpoint", return_value=True) def test_deepspeed_load_checkpoint_one_deepspeed_engine_required(_, tmp_path): """Test that the DeepSpeed strategy can only load one DeepSpeedEngine per checkpoint.""" from deepspeed import DeepSpeedEngine @@ -317,7 +317,7 @@ def test_deepspeed_load_checkpoint_client_state_missing(tmp_path): @RunIf(deepspeed=True) -@mock.patch("pytorch_lightning_enterprise.strategies.deepspeed._is_deepspeed_checkpoint", return_value=True) +@mock.patch("lightning.fabric.strategies.deepspeed._is_deepspeed_checkpoint", return_value=True) def test_deepspeed_load_checkpoint_state_updated_with_client_state(_, tmp_path): """Test that the DeepSpeed strategy properly updates the state variables and returns additional metadata.""" from deepspeed import DeepSpeedEngine @@ -342,7 +342,7 @@ def test_deepspeed_load_checkpoint_state_updated_with_client_state(_, tmp_path): @RunIf(deepspeed=True) @pytest.mark.parametrize("optimzer_state_requested", [True, False]) -@mock.patch("pytorch_lightning_enterprise.strategies.deepspeed._is_deepspeed_checkpoint", return_value=True) +@mock.patch("lightning.fabric.strategies.deepspeed._is_deepspeed_checkpoint", return_value=True) def test_deepspeed_load_checkpoint_optimzer_state_requested(_, optimzer_state_requested, tmp_path): """Test that the DeepSpeed strategy loads the optimizer state only when requested.""" from deepspeed import DeepSpeedEngine @@ -427,7 +427,6 @@ def test_validate_parallel_devices_indices(device_indices): """ accelerator = Mock(spec=CUDAAccelerator) - accelerator.name.return_value = "cuda" strategy = DeepSpeedStrategy( accelerator=accelerator, parallel_devices=[torch.device("cuda", i) for i in device_indices] ) diff --git a/tests/tests_fabric/strategies/test_deepspeed_integration.py b/tests/tests_fabric/strategies/test_deepspeed_integration.py index 8f2e0034f5dda..2a04c6b27dded 100644 --- a/tests/tests_fabric/strategies/test_deepspeed_integration.py +++ b/tests/tests_fabric/strategies/test_deepspeed_integration.py @@ -164,7 +164,7 @@ def test_deepspeed_custom_precision_params(): devices=1, ) fabric.launch() - assert fabric._strategy.deepspeed_impl._config_initialized + assert fabric._strategy._config_initialized assert fabric._strategy.config["fp16"]["loss_scale"] == 10 assert fabric._strategy.config["fp16"]["initial_scale_power"] == 11 assert fabric._strategy.config["fp16"]["loss_scale_window"] == 12 @@ -251,7 +251,7 @@ def test_deepspeed_env_variables_on_platforms(_, deepspeed_dist_mock, platform): strategy = fabric._strategy assert isinstance(strategy, DeepSpeedStrategy) with mock.patch("platform.system", return_value=platform) as platform_mock: - strategy.deepspeed_impl._init_deepspeed_distributed() + strategy._init_deepspeed_distributed() deepspeed_dist_mock.assert_called() platform_mock.assert_called() if platform == "Windows": diff --git a/tests/tests_fabric/strategies/test_xla.py b/tests/tests_fabric/strategies/test_xla.py index 1884208dc2126..9ca21d8d8b894 100644 --- a/tests/tests_fabric/strategies/test_xla.py +++ b/tests/tests_fabric/strategies/test_xla.py @@ -194,14 +194,14 @@ def test_rank_properties_access(xla_available): strategy.cluster_environment = Mock() # we're in the main process, no processes have been launched yet - assert not strategy.xla_strategy_impl._launched + assert not strategy._launched assert strategy.global_rank == 0 assert strategy.local_rank == 0 assert strategy.node_rank == 0 assert strategy.world_size == 1 # simulate we're in a worker process - strategy.xla_strategy_impl._launched = True + strategy._launched = True assert strategy.global_rank == strategy.cluster_environment.global_rank() assert strategy.local_rank == strategy.cluster_environment.local_rank() assert strategy.node_rank == strategy.cluster_environment.node_rank() diff --git a/tests/tests_fabric/strategies/test_xla_fsdp.py b/tests/tests_fabric/strategies/test_xla_fsdp.py index fa1021402590e..c2634283ad110 100644 --- a/tests/tests_fabric/strategies/test_xla_fsdp.py +++ b/tests/tests_fabric/strategies/test_xla_fsdp.py @@ -23,6 +23,7 @@ from lightning.fabric.accelerators import XLAAccelerator from lightning.fabric.plugins import XLAPrecision from lightning.fabric.strategies import XLAFSDPStrategy +from lightning.fabric.strategies.xla_fsdp import _activation_checkpointing_auto_wrapper, _XLAFSDPBackwardSyncControl from tests_fabric.helpers.runif import RunIf @@ -42,7 +43,6 @@ def test_xla_fsdp_setup_optimizer_validation(): def test_xla_fsdp_no_backward_sync(): """Test that the backward sync control calls `.no_sync()`, and only on a module wrapped in XlaFullyShardedDataParallel.""" - from pytorch_lightning_enterprise.strategies.xla.fsdp import _XLAFSDPBackwardSyncControl from torch_xla.distributed.fsdp import XlaFullyShardedDataParallel strategy = XLAFSDPStrategy() @@ -75,45 +75,41 @@ def test_xla_fsdp_grad_clipping_value_error(): @mock.patch.dict(os.environ, os.environ.copy(), clear=True) -@mock.patch("pytorch_lightning_enterprise.strategies.xla.fsdp._XLA_AVAILABLE", True) def test_rank_properties_access(xla_available): """Test that the strategy returns the expected values depending on whether we're in the main process or not.""" strategy = XLAFSDPStrategy() strategy.cluster_environment = Mock() # we're in the main process, no processes have been launched yet - assert not strategy.xla_fsdp_impl._launched + assert not strategy._launched assert strategy.global_rank == 0 assert strategy.local_rank == 0 assert strategy.node_rank == 0 assert strategy.world_size == 1 # simulate we're in a worker process - strategy.xla_fsdp_impl._launched = True + strategy._launched = True assert strategy.global_rank == strategy.cluster_environment.global_rank() assert strategy.local_rank == strategy.cluster_environment.local_rank() assert strategy.node_rank == strategy.cluster_environment.node_rank() assert strategy.world_size == strategy.cluster_environment.world_size() -@mock.patch("pytorch_lightning_enterprise.strategies.xla.fsdp._XLA_AVAILABLE", True) def test_xla_fsdp_policy(xla_available): - from pytorch_lightning_enterprise.strategies.xla import fsdp as fsdp_module - strategy = XLAFSDPStrategy(foo=1) - assert strategy.xla_fsdp_impl._fsdp_kwargs == {"foo": 1} + assert strategy._fsdp_kwargs == {"foo": 1} strategy = XLAFSDPStrategy(auto_wrap_policy={torch.nn.Linear}) - kwargs = strategy.xla_fsdp_impl._parse_fsdp_kwargs() + kwargs = strategy._parse_fsdp_kwargs() assert set(kwargs) == {"auto_wrap_policy", "compute_dtype"} assert kwargs["auto_wrap_policy"].func._mock_name == "transformer_auto_wrap_policy" assert kwargs["compute_dtype"] is torch.float32 strategy = XLAFSDPStrategy(activation_checkpointing_policy={torch.nn.Linear}) - _ = strategy.xla_fsdp_impl._parse_fsdp_kwargs() - kwargs = strategy.xla_fsdp_impl._parse_fsdp_kwargs() # ensure it's idempotent + _ = strategy._parse_fsdp_kwargs() + kwargs = strategy._parse_fsdp_kwargs() # ensure it's idempotent assert set(kwargs) == {"auto_wrapper_callable", "compute_dtype"} - assert kwargs["auto_wrapper_callable"].func is fsdp_module._activation_checkpointing_auto_wrapper + assert kwargs["auto_wrapper_callable"].func is _activation_checkpointing_auto_wrapper assert kwargs["compute_dtype"] is torch.float32 strategy = XLAFSDPStrategy( @@ -122,17 +118,17 @@ def test_xla_fsdp_policy(xla_available): activation_checkpointing_policy={torch.nn.Linear}, precision=XLAPrecision("bf16-true"), ) - kwargs = strategy.xla_fsdp_impl._parse_fsdp_kwargs() + kwargs = strategy._parse_fsdp_kwargs() assert set(kwargs) == {"auto_wrap_policy", "auto_wrapper_callable", "compute_dtype"} assert kwargs["auto_wrap_policy"].func._mock_name == "transformer_auto_wrap_policy" - assert kwargs["auto_wrapper_callable"].func is fsdp_module._activation_checkpointing_auto_wrapper + assert kwargs["auto_wrapper_callable"].func is _activation_checkpointing_auto_wrapper assert kwargs["compute_dtype"] is torch.bfloat16 strategy.teardown() strategy = XLAFSDPStrategy(activation_checkpointing_policy={torch.nn.Linear}, auto_wrapper_callable="foo") with pytest.raises(ValueError, match="cannot set both"): - strategy.xla_fsdp_impl._parse_fsdp_kwargs() + strategy._parse_fsdp_kwargs() strategy = XLAFSDPStrategy(activation_checkpointing_policy="foo") with pytest.raises(TypeError, match="must be a set"): - strategy.xla_fsdp_impl._parse_fsdp_kwargs() + strategy._parse_fsdp_kwargs() diff --git a/tests/tests_fabric/test_connector.py b/tests/tests_fabric/test_connector.py index de9f7b2d16395..1074789e71055 100644 --- a/tests/tests_fabric/test_connector.py +++ b/tests/tests_fabric/test_connector.py @@ -242,10 +242,8 @@ class TestStrategy(DDPStrategy): ), ], ) -@mock.patch( - "pytorch_lightning_enterprise.plugins.environments.lsf.LSFEnvironment._read_hosts", return_value=["node0", "node1"] -) -@mock.patch("pytorch_lightning_enterprise.plugins.environments.lsf.LSFEnvironment._get_node_rank", return_value=0) +@mock.patch("lightning.fabric.plugins.environments.lsf.LSFEnvironment._read_hosts", return_value=["node0", "node1"]) +@mock.patch("lightning.fabric.plugins.environments.lsf.LSFEnvironment._get_node_rank", return_value=0) def test_fallback_from_ddp_spawn_to_ddp_on_cluster(_, __, env_vars, expected_environment): with mock.patch.dict(os.environ, env_vars, clear=True): connector = _Connector(strategy="ddp_spawn", accelerator="cpu", devices=2) @@ -343,8 +341,8 @@ def test_cuda_accelerator_can_not_run_on_system(_): @pytest.mark.skipif(XLAAccelerator.is_available(), reason="test requires missing TPU") -@mock.patch("pytorch_lightning_enterprise.accelerators.xla._XLA_AVAILABLE", True) -@mock.patch("pytorch_lightning_enterprise.accelerators.xla._using_pjrt", return_value=True) +@mock.patch("lightning.fabric.accelerators.xla._XLA_AVAILABLE", True) +@mock.patch("lightning.fabric.accelerators.xla._using_pjrt", return_value=True) def test_tpu_accelerator_can_not_run_on_system(_): with pytest.raises(RuntimeError, match="XLAAccelerator` can not run on your system"): _Connector(accelerator="tpu", devices=8) @@ -890,7 +888,7 @@ def test_precision_selection_model_parallel(_, precision, raises): def test_bitsandbytes_precision_cuda_required(monkeypatch): - monkeypatch.setattr("pytorch_lightning_enterprise.plugins.precision.bitsandbytes._BITSANDBYTES_AVAILABLE", True) + monkeypatch.setattr(lightning.fabric.plugins.precision.bitsandbytes, "_BITSANDBYTES_AVAILABLE", True) monkeypatch.setitem(sys.modules, "bitsandbytes", Mock()) with pytest.raises(RuntimeError, match="Bitsandbytes is only supported on CUDA GPUs"): _Connector(accelerator="cpu", plugins=BitsandbytesPrecision(mode="int8")) diff --git a/tests/tests_fabric/utilities/test_throughput.py b/tests/tests_fabric/utilities/test_throughput.py index 785ec0a3ed2e9..a175fa97fd444 100644 --- a/tests/tests_fabric/utilities/test_throughput.py +++ b/tests/tests_fabric/utilities/test_throughput.py @@ -45,13 +45,7 @@ def test_get_available_flops(xla_available): ): assert get_available_flops(torch.device("cuda"), torch.bfloat16) is None - # Import from the right module based on _XLA_GREATER_EQUAL_2_1 - from lightning.fabric.accelerators.xla import _XLA_GREATER_EQUAL_2_1 - - if _XLA_GREATER_EQUAL_2_1: - from torch_xla._internal import tpu - else: - from torch_xla.experimental import tpu + from torch_xla.experimental import tpu assert isinstance(tpu, Mock) diff --git a/tests/tests_pytorch/accelerators/test_xla.py b/tests/tests_pytorch/accelerators/test_xla.py index 52434ee3a19d5..5e56d5c585c88 100644 --- a/tests/tests_pytorch/accelerators/test_xla.py +++ b/tests/tests_pytorch/accelerators/test_xla.py @@ -22,6 +22,7 @@ from torch import nn from torch.utils.data import DataLoader +import lightning.fabric from lightning.fabric.utilities.imports import _IS_WINDOWS from lightning.pytorch import Trainer from lightning.pytorch.accelerators import CPUAccelerator, XLAAccelerator @@ -311,7 +312,7 @@ def test_warning_if_tpus_not_used(tpu_available): ) @RunIf(min_python="3.9") # mocking issue def test_trainer_config_device_ids(devices, expected_device_ids, tpu_available, monkeypatch): - monkeypatch.setattr("pytorch_lightning_enterprise.accelerators.xla._using_pjrt", lambda: True) + monkeypatch.setattr(lightning.fabric.accelerators.xla, "_using_pjrt", lambda: True) mock = DeviceMock() monkeypatch.setattr(torch, "device", mock) diff --git a/tests/tests_pytorch/callbacks/progress/test_tqdm_progress_bar.py b/tests/tests_pytorch/callbacks/progress/test_tqdm_progress_bar.py index e764765ca75f7..0bd29b998c598 100644 --- a/tests/tests_pytorch/callbacks/progress/test_tqdm_progress_bar.py +++ b/tests/tests_pytorch/callbacks/progress/test_tqdm_progress_bar.py @@ -480,9 +480,8 @@ def predict_step(self, *args, **kwargs): return super().predict_step(*args, **kwargs) -@mock.patch("builtins.print") -@mock.patch("lightning.pytorch.callbacks.progress.tqdm_progress.Tqdm.write") -def test_tqdm_progress_bar_print(tqdm_write, mock_print, tmp_path): +@mock.patch("tqdm.tqdm.write") +def test_tqdm_progress_bar_print(tqdm_write, tmp_path): """Test that printing in the LightningModule redirects arguments to the progress bar.""" model = PrintModel() bar = TQDMProgressBar() @@ -507,9 +506,8 @@ def test_tqdm_progress_bar_print(tqdm_write, mock_print, tmp_path): ] -@mock.patch("builtins.print") -@mock.patch("lightning.pytorch.callbacks.progress.tqdm_progress.Tqdm.write") -def test_tqdm_progress_bar_print_no_train(tqdm_write, mock_print, tmp_path): +@mock.patch("tqdm.tqdm.write") +def test_tqdm_progress_bar_print_no_train(tqdm_write, tmp_path): """Test that printing in the LightningModule redirects arguments to the progress bar without training.""" model = PrintModel() bar = TQDMProgressBar() diff --git a/tests/tests_pytorch/conftest.py b/tests/tests_pytorch/conftest.py index 53dd20e7e4399..878298c6bfd94 100644 --- a/tests/tests_pytorch/conftest.py +++ b/tests/tests_pytorch/conftest.py @@ -22,7 +22,6 @@ from unittest.mock import Mock import pytest -import pytorch_lightning_enterprise.utils.imports import torch.distributed from tqdm import TMonitor @@ -102,8 +101,6 @@ def restore_env_variables(): "TPU_ML_PLATFORM_VERSION", "LD_LIBRARY_PATH", "ENABLE_RUNTIME_UPTIME_TELEMETRY", - "TQDM_MININTERVAL", # set by our platform - "TQDM_POSITION", # set by our platform } leaked_vars.difference_update(allowlist) assert not leaked_vars, f"test is leaking environment variable(s): {set(leaked_vars)}" @@ -223,47 +220,14 @@ def mps_count_1(monkeypatch): def mock_xla_available(monkeypatch: pytest.MonkeyPatch, value: bool = True) -> None: - # First, mock torch_xla modules in sys.modules so imports succeed - monkeypatch.setitem(sys.modules, "torch_xla", Mock()) - monkeypatch.setitem(sys.modules, "torch_xla.core", Mock()) - monkeypatch.setitem(sys.modules, "torch_xla.core.xla_model", Mock()) - monkeypatch.setitem(sys.modules, "torch_xla.core.functions", Mock()) - monkeypatch.setitem(sys.modules, "torch_xla.core.xla_env_vars", Mock()) - monkeypatch.setitem(sys.modules, "torch_xla.experimental", Mock()) - monkeypatch.setitem(sys.modules, "torch_xla.experimental.pjrt", Mock()) - monkeypatch.setitem(sys.modules, "torch_xla.experimental.tpu", Mock()) - monkeypatch.setitem(sys.modules, "torch_xla.distributed", Mock()) - monkeypatch.setitem(sys.modules, "torch_xla.distributed.fsdp", Mock()) - monkeypatch.setitem(sys.modules, "torch_xla.distributed.fsdp.wrap", Mock()) - monkeypatch.setitem(sys.modules, "torch_xla.distributed.parallel_loader", Mock()) - monkeypatch.setitem(sys.modules, "torch_xla.distributed.xla_multiprocessing", Mock()) - monkeypatch.setitem(sys.modules, "torch_xla.runtime", Mock()) - monkeypatch.setitem(sys.modules, "torch_xla.utils", Mock()) - monkeypatch.setitem(sys.modules, "torch_xla.utils.utils", Mock()) - monkeypatch.setitem(sys.modules, "torch_xla.debug", Mock()) - monkeypatch.setitem(sys.modules, "torch_xla.debug.profiler", Mock()) - monkeypatch.setitem(sys.modules, "torch_xla._internal", Mock()) - monkeypatch.setitem(sys.modules, "torch_xla._internal.tpu", Mock()) - - # Then patch the _XLA_AVAILABLE flags in various modules - monkeypatch.setattr(pytorch_lightning_enterprise.utils.imports, "_XLA_AVAILABLE", value) - monkeypatch.setattr(pytorch_lightning_enterprise.utils.imports, "_XLA_GREATER_EQUAL_2_1", value) - monkeypatch.setattr(pytorch_lightning_enterprise.utils.imports, "_XLA_GREATER_EQUAL_2_5", value) + monkeypatch.setattr(lightning.pytorch.strategies.xla, "_XLA_AVAILABLE", value) + monkeypatch.setattr(lightning.pytorch.strategies.single_xla, "_XLA_AVAILABLE", value) + monkeypatch.setattr(lightning.pytorch.plugins.precision.xla, "_XLA_AVAILABLE", value) + monkeypatch.setattr(lightning.pytorch.strategies.launchers.xla, "_XLA_AVAILABLE", value) monkeypatch.setattr(lightning.fabric.accelerators.xla, "_XLA_AVAILABLE", value) - monkeypatch.setattr(lightning.fabric.accelerators.xla, "_XLA_GREATER_EQUAL_2_1", value) - monkeypatch.setattr(lightning.fabric.accelerators.xla, "_XLA_GREATER_EQUAL_2_5", value) - # Patch in the modules where they're used after import - monkeypatch.setattr("pytorch_lightning_enterprise.accelerators.xla._XLA_AVAILABLE", value) - monkeypatch.setattr("pytorch_lightning_enterprise.accelerators.xla._XLA_GREATER_EQUAL_2_1", value) - monkeypatch.setattr("pytorch_lightning_enterprise.accelerators.xla._XLA_GREATER_EQUAL_2_5", value) - monkeypatch.setattr("pytorch_lightning_enterprise.plugins.environments.xla._XLA_AVAILABLE", value) - monkeypatch.setattr("pytorch_lightning_enterprise.plugins.environments.xla._XLA_GREATER_EQUAL_2_1", value) - monkeypatch.setattr("pytorch_lightning_enterprise.plugins.precision.xla._XLA_AVAILABLE", value) - monkeypatch.setattr("pytorch_lightning_enterprise.strategies.xla.single._XLA_AVAILABLE", value) - monkeypatch.setattr("pytorch_lightning_enterprise.strategies.xla.ddp._XLA_AVAILABLE", value) - monkeypatch.setattr("pytorch_lightning_enterprise.strategies.xla.ddp._XLA_GREATER_EQUAL_2_1", value) - monkeypatch.setattr("pytorch_lightning_enterprise.strategies.xla.fsdp._XLA_AVAILABLE", value) - monkeypatch.setattr("pytorch_lightning_enterprise.strategies.xla.launcher._XLA_AVAILABLE", value) + monkeypatch.setattr(lightning.fabric.plugins.environments.xla, "_XLA_AVAILABLE", value) + monkeypatch.setattr(lightning.fabric.plugins.io.xla, "_XLA_AVAILABLE", value) + monkeypatch.setattr(lightning.fabric.strategies.launchers.xla, "_XLA_AVAILABLE", value) @pytest.fixture @@ -273,14 +237,10 @@ def xla_available(monkeypatch: pytest.MonkeyPatch) -> None: def mock_tpu_available(monkeypatch: pytest.MonkeyPatch, value: bool = True) -> None: mock_xla_available(monkeypatch, value) + monkeypatch.setattr(lightning.pytorch.accelerators.xla.XLAAccelerator, "is_available", lambda: value) monkeypatch.setattr(lightning.fabric.accelerators.xla.XLAAccelerator, "is_available", lambda: value) + monkeypatch.setattr(lightning.pytorch.accelerators.xla.XLAAccelerator, "auto_device_count", lambda *_: 8) monkeypatch.setattr(lightning.fabric.accelerators.xla.XLAAccelerator, "auto_device_count", lambda *_: 8) - # Also mock the enterprise XLAAccelerator methods - import pytorch_lightning_enterprise.accelerators.xla - - monkeypatch.setattr(pytorch_lightning_enterprise.accelerators.xla.XLAAccelerator, "is_available", lambda: value) - monkeypatch.setattr(pytorch_lightning_enterprise.accelerators.xla.XLAAccelerator, "auto_device_count", lambda *_: 8) - monkeypatch.setitem(sys.modules, "torch_xla", Mock()) monkeypatch.setitem(sys.modules, "torch_xla.core.xla_model", Mock()) monkeypatch.setitem(sys.modules, "torch_xla.experimental", Mock()) diff --git a/tests/tests_pytorch/deprecated_api/test_no_removal_version.py b/tests/tests_pytorch/deprecated_api/test_no_removal_version.py index 7f47ad0be54bc..0b79b638534fa 100644 --- a/tests/tests_pytorch/deprecated_api/test_no_removal_version.py +++ b/tests/tests_pytorch/deprecated_api/test_no_removal_version.py @@ -2,7 +2,6 @@ from unittest.mock import Mock import pytest -import pytorch_lightning_enterprise.plugins.precision.bitsandbytes import torch.nn import lightning.fabric @@ -63,7 +62,6 @@ def test_fsdp_precision_plugin(): def test_bitsandbytes_precision_plugin(monkeypatch): monkeypatch.setattr(lightning.fabric.plugins.precision.bitsandbytes, "_BITSANDBYTES_AVAILABLE", True) - monkeypatch.setattr(pytorch_lightning_enterprise.plugins.precision.bitsandbytes, "_BITSANDBYTES_AVAILABLE", True) bitsandbytes_mock = Mock() monkeypatch.setitem(sys.modules, "bitsandbytes", bitsandbytes_mock) @@ -109,9 +107,7 @@ def test_precision_plugin(): def test_transformer_engine_precision_plugin(monkeypatch): - monkeypatch.setattr( - pytorch_lightning_enterprise.plugins.precision.transformer_engine, "_TRANSFORMER_ENGINE_AVAILABLE", True - ) + monkeypatch.setattr(lightning.fabric.plugins.precision.transformer_engine, "_TRANSFORMER_ENGINE_AVAILABLE", True) transformer_engine_mock = Mock() monkeypatch.setitem(sys.modules, "transformer_engine", transformer_engine_mock) monkeypatch.setitem(sys.modules, "transformer_engine.pytorch", Mock()) diff --git a/tests/tests_pytorch/graveyard/test_tpu.py b/tests/tests_pytorch/graveyard/test_tpu.py index eb8a7bfce07d4..0e010ca53ddcb 100644 --- a/tests/tests_pytorch/graveyard/test_tpu.py +++ b/tests/tests_pytorch/graveyard/test_tpu.py @@ -1,25 +1,9 @@ -import os from importlib import import_module import pytest import torch -# mimics `lightning_utilites.RequirementCache` -class MockXLAAvailable: - def __init__(self, available: bool, pkg_name: str = "torch_xla"): - self.available = available - self.pkg_name = pkg_name - - def __bool__(self): - return self.available - - def __str__(self): - if self.available: - return f"Requirement '{self.pkg_name}' met" - return f"Module not found: {self.pkg_name!r}. HINT: Try running `pip install -U {self.pkg_name}`" - - @pytest.mark.parametrize( ("import_path", "name"), [ @@ -50,17 +34,8 @@ def test_graveyard_single_tpu(import_path, name): ("lightning.pytorch.plugins.precision.xlabf16", "XLABf16PrecisionPlugin"), ], ) -def test_graveyard_no_device(import_path, name, monkeypatch): - monkeypatch.setattr("pytorch_lightning_enterprise.accelerators.xla._XLA_AVAILABLE", MockXLAAvailable(False)) - monkeypatch.setattr("pytorch_lightning_enterprise.plugins.precision.xla._XLA_AVAILABLE", MockXLAAvailable(False)) - +def test_graveyard_no_device(import_path, name): module = import_module(import_path) cls = getattr(module, name) with pytest.deprecated_call(match="is deprecated"), pytest.raises(ModuleNotFoundError, match="torch_xla"): cls() - - # teardown - # ideally, we should call the plugin's teardown method, but since the class - # instantiation itself fails, we directly manipulate the env vars here - os.environ.pop("XLA_USE_BF16", None) - os.environ.pop("XLA_USE_F16", None) diff --git a/tests/tests_pytorch/loggers/conftest.py b/tests/tests_pytorch/loggers/conftest.py index 79878b980900b..98819eb080eb8 100644 --- a/tests/tests_pytorch/loggers/conftest.py +++ b/tests/tests_pytorch/loggers/conftest.py @@ -38,10 +38,8 @@ def mlflow_mock(monkeypatch): mlflow.tracking = mlflow_tracking mlflow.entities = mlflow_entities - monkeypatch.setattr("pytorch_lightning_enterprise.loggers.mlflow._MLFLOW_AVAILABLE", True) - monkeypatch.setattr("pytorch_lightning_enterprise.loggers.mlflow._MLFLOW_SYNCHRONOUS_AVAILABLE", True) - monkeypatch.setattr("pytorch_lightning_enterprise.utils.imports._MLFLOW_AVAILABLE", True) - monkeypatch.setattr("pytorch_lightning_enterprise.utils.imports._MLFLOW_SYNCHRONOUS_AVAILABLE", True) + monkeypatch.setattr("lightning.pytorch.loggers.mlflow._MLFLOW_AVAILABLE", True) + monkeypatch.setattr("lightning.pytorch.loggers.mlflow._MLFLOW_SYNCHRONOUS_AVAILABLE", True) return mlflow @@ -88,8 +86,7 @@ class RunType: # to make isinstance checks pass wandb.sdk.lib = wandb_sdk_lib wandb.wandb_run = wandb_wandb_run - monkeypatch.setattr("pytorch_lightning_enterprise.loggers.wandb._WANDB_AVAILABLE", True) - monkeypatch.setattr("pytorch_lightning_enterprise.utils.imports._WANDB_AVAILABLE", True) + monkeypatch.setattr("lightning.pytorch.loggers.wandb._WANDB_AVAILABLE", True) return wandb @@ -112,8 +109,7 @@ def comet_mock(monkeypatch): comet.start = Mock(name="comet_ml.start", return_value=comet.Experiment()) comet.config = Mock() - monkeypatch.setattr("pytorch_lightning_enterprise.loggers.comet._COMET_AVAILABLE", True) - monkeypatch.setattr("pytorch_lightning_enterprise.utils.imports._COMET_AVAILABLE", True) + monkeypatch.setattr("lightning.pytorch.loggers.comet._COMET_AVAILABLE", True) return comet @@ -129,9 +125,6 @@ def __getitem__(self, item): def __setitem__(self, key, value): pass - def wait(self): - pass - run_mock = MagicMock(spec=RunType, exists=Mock(return_value=False), wait=Mock(), get_structure=MagicMock()) run_mock.get_root_object.return_value = run_mock @@ -160,6 +153,5 @@ def wait(self): neptune.types = neptune_types neptune.utils = neptune_utils - monkeypatch.setattr("pytorch_lightning_enterprise.loggers.neptune._NEPTUNE_AVAILABLE", True) - monkeypatch.setattr("pytorch_lightning_enterprise.utils.imports._NEPTUNE_AVAILABLE", True) + monkeypatch.setattr("lightning.pytorch.loggers.neptune._NEPTUNE_AVAILABLE", True) return neptune diff --git a/tests/tests_pytorch/loggers/test_all.py b/tests/tests_pytorch/loggers/test_all.py index 416676175cd63..a9cf9af79185b 100644 --- a/tests/tests_pytorch/loggers/test_all.py +++ b/tests/tests_pytorch/loggers/test_all.py @@ -70,7 +70,7 @@ def _instantiate_logger(logger_class, save_dir, **override_kwargs): @mock.patch.dict(os.environ, {}) -@mock.patch("pytorch_lightning_enterprise.loggers.mlflow._get_resolve_tags", Mock()) +@mock.patch("lightning.pytorch.loggers.mlflow._get_resolve_tags", Mock()) @pytest.mark.parametrize("logger_class", ALL_LOGGER_CLASSES) def test_loggers_fit_test_all(logger_class, mlflow_mock, wandb_mock, comet_mock, neptune_mock, tmp_path, monkeypatch): """Verify that basic functionality of all loggers.""" @@ -107,8 +107,8 @@ def log_metrics(self, metrics, step): if logger_class == CometLogger: logger.experiment.id = "foo" - logger.logger_impl._comet_config.offline_directory = None - logger.logger_impl._project_name = "bar" + logger._comet_config.offline_directory = None + logger._project_name = "bar" logger.experiment.get_key.return_value = "SOME_KEY" if logger_class == NeptuneLogger: @@ -287,7 +287,7 @@ def _test_logger_initialization(tmp_path, logger_class): @mock.patch.dict(os.environ, {}) -@mock.patch("pytorch_lightning_enterprise.loggers.mlflow._get_resolve_tags", Mock()) +@mock.patch("lightning.pytorch.loggers.mlflow._get_resolve_tags", Mock()) def test_logger_with_prefix_all(mlflow_mock, wandb_mock, comet_mock, neptune_mock, monkeypatch, tmp_path): """Test that prefix is added at the beginning of the metric keys.""" prefix = "tmp" @@ -335,7 +335,7 @@ def test_logger_with_prefix_all(mlflow_mock, wandb_mock, comet_mock, neptune_moc logger.experiment.log.assert_called_once_with({"tmp-test": 1.0, "trainer/global_step": 0}) -@mock.patch("pytorch_lightning_enterprise.loggers.mlflow._get_resolve_tags", Mock()) +@mock.patch("lightning.pytorch.loggers.mlflow._get_resolve_tags", Mock()) def test_logger_default_name(mlflow_mock, monkeypatch, tmp_path): """Test that the default logger name is lightning_logs.""" # CSV @@ -358,6 +358,6 @@ def test_logger_default_name(mlflow_mock, monkeypatch, tmp_path): logger = _instantiate_logger(MLFlowLogger, save_dir=tmp_path) _ = logger.experiment - logger.logger_impl._mlflow_client.create_experiment.assert_called_with(name="lightning_logs", artifact_location=ANY) + logger._mlflow_client.create_experiment.assert_called_with(name="lightning_logs", artifact_location=ANY) # on MLFLowLogger `name` refers to the experiment id # assert logger.experiment.get_experiment(logger.name).name == "lightning_logs" diff --git a/tests/tests_pytorch/loggers/test_comet.py b/tests/tests_pytorch/loggers/test_comet.py index 4108d3c3b641c..dae8f617b873e 100644 --- a/tests/tests_pytorch/loggers/test_comet.py +++ b/tests/tests_pytorch/loggers/test_comet.py @@ -89,10 +89,10 @@ def test_comet_experiment_is_still_alive_after_training_complete(comet_mock): logger.finalize("ended") # Assert that data was saved to comet.com - logger.logger_impl._experiment.flush.assert_called_once() + logger._experiment.flush.assert_called_once() # Assert that was not ended - logger.logger_impl._experiment.end.assert_not_called() + logger._experiment.end.assert_not_called() @mock.patch.dict(os.environ, {}) @@ -115,8 +115,8 @@ def test_comet_logger_experiment_name(comet_mock): experiment_config=comet_mock.ExperimentConfig(), ) # check that we saved "experiment name" in kwargs as new "name" arg - assert logger.logger_impl._kwargs["name"] == experiment_name - assert "experiment_name" not in logger.logger_impl._kwargs + assert logger._kwargs["name"] == experiment_name + assert "experiment_name" not in logger._kwargs # check that "experiment name" was passed to experiment config correctly assert call(experiment_name=experiment_name) not in comet_mock.ExperimentConfig.call_args_list @@ -130,10 +130,10 @@ def test_comet_version(comet_mock): experiment_name = "My Name" logger = CometLogger(api_key=api_key, name=experiment_name) - assert logger.logger_impl._experiment is not None + assert logger._experiment is not None _ = logger.version - logger.logger_impl._experiment.get_key.assert_called() + logger._experiment.get_key.assert_called() @mock.patch.dict(os.environ, {}) @@ -146,7 +146,7 @@ def test_comet_epoch_logging(comet_mock, tmp_path, monkeypatch): {"test": 1}, epoch=1, step=123, - prefix=logger.logger_impl._prefix, + prefix=logger._prefix, framework="pytorch-lightning", ) diff --git a/tests/tests_pytorch/loggers/test_mlflow.py b/tests/tests_pytorch/loggers/test_mlflow.py index b82800fe22f27..c7f9dbe1fe2c6 100644 --- a/tests/tests_pytorch/loggers/test_mlflow.py +++ b/tests/tests_pytorch/loggers/test_mlflow.py @@ -16,13 +16,13 @@ from unittest.mock import MagicMock, Mock import pytest -import pytorch_lightning_enterprise.loggers # noqa: F401 -from pytorch_lightning_enterprise.utils.imports import _MLFLOW_AVAILABLE from lightning.pytorch import Trainer from lightning.pytorch.demos.boring_classes import BoringModel from lightning.pytorch.loggers.mlflow import ( + _MLFLOW_AVAILABLE, MLFlowLogger, + _get_resolve_tags, ) @@ -30,13 +30,13 @@ def mock_mlflow_run_creation(logger, experiment_name=None, experiment_id=None, r """Helper function to simulate mlflow client creating a new (or existing) experiment.""" run = MagicMock() run.info.run_id = run_id - logger.logger_impl._mlflow_client.get_experiment_by_name = MagicMock(return_value=experiment_name) - logger.logger_impl._mlflow_client.create_experiment = MagicMock(return_value=experiment_id) - logger.logger_impl._mlflow_client.create_run = MagicMock(return_value=run) + logger._mlflow_client.get_experiment_by_name = MagicMock(return_value=experiment_name) + logger._mlflow_client.create_experiment = MagicMock(return_value=experiment_id) + logger._mlflow_client.create_run = MagicMock(return_value=run) return logger -@mock.patch("pytorch_lightning_enterprise.loggers.mlflow._get_resolve_tags", Mock()) +@mock.patch("lightning.pytorch.loggers.mlflow._get_resolve_tags", Mock()) def test_mlflow_logger_exists(mlflow_mock, tmp_path): """Test launching three independent loggers with either same or different experiment name.""" client = mlflow_mock.tracking.MlflowClient @@ -57,8 +57,8 @@ def test_mlflow_logger_exists(mlflow_mock, tmp_path): client.return_value.create_run = MagicMock(return_value=run1) logger = MLFlowLogger("test", save_dir=str(tmp_path)) - assert logger.logger_impl._experiment_id is None - assert logger.logger_impl._run_id is None + assert logger._experiment_id is None + assert logger._run_id is None _ = logger.experiment assert logger.experiment_id == "exp-id-1" assert logger.run_id == "run-id-1" @@ -96,14 +96,13 @@ def test_mlflow_run_name_setting(tmp_path): pytest.skip("test for explicit file creation requires mlflow dependency to be installed.") from mlflow.utils.mlflow_tags import MLFLOW_RUN_NAME - from pytorch_lightning_enterprise.loggers.mlflow import _get_resolve_tags resolve_tags = _get_resolve_tags() tags = resolve_tags({MLFLOW_RUN_NAME: "run-name-1"}) # run_name is appended to tags logger = MLFlowLogger("test", run_name="run-name-1", save_dir=str(tmp_path)) - logger.logger_impl._mlflow_client = client = Mock() + logger._mlflow_client = client = Mock() logger = mock_mlflow_run_creation(logger, experiment_id="exp-id") _ = logger.experiment @@ -123,7 +122,7 @@ def test_mlflow_run_name_setting(tmp_path): client.create_run.assert_called_with(experiment_id="exp-id", tags=default_tags) -@mock.patch("pytorch_lightning_enterprise.loggers.mlflow._get_resolve_tags", Mock()) +@mock.patch("lightning.pytorch.loggers.mlflow._get_resolve_tags", Mock()) def test_mlflow_run_id_setting(mlflow_mock, tmp_path): """Test that the run_id argument uses the provided run_id.""" client = mlflow_mock.tracking.MlflowClient @@ -144,7 +143,7 @@ def test_mlflow_run_id_setting(mlflow_mock, tmp_path): client.reset_mock(return_value=True) -@mock.patch("pytorch_lightning_enterprise.loggers.mlflow._get_resolve_tags", Mock()) +@mock.patch("lightning.pytorch.loggers.mlflow._get_resolve_tags", Mock()) def test_mlflow_log_dir(mlflow_mock, tmp_path): """Test that the trainer saves checkpoints in the logger's save dir.""" client = mlflow_mock.tracking.MlflowClient @@ -212,8 +211,8 @@ def on_train_epoch_end(self, *args, **kwargs): assert os.listdir(trainer.checkpoint_callback.dirpath) == [f"epoch=0-step={limit_batches}.ckpt"] -@mock.patch("pytorch_lightning_enterprise.loggers.mlflow._get_resolve_tags", Mock()) -@mock.patch("pytorch_lightning_enterprise.utils.imports._MLFLOW_AVAILABLE", return_value=True) +@mock.patch("lightning.pytorch.loggers.mlflow._get_resolve_tags", Mock()) +@mock.patch("lightning.pytorch.loggers.mlflow._MLFLOW_AVAILABLE", return_value=True) def test_mlflow_experiment_id_retrieved_once(_, mlflow_mock, tmp_path): """Test that the logger experiment_id retrieved only once.""" logger = MLFlowLogger("test", save_dir=str(tmp_path)) @@ -223,7 +222,7 @@ def test_mlflow_experiment_id_retrieved_once(_, mlflow_mock, tmp_path): assert logger.experiment.get_experiment_by_name.call_count == 1 -@mock.patch("pytorch_lightning_enterprise.loggers.mlflow._get_resolve_tags", Mock()) +@mock.patch("lightning.pytorch.loggers.mlflow._get_resolve_tags", Mock()) def test_mlflow_logger_with_unexpected_characters(mlflow_mock, tmp_path): """Test that the logger raises warning with special characters not accepted by MLFlow.""" logger = MLFlowLogger("test", save_dir=str(tmp_path)) @@ -233,7 +232,7 @@ def test_mlflow_logger_with_unexpected_characters(mlflow_mock, tmp_path): logger.log_metrics(metrics) -@mock.patch("pytorch_lightning_enterprise.loggers.mlflow._get_resolve_tags", Mock()) +@mock.patch("lightning.pytorch.loggers.mlflow._get_resolve_tags", Mock()) def test_mlflow_logger_experiment_calls(mlflow_mock, tmp_path): """Test that the logger calls methods on the mlflow experiment correctly.""" time = mlflow_mock.entities.time @@ -243,7 +242,7 @@ def test_mlflow_logger_experiment_calls(mlflow_mock, tmp_path): time.return_value = 1 logger = MLFlowLogger("test", save_dir=str(tmp_path), artifact_location="my_artifact_location") - logger.logger_impl._mlflow_client.get_experiment_by_name.return_value = None + logger._mlflow_client.get_experiment_by_name.return_value = None params = {"test": "test_param"} logger.log_hyperparams(params) @@ -261,13 +260,13 @@ def test_mlflow_logger_experiment_calls(mlflow_mock, tmp_path): ) metric.assert_called_with(key="some_metric", value=10, timestamp=1000, step=0) - logger.logger_impl._mlflow_client.create_experiment.assert_called_once_with( + logger._mlflow_client.create_experiment.assert_called_once_with( name="test", artifact_location="my_artifact_location" ) @pytest.mark.parametrize("synchronous", [False, True]) -@mock.patch("pytorch_lightning_enterprise.loggers.mlflow._get_resolve_tags", Mock()) +@mock.patch("lightning.pytorch.loggers.mlflow._get_resolve_tags", Mock()) def test_mlflow_logger_experiment_calls_with_synchronous(mlflow_mock, tmp_path, synchronous): """Test that the logger calls methods on the mlflow experiment with the specified synchronous flag.""" @@ -303,8 +302,8 @@ def test_mlflow_logger_experiment_calls_with_synchronous(mlflow_mock, tmp_path, mlflow_client.create_experiment.assert_called_once_with(name="test", artifact_location="my_artifact_location") -@mock.patch("pytorch_lightning_enterprise.loggers.mlflow._get_resolve_tags", Mock()) -@mock.patch("pytorch_lightning_enterprise.loggers.mlflow._MLFLOW_SYNCHRONOUS_AVAILABLE", False) +@mock.patch("lightning.pytorch.loggers.mlflow._get_resolve_tags", Mock()) +@mock.patch.dict("lightning.pytorch.loggers.mlflow.__dict__", {"_MLFLOW_SYNCHRONOUS_AVAILABLE": False}) def test_mlflow_logger_no_synchronous_support(mlflow_mock, tmp_path): """Test that the logger does not support synchronous flag.""" time = mlflow_mock.entities.time @@ -316,7 +315,7 @@ def test_mlflow_logger_no_synchronous_support(mlflow_mock, tmp_path): MLFlowLogger("test", save_dir=str(tmp_path), artifact_location="my_artifact_location", synchronous=True) -@mock.patch("pytorch_lightning_enterprise.loggers.mlflow._get_resolve_tags", Mock()) +@mock.patch("lightning.pytorch.loggers.mlflow._get_resolve_tags", Mock()) def test_mlflow_logger_with_long_param_value(mlflow_mock, tmp_path): """Test that long parameter values are truncated to 250 characters.""" @@ -334,7 +333,7 @@ def _check_value_length(value, *args, **kwargs): logger.experiment.log_batch.assert_called_once() -@mock.patch("pytorch_lightning_enterprise.loggers.mlflow._get_resolve_tags", Mock()) +@mock.patch("lightning.pytorch.loggers.mlflow._get_resolve_tags", Mock()) def test_mlflow_logger_with_many_params(mlflow_mock, tmp_path): """Test that when logging more than 100 parameters, it will be split into batches of at most 100 parameters.""" logger = MLFlowLogger("test", save_dir=str(tmp_path)) @@ -353,37 +352,37 @@ def test_mlflow_logger_with_many_params(mlflow_mock, tmp_path): ("finished", "FINISHED"), ], ) -@mock.patch("pytorch_lightning_enterprise.loggers.mlflow._get_resolve_tags", Mock()) +@mock.patch("lightning.pytorch.loggers.mlflow._get_resolve_tags", Mock()) def test_mlflow_logger_finalize(mlflow_mock, status, expected): logger = MLFlowLogger("test") # Pretend we are in a worker process and finalizing _ = logger.experiment - assert logger.logger_impl._initialized + assert logger._initialized logger.finalize(status) logger.experiment.set_terminated.assert_called_once_with(logger.run_id, expected) -@mock.patch("pytorch_lightning_enterprise.loggers.mlflow._get_resolve_tags", Mock()) +@mock.patch("lightning.pytorch.loggers.mlflow._get_resolve_tags", Mock()) def test_mlflow_logger_finalize_when_exception(mlflow_mock): logger = MLFlowLogger("test") # Pretend we are on the main process and failing - assert logger.logger_impl._mlflow_client - assert not logger.logger_impl._initialized + assert logger._mlflow_client + assert not logger._initialized logger.finalize("failed") logger.experiment.set_terminated.assert_not_called() # Pretend we are in a worker process and failing _ = logger.experiment - assert logger.logger_impl._initialized + assert logger._initialized logger.finalize("failed") logger.experiment.set_terminated.assert_called_once_with(logger.run_id, "FAILED") @pytest.mark.parametrize("log_model", ["all", True, False]) -@mock.patch("pytorch_lightning_enterprise.loggers.mlflow._get_resolve_tags", Mock()) +@mock.patch("lightning.pytorch.loggers.mlflow._get_resolve_tags", Mock()) def test_mlflow_log_model(mlflow_mock, log_model, tmp_path): """Test that the logger creates the folders and files in the right place.""" client = mlflow_mock.tracking.MlflowClient @@ -421,7 +420,7 @@ def test_mlflow_log_model(mlflow_mock, log_model, tmp_path): assert not client.return_value.log_artifacts.called -@mock.patch("pytorch_lightning_enterprise.loggers.mlflow._get_resolve_tags", Mock()) +@mock.patch("lightning.pytorch.loggers.mlflow._get_resolve_tags", Mock()) def test_set_tracking_uri(mlflow_mock): """Test that the tracking uri is set for logging artifacts to MLFlow server.""" logger = MLFlowLogger(tracking_uri="the_tracking_uri") diff --git a/tests/tests_pytorch/loggers/test_neptune.py b/tests/tests_pytorch/loggers/test_neptune.py index e50bcf967ef9a..572b85715da86 100644 --- a/tests/tests_pytorch/loggers/test_neptune.py +++ b/tests/tests_pytorch/loggers/test_neptune.py @@ -13,6 +13,7 @@ # limitations under the License. import os import pickle +from collections import namedtuple from unittest import mock from unittest.mock import MagicMock, call @@ -44,10 +45,10 @@ def _fit_and_test(logger, model, tmp_path): def _get_logger_with_mocks(**kwargs): logger = NeptuneLogger(**kwargs) run_instance_mock = MagicMock() - logger.logger_impl._run_instance = run_instance_mock - logger.logger_impl._run_instance.__getitem__.return_value.fetch.return_value = "exp-name" + logger._run_instance = run_instance_mock + logger._run_instance.__getitem__.return_value.fetch.return_value = "exp-name" run_attr_mock = MagicMock() - logger.logger_impl._run_instance.__getitem__.return_value = run_attr_mock + logger._run_instance.__getitem__.return_value = run_attr_mock return logger, run_instance_mock, run_attr_mock @@ -59,7 +60,7 @@ def test_neptune_online(neptune_mock): logger = NeptuneLogger(api_key="test", project="project") created_run_mock = logger.run - assert logger.logger_impl._run_instance == created_run_mock + assert logger._run_instance == created_run_mock created_run_mock.exists.assert_called_once_with("sys/id") assert logger.name == "Run test name" assert logger.version == "TEST-1" @@ -78,8 +79,8 @@ def test_neptune_offline(neptune_mock): logger.experiment["foo"] = "bar" created_run_mock.exists.assert_called_once_with("sys/id") - assert logger.logger_impl._run_short_id == "OFFLINE" - assert logger.logger_impl._run_name == "offline-name" + assert logger._run_short_id == "OFFLINE" + assert logger._run_name == "offline-name" def test_online_with_custom_run(neptune_mock): @@ -90,8 +91,8 @@ def test_online_with_custom_run(neptune_mock): neptune_mock.init_run.reset_mock() logger = NeptuneLogger(run=created_run) - assert logger.logger_impl._run_instance == created_run - assert logger.logger_impl._run_instance == created_run + assert logger._run_instance == created_run + assert logger._run_instance == created_run assert logger.version == "TEST-1" assert neptune_mock.init_run.call_count == 0 @@ -274,6 +275,38 @@ def test_save_dir(neptune_mock): assert logger.save_dir == os.path.join(os.getcwd(), ".neptune") +def test_get_full_model_name(): + SimpleCheckpoint = namedtuple("SimpleCheckpoint", ["dirpath"]) + test_input_data = [ + ("key", os.path.join("foo", "bar", "key.ext"), SimpleCheckpoint(dirpath=os.path.join("foo", "bar"))), + ( + "key/in/parts", + os.path.join("foo", "bar", "key/in/parts.ext"), + SimpleCheckpoint(dirpath=os.path.join("foo", "bar")), + ), + ("key", os.path.join("../foo", "bar", "key.ext"), SimpleCheckpoint(dirpath=os.path.join("../foo", "bar"))), + ("key", os.path.join("foo", "key.ext"), SimpleCheckpoint(dirpath=os.path.join("./foo", "bar/../"))), + ] + + for expected_model_name, model_path, checkpoint in test_input_data: + assert NeptuneLogger._get_full_model_name(model_path, checkpoint) == expected_model_name + + +def test_get_full_model_names_from_exp_structure(): + input_dict = { + "foo": { + "bar": { + "lvl1_1": {"lvl2": {"lvl3_1": "some non important value", "lvl3_2": "some non important value"}}, + "lvl1_2": "some non important value", + }, + "other_non_important": {"val100": 100}, + }, + "other_non_important": {"val42": 42}, + } + expected_keys = {"lvl1_1/lvl2/lvl3_1", "lvl1_1/lvl2/lvl3_2", "lvl1_2"} + assert NeptuneLogger._get_full_model_names_from_exp_structure(input_dict, "foo/bar") == expected_keys + + def test_inactive_run(neptune_mock, tmp_path, monkeypatch): from neptune.exceptions import InactiveRunException diff --git a/tests/tests_pytorch/loggers/test_wandb.py b/tests/tests_pytorch/loggers/test_wandb.py index 0409cbd7b8f2c..e9b9e9a8090b0 100644 --- a/tests/tests_pytorch/loggers/test_wandb.py +++ b/tests/tests_pytorch/loggers/test_wandb.py @@ -25,6 +25,7 @@ from lightning.pytorch.cli import LightningCLI from lightning.pytorch.demos.boring_classes import BoringModel from lightning.pytorch.loggers import TensorBoardLogger, WandbLogger +from lightning.pytorch.utilities.exceptions import MisconfigurationException def test_wandb_project_name(wandb_mock): @@ -100,7 +101,7 @@ def test_wandb_logger_init(wandb_mock): _ = logger.experiment # verify default resume value - assert logger.logger_impl._wandb_init["resume"] == "allow" + assert logger._wandb_init["resume"] == "allow" logger.log_metrics({"acc": 1.0}, step=3) wandb_mock.init.assert_called_once() @@ -144,9 +145,9 @@ def test_wandb_logger_sync_tensorboard_log_metrics(wandb_mock): def test_wandb_logger_init_before_spawn(wandb_mock): logger = WandbLogger() - assert logger.logger_impl._experiment is None - logger.logger_impl.__getstate__() - assert logger.logger_impl._experiment is not None + assert logger._experiment is None + logger.__getstate__() + assert logger._experiment is not None def test_wandb_logger_experiment_called_first(wandb_mock, tmp_path): @@ -625,7 +626,7 @@ def test_wandb_log_media(wandb_mock, tmp_path): def test_wandb_logger_offline_log_model(wandb_mock, tmp_path): """Test that log_model=True raises an error in offline mode.""" - with pytest.raises(ValueError, match="checkpoints cannot be uploaded in offline mode"): + with pytest.raises(MisconfigurationException, match="checkpoints cannot be uploaded in offline mode"): _ = WandbLogger(save_dir=tmp_path, offline=True, log_model=True) diff --git a/tests/tests_pytorch/models/test_tpu.py b/tests/tests_pytorch/models/test_tpu.py index f1878a4a2650d..e74778ee32f4e 100644 --- a/tests/tests_pytorch/models/test_tpu.py +++ b/tests/tests_pytorch/models/test_tpu.py @@ -235,7 +235,7 @@ def test_tpu_misconfiguration(devices, tpu_available): @pytest.mark.skipif(XLAAccelerator.is_available(), reason="test requires missing TPU") -@mock.patch("pytorch_lightning_enterprise.accelerators.xla._using_pjrt", return_value=True) +@mock.patch("lightning.fabric.accelerators.xla._using_pjrt", return_value=True) def test_exception_when_no_tpu_found(_, xla_available): """Test if exception is thrown when xla devices are not available.""" with pytest.raises(MisconfigurationException, match="XLAAccelerator` can not run on your system"): diff --git a/tests/tests_pytorch/plugins/precision/test_deepspeed_precision.py b/tests/tests_pytorch/plugins/precision/test_deepspeed_precision.py index 86eb93c4542a0..da4ce6b89aaab 100644 --- a/tests/tests_pytorch/plugins/precision/test_deepspeed_precision.py +++ b/tests/tests_pytorch/plugins/precision/test_deepspeed_precision.py @@ -19,7 +19,7 @@ def test_invalid_precision_with_deepspeed_precision(): - with pytest.raises(ValueError, match=r"is not supported in DeepSpeed. `precision` must be one of"): + with pytest.raises(ValueError, match="is not supported. `precision` must be one of"): DeepSpeedPrecision(precision="64-true") diff --git a/tests/tests_pytorch/plugins/precision/test_transformer_engine.py b/tests/tests_pytorch/plugins/precision/test_transformer_engine.py index 6055f9ecd6e60..a9967280e3f23 100644 --- a/tests/tests_pytorch/plugins/precision/test_transformer_engine.py +++ b/tests/tests_pytorch/plugins/precision/test_transformer_engine.py @@ -16,16 +16,16 @@ from unittest.mock import ANY, Mock import pytest -import pytorch_lightning_enterprise.utils.imports import torch +import lightning.fabric from lightning.pytorch import LightningModule, Trainer from lightning.pytorch.plugins import TransformerEnginePrecision from lightning.pytorch.trainer.connectors.accelerator_connector import _AcceleratorConnector def test_transformer_engine_precision_plugin(monkeypatch): - module = pytorch_lightning_enterprise.plugins.precision.transformer_engine + module = lightning.fabric.plugins.precision.transformer_engine if module._TRANSFORMER_ENGINE_AVAILABLE: pytest.skip("Assumes transformer_engine is unavailable") monkeypatch.setattr(module, "_TRANSFORMER_ENGINE_AVAILABLE", lambda: True) @@ -44,7 +44,7 @@ def test_transformer_engine_precision_plugin(monkeypatch): def test_configure_model(monkeypatch): - module = pytorch_lightning_enterprise.plugins.precision.transformer_engine + module = lightning.fabric.plugins.precision.transformer_engine if module._TRANSFORMER_ENGINE_AVAILABLE: pytest.skip("Assumes transformer_engine is unavailable") monkeypatch.setattr(module, "_TRANSFORMER_ENGINE_AVAILABLE", lambda: True) diff --git a/tests/tests_pytorch/strategies/test_deepspeed.py b/tests/tests_pytorch/strategies/test_deepspeed.py index c4a1800743dd0..047e2d24f77a3 100644 --- a/tests/tests_pytorch/strategies/test_deepspeed.py +++ b/tests/tests_pytorch/strategies/test_deepspeed.py @@ -33,6 +33,7 @@ from lightning.pytorch.loggers import CSVLogger from lightning.pytorch.plugins import DeepSpeedPrecision from lightning.pytorch.strategies.deepspeed import DeepSpeedStrategy +from lightning.pytorch.utilities.exceptions import MisconfigurationException from lightning.pytorch.utilities.imports import _TORCHMETRICS_GREATER_EQUAL_0_11 as _TM_GE_0_11 from tests_pytorch.helpers.datamodules import ClassifDataModule from tests_pytorch.helpers.runif import RunIf @@ -152,7 +153,7 @@ def test_deepspeed_precision_choice(cuda_count_1, tmp_path): def test_deepspeed_with_invalid_config_path(): """Test to ensure if we pass an invalid config path we throw an exception.""" with pytest.raises( - FileNotFoundError, match="You passed in a path to a DeepSpeed config but the path does not exist" + MisconfigurationException, match="You passed in a path to a DeepSpeed config but the path does not exist" ): DeepSpeedStrategy(config="invalid_path.json") @@ -1003,7 +1004,7 @@ def test_deepspeed_strategy_env_variables(mock_deepspeed_distributed, tmp_path, strategy = trainer.strategy assert isinstance(strategy, DeepSpeedStrategy) with mock.patch("platform.system", return_value=platform) as mock_platform: - strategy.deepspeed_strategy_impl._init_deepspeed_distributed() + strategy._init_deepspeed_distributed() mock_deepspeed_distributed.assert_called() mock_platform.assert_called() if platform == "Windows": @@ -1078,7 +1079,7 @@ def training_step(self, batch, batch_idx): enable_progress_bar=False, enable_model_summary=False, ) - with pytest.raises(ValueError, match="returning `None` .* is not supported"): + with pytest.raises(MisconfigurationException, match="returning `None` .* is not supported"): trainer.fit(model) @@ -1157,7 +1158,7 @@ def test_deepspeed_gradient_clip_by_value(tmp_path): enable_progress_bar=False, enable_model_summary=False, ) - with pytest.raises(ValueError, match="does not support clipping gradients by value"): + with pytest.raises(MisconfigurationException, match="does not support clipping gradients by value"): trainer.fit(model) @@ -1266,7 +1267,6 @@ def test_validate_parallel_devices_indices(device_indices): """ accelerator = Mock(spec=CUDAAccelerator) - accelerator.name.return_value = "cuda" strategy = DeepSpeedStrategy( accelerator=accelerator, parallel_devices=[torch.device("cuda", i) for i in device_indices] ) diff --git a/tests/tests_pytorch/strategies/test_xla.py b/tests/tests_pytorch/strategies/test_xla.py index 0fd2afa98690b..3fde2600c9483 100644 --- a/tests/tests_pytorch/strategies/test_xla.py +++ b/tests/tests_pytorch/strategies/test_xla.py @@ -50,14 +50,14 @@ def test_rank_properties_access(xla_available): strategy.cluster_environment = Mock() # we're in the main process, no processes have been launched yet - assert not strategy.xla_strategy_impl._launched + assert not strategy._launched assert strategy.global_rank == 0 assert strategy.local_rank == 0 assert strategy.node_rank == 0 assert strategy.world_size == 1 # simulate we're in a worker process - strategy.xla_strategy_impl._launched = True + strategy._launched = True assert strategy.global_rank == strategy.cluster_environment.global_rank() assert strategy.local_rank == strategy.cluster_environment.local_rank() assert strategy.node_rank == strategy.cluster_environment.node_rank() diff --git a/tests/tests_pytorch/trainer/connectors/test_accelerator_connector.py b/tests/tests_pytorch/trainer/connectors/test_accelerator_connector.py index 114fb02364a09..99ce4f5f975c8 100644 --- a/tests/tests_pytorch/trainer/connectors/test_accelerator_connector.py +++ b/tests/tests_pytorch/trainer/connectors/test_accelerator_connector.py @@ -968,7 +968,6 @@ def test_precision_selection(precision_str, strategy_str, expected_precision_cls def test_bitsandbytes_precision_cuda_required(monkeypatch): monkeypatch.setattr(lightning.fabric.plugins.precision.bitsandbytes, "_BITSANDBYTES_AVAILABLE", True) - monkeypatch.setattr("pytorch_lightning_enterprise.plugins.precision.bitsandbytes._BITSANDBYTES_AVAILABLE", True) monkeypatch.setitem(sys.modules, "bitsandbytes", Mock()) with pytest.raises(RuntimeError, match="Bitsandbytes is only supported on CUDA GPUs"): _AcceleratorConnector(accelerator="cpu", plugins=BitsandbytesPrecision(mode="int8")) diff --git a/tests/tests_pytorch/utilities/migration/test_utils.py b/tests/tests_pytorch/utilities/migration/test_utils.py index b8b13af7b4710..41ca1a779f8a5 100644 --- a/tests/tests_pytorch/utilities/migration/test_utils.py +++ b/tests/tests_pytorch/utilities/migration/test_utils.py @@ -101,10 +101,7 @@ def test_patch_legacy_imports_unified(pl_version): assert any(key.startswith("lightning." + "pytorch") for key in sys.modules), ( f"Imported unified package, so it has to be in sys.modules: {_list_sys_modules('lightning' + '.pytorch')}" ) - assert not any( - key.startswith("pytorch_lightning") and not key.startswith("pytorch_lightning_enterprise") - for key in sys.modules - ), ( + assert not any(key.startswith("pytorch_lightning") for key in sys.modules), ( "Should not import standalone package, all imports should be redirected to the unified package;\n" f" environment: {_list_sys_modules('pytorch_lightning')}" ) @@ -125,10 +122,7 @@ def test_patch_legacy_imports_unified(pl_version): assert any(key.startswith("lightning." + "pytorch") for key in sys.modules), ( f"Imported unified package, so it has to be in sys.modules: {_list_sys_modules('lightning' + '.pytorch')}" ) - assert not any( - key.startswith("pytorch_lightning") and not key.startswith("pytorch_lightning_enterprise") - for key in sys.modules - ), ( + assert not any(key.startswith("pytorch_lightning") for key in sys.modules), ( "Should not import standalone package, all imports should be redirected to the unified package;\n" f" environment: {_list_sys_modules('pytorch_lightning')}" )