diff --git a/docs/source-pytorch/conf.py b/docs/source-pytorch/conf.py index db3f5334ed162..67b1c8bf7bfd6 100644 --- a/docs/source-pytorch/conf.py +++ b/docs/source-pytorch/conf.py @@ -604,10 +604,7 @@ def package_list_from_file(file): from lightning.pytorch.cli import _JSONARGPARSE_SIGNATURES_AVAILABLE as _JSONARGPARSE_AVAILABLE from lightning.pytorch.utilities.imports import _TORCHVISION_AVAILABLE from lightning.fabric.loggers.tensorboard import _TENSORBOARD_AVAILABLE, _TENSORBOARDX_AVAILABLE -from lightning.pytorch.loggers.neptune import _NEPTUNE_AVAILABLE -from lightning.pytorch.loggers.comet import _COMET_AVAILABLE -from lightning.pytorch.loggers.mlflow import _MLFLOW_AVAILABLE -from lightning.pytorch.loggers.wandb import _WANDB_AVAILABLE +from lightning.fabric.utilities.imports import _COMET_AVAILABLE, _MLFLOW_AVAILABLE, _NEPTUNE_AVAILABLE, _WANDB_AVAILABLE """ coverage_skip_undoc_in_source = True diff --git a/requirements/fabric/base.txt b/requirements/fabric/base.txt index fdc814c7f7fa1..0c784b94864ab 100644 --- a/requirements/fabric/base.txt +++ b/requirements/fabric/base.txt @@ -6,3 +6,4 @@ fsspec[http] >=2022.5.0, <2025.11.0 packaging >=20.0, <=25.0 typing-extensions >4.5.0, <4.16.0 lightning-utilities >=0.10.0, <0.16.0 +pytorch-lightning-enterprise >=0.0.1dev4 diff --git a/requirements/pytorch/base.txt b/requirements/pytorch/base.txt index 014b223b1f012..96bb0ace3e115 100644 --- a/requirements/pytorch/base.txt +++ b/requirements/pytorch/base.txt @@ -9,3 +9,4 @@ torchmetrics >0.7.0, <1.9.0 packaging >=20.0, <=25.0 typing-extensions >4.5.0, <4.16.0 lightning-utilities >=0.10.0, <0.16.0 +pytorch-lightning-enterprise >=0.0.1dev4 diff --git a/src/lightning/fabric/accelerators/xla.py b/src/lightning/fabric/accelerators/xla.py index db2cf2586e1ba..0cf42821424ee 100644 --- a/src/lightning/fabric/accelerators/xla.py +++ b/src/lightning/fabric/accelerators/xla.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import functools +import warnings from typing import Any, Union import torch @@ -20,7 +21,11 @@ from lightning.fabric.accelerators.accelerator import Accelerator from lightning.fabric.accelerators.registry import _AcceleratorRegistry -from lightning.fabric.utilities.device_parser import _check_data_type +from lightning.fabric.utilities.imports import _raise_enterprise_not_available + +_XLA_AVAILABLE = RequirementCache("torch_xla>=1.13", "torch_xla") +_XLA_GREATER_EQUAL_2_1 = RequirementCache("torch_xla>=2.1") +_XLA_GREATER_EQUAL_2_5 = RequirementCache("torch_xla>=2.5") class XLAAccelerator(Accelerator): @@ -31,38 +36,38 @@ class XLAAccelerator(Accelerator): """ def __init__(self, *args: Any, **kwargs: Any) -> None: - if not _XLA_AVAILABLE: - raise ModuleNotFoundError(str(_XLA_AVAILABLE)) - if not _using_pjrt(): - raise RuntimeError("The XLA XRT runtime is not supported anymore.") + _raise_enterprise_not_available() super().__init__(*args, **kwargs) + from pytorch_lightning_enterprise.accelerators.xla import XLAAccelerator as EnterpriseXLAAccelerator + + self.accelerator_impl = EnterpriseXLAAccelerator(*args, **kwargs) + @override def setup_device(self, device: torch.device) -> None: - pass + return self.accelerator_impl.setup_device(device) @override def teardown(self) -> None: - pass + return self.accelerator_impl.teardown() @staticmethod @override def parse_devices(devices: Union[int, str, list[int]]) -> Union[int, list[int]]: """Accelerator device parsing logic.""" - return _parse_tpu_devices(devices) + _raise_enterprise_not_available() + from pytorch_lightning_enterprise.accelerators.xla import XLAAccelerator as EnterpriseXLAAccelerator + + return EnterpriseXLAAccelerator.parse_devices(devices) @staticmethod @override def get_parallel_devices(devices: Union[int, list[int]]) -> list[torch.device]: """Gets parallel devices for the Accelerator.""" - devices = _parse_tpu_devices(devices) - if isinstance(devices, int): - return [torch.device("xla", i) for i in range(devices)] - # list of devices is not supported, just a specific index, fine to access [0] - return [torch.device("xla", devices[0])] - # we cannot create `xla_device` here because processes have not been spawned yet (this is called in the - # accelerator connector init). However, there doesn't seem to be a problem with instantiating `torch.device`. - # it will be replaced with `xla_device` (also a torch.device`, but with extra logic) in the strategy + _raise_enterprise_not_available() + from pytorch_lightning_enterprise.accelerators.xla import XLAAccelerator as EnterpriseXLAAccelerator + + return EnterpriseXLAAccelerator.get_parallel_devices(devices) @staticmethod @override @@ -71,16 +76,10 @@ def get_parallel_devices(devices: Union[int, list[int]]) -> list[torch.device]: @functools.lru_cache(maxsize=1) def auto_device_count() -> int: """Get the devices when set to auto.""" - if not _XLA_AVAILABLE: - return 0 - if _XLA_GREATER_EQUAL_2_1: - from torch_xla._internal import tpu - - return tpu.num_available_devices() - from torch_xla.experimental import tpu + _raise_enterprise_not_available() + from pytorch_lightning_enterprise.accelerators.xla import XLAAccelerator as EnterpriseXLAAccelerator - device_count_on_version = {2: 8, 3: 8, 4: 4} - return device_count_on_version.get(tpu.version(), 8) + return EnterpriseXLAAccelerator.auto_device_count() @staticmethod @override @@ -92,6 +91,9 @@ def is_available() -> bool: # XLA may raise these exceptions if it's not properly configured. This needs to be avoided for the cases # when `torch_xla` is imported but not used return False + except ModuleNotFoundError as e: + warnings.warn(str(e)) + return False @staticmethod @override @@ -106,74 +108,3 @@ def register_accelerators(cls, accelerator_registry: _AcceleratorRegistry) -> No cls, description=cls.__name__, ) - - -# PJRT support requires this minimum version -_XLA_AVAILABLE = RequirementCache("torch_xla>=1.13", "torch_xla") -_XLA_GREATER_EQUAL_2_1 = RequirementCache("torch_xla>=2.1") -_XLA_GREATER_EQUAL_2_5 = RequirementCache("torch_xla>=2.5") - - -def _using_pjrt() -> bool: - # `using_pjrt` is removed in torch_xla 2.5 - if _XLA_GREATER_EQUAL_2_5: - from torch_xla import runtime as xr - - return xr.device_type() is not None - # delete me when torch_xla 2.2 is the min supported version, where XRT support has been dropped. - if _XLA_GREATER_EQUAL_2_1: - from torch_xla import runtime as xr - - return xr.using_pjrt() - - from torch_xla.experimental import pjrt - - return pjrt.using_pjrt() - - -def _parse_tpu_devices(devices: Union[int, str, list[int]]) -> Union[int, list[int]]: - """Parses the TPU devices given in the format as accepted by the - :class:`~lightning.pytorch.trainer.trainer.Trainer` and :class:`~lightning.fabric.Fabric`. - - Args: - devices: An int of 1 or string '1' indicates that 1 core with multi-processing should be used - An int 8 or string '8' indicates that all 8 cores with multi-processing should be used - A single element list of int or string can be used to indicate the specific TPU core to use. - - Returns: - A list of tpu cores to be used. - - """ - _check_data_type(devices) - if isinstance(devices, str): - devices = _parse_tpu_devices_str(devices) - _check_tpu_devices_valid(devices) - return devices - - -def _check_tpu_devices_valid(devices: object) -> None: - device_count = XLAAccelerator.auto_device_count() - if ( - # support number of devices - isinstance(devices, int) - and devices in {1, device_count} - # support picking a specific device - or isinstance(devices, (list, tuple)) - and len(devices) == 1 - and 0 <= devices[0] <= device_count - 1 - ): - return - raise ValueError( - f"`devices` can only be 'auto', 1, {device_count} or [<0-{device_count - 1}>] for TPUs. Got {devices!r}" - ) - - -def _parse_tpu_devices_str(devices: str) -> Union[int, list[int]]: - devices = devices.strip() - try: - return int(devices) - except ValueError: - try: - return [int(x.strip()) for x in devices.split(",") if len(x) > 0] - except ValueError: - raise ValueError(f"Could not parse the selected TPU devices: {devices!r}") diff --git a/src/lightning/fabric/plugins/environments/kubeflow.py b/src/lightning/fabric/plugins/environments/kubeflow.py index 23a1c0d1753af..8013ec2a65720 100644 --- a/src/lightning/fabric/plugins/environments/kubeflow.py +++ b/src/lightning/fabric/plugins/environments/kubeflow.py @@ -13,11 +13,11 @@ # limitations under the License. import logging -import os from typing_extensions import override from lightning.fabric.plugins.environments.cluster_environment import ClusterEnvironment +from lightning.fabric.utilities.imports import _raise_enterprise_not_available log = logging.getLogger(__name__) @@ -33,20 +33,28 @@ class KubeflowEnvironment(ClusterEnvironment): """ + def __init__(self) -> None: + _raise_enterprise_not_available() + from pytorch_lightning_enterprise.plugins.environments.kubeflow import ( + KubeflowEnvironment as EnterpriseKubeflowEnvironment, + ) + + self.kubeflow_impl = EnterpriseKubeflowEnvironment() + @property @override def creates_processes_externally(self) -> bool: - return True + return self.kubeflow_impl.creates_processes_externally @property @override def main_address(self) -> str: - return os.environ["MASTER_ADDR"] + return self.kubeflow_impl.main_address @property @override def main_port(self) -> int: - return int(os.environ["MASTER_PORT"]) + return self.kubeflow_impl.main_port @staticmethod @override @@ -55,24 +63,24 @@ def detect() -> bool: @override def world_size(self) -> int: - return int(os.environ["WORLD_SIZE"]) + return self.kubeflow_impl.world_size() @override def set_world_size(self, size: int) -> None: - log.debug("KubeflowEnvironment.set_world_size was called, but setting world size is not allowed. Ignored.") + return self.kubeflow_impl.set_world_size(size) @override def global_rank(self) -> int: - return int(os.environ["RANK"]) + return self.kubeflow_impl.global_rank() @override def set_global_rank(self, rank: int) -> None: - log.debug("KubeflowEnvironment.set_global_rank was called, but setting global rank is not allowed. Ignored.") + return self.kubeflow_impl.set_global_rank(rank) @override def local_rank(self) -> int: - return 0 + return self.kubeflow_impl.local_rank() @override def node_rank(self) -> int: - return self.global_rank() + return self.kubeflow_impl.node_rank() diff --git a/src/lightning/fabric/plugins/environments/lsf.py b/src/lightning/fabric/plugins/environments/lsf.py index f0a07d61d9f03..0b62bf502303a 100644 --- a/src/lightning/fabric/plugins/environments/lsf.py +++ b/src/lightning/fabric/plugins/environments/lsf.py @@ -13,12 +13,11 @@ # limitations under the License. import logging import os -import socket from typing_extensions import override from lightning.fabric.plugins.environments.cluster_environment import ClusterEnvironment -from lightning.fabric.utilities.cloud_io import get_filesystem +from lightning.fabric.utilities.imports import _raise_enterprise_not_available log = logging.getLogger(__name__) @@ -50,36 +49,32 @@ class LSFEnvironment(ClusterEnvironment): def __init__(self) -> None: super().__init__() - self._main_address = self._get_main_address() - self._main_port = self._get_main_port() - self._node_rank = self._get_node_rank() - self._set_init_progress_group_env_vars() - - def _set_init_progress_group_env_vars(self) -> None: - # set environment variables needed for initializing torch distributed process group - os.environ["MASTER_ADDR"] = str(self._main_address) - log.debug(f"MASTER_ADDR: {os.environ['MASTER_ADDR']}") - os.environ["MASTER_PORT"] = str(self._main_port) - log.debug(f"MASTER_PORT: {os.environ['MASTER_PORT']}") + + _raise_enterprise_not_available() + from pytorch_lightning_enterprise.plugins.environments.lsf import ( + LSFEnvironment as EnterpriseLSFEnvironment, + ) + + self.lsf_impl = EnterpriseLSFEnvironment() @property @override def creates_processes_externally(self) -> bool: """LSF creates subprocesses, i.e., PyTorch Lightning does not need to spawn them.""" - return True + return self.lsf_impl.creates_processes_externally @property @override def main_address(self) -> str: """The main address is read from an OpenMPI host rank file in the environment variable ``LSB_DJOB_RANKFILE``.""" - return self._main_address + return self.lsf_impl.main_address @property @override def main_port(self) -> int: """The main port is calculated from the LSF job ID.""" - return self._main_port + return self.lsf_impl.main_port @staticmethod @override @@ -91,110 +86,28 @@ def detect() -> bool: @override def world_size(self) -> int: """The world size is read from the environment variable ``JSM_NAMESPACE_SIZE``.""" - world_size = os.environ.get("JSM_NAMESPACE_SIZE") - if world_size is None: - raise ValueError( - "Cannot determine world size. Environment variable `JSM_NAMESPACE_SIZE` not found." - " Make sure you run your executable with `jsrun`." - ) - return int(world_size) + return self.lsf_impl.world_size() @override def set_world_size(self, size: int) -> None: - log.debug("LSFEnvironment.set_world_size was called, but setting world size is not allowed. Ignored.") + return self.lsf_impl.set_world_size(size) @override def global_rank(self) -> int: """The world size is read from the environment variable ``JSM_NAMESPACE_RANK``.""" - global_rank = os.environ.get("JSM_NAMESPACE_RANK") - if global_rank is None: - raise ValueError( - "Cannot determine global rank. Environment variable `JSM_NAMESPACE_RANK` not found." - " Make sure you run your executable with `jsrun`." - ) - return int(global_rank) + return self.lsf_impl.global_rank() @override def set_global_rank(self, rank: int) -> None: - log.debug("LSFEnvironment.set_global_rank was called, but setting global rank is not allowed. Ignored.") + return self.lsf_impl.set_global_rank(rank) @override def local_rank(self) -> int: """The local rank is read from the environment variable `JSM_NAMESPACE_LOCAL_RANK`.""" - local_rank = os.environ.get("JSM_NAMESPACE_LOCAL_RANK") - if local_rank is None: - raise ValueError( - "Cannot determine local rank. Environment variable `JSM_NAMESPACE_LOCAL_RANK` not found." - " Make sure you run your executable with `jsrun`." - ) - return int(local_rank) + return self.lsf_impl.local_rank() @override def node_rank(self) -> int: """The node rank is determined by the position of the current hostname in the OpenMPI host rank file stored in ``LSB_DJOB_RANKFILE``.""" - return self._node_rank - - def _get_node_rank(self) -> int: - """A helper method for getting the node rank. - - The node rank is determined by the position of the current node in the list of hosts used in the job. This is - calculated by reading all hosts from ``LSB_DJOB_RANKFILE`` and finding this node's hostname in the list. - - """ - hosts = self._read_hosts() - count: dict[str, int] = {} - for host in hosts: - if host not in count: - count[host] = len(count) - return count[socket.gethostname()] - - @staticmethod - def _read_hosts() -> list[str]: - """Read compute hosts that are a part of the compute job. - - LSF uses the Job Step Manager (JSM) to manage job steps. Job steps are executed by the JSM from "launch" nodes. - Each job is assigned a launch node. This launch node will be the first node in the list contained in - ``LSB_DJOB_RANKFILE``. - - """ - var = "LSB_DJOB_RANKFILE" - rankfile = os.environ.get(var) - if rankfile is None: - raise ValueError("Did not find the environment variable `LSB_DJOB_RANKFILE`") - if not rankfile: - raise ValueError("The environment variable `LSB_DJOB_RANKFILE` is empty") - - fs = get_filesystem(rankfile) - with fs.open(rankfile, "r") as f: - ret = [line.strip() for line in f] - # remove the launch node (i.e. the first node in LSB_DJOB_RANKFILE) from the list - return ret[1:] - - def _get_main_address(self) -> str: - """A helper for getting the main address. - - The main address is assigned to the first node in the list of nodes used for the job. - - """ - hosts = self._read_hosts() - return hosts[0] - - @staticmethod - def _get_main_port() -> int: - """A helper function for accessing the main port. - - Uses the LSF job ID so all ranks can compute the main port. - - """ - # check for user-specified main port - if "MASTER_PORT" in os.environ: - log.debug(f"Using externally specified main port: {os.environ['MASTER_PORT']}") - return int(os.environ["MASTER_PORT"]) - if "LSB_JOBID" in os.environ: - port = int(os.environ["LSB_JOBID"]) - # all ports should be in the 10k+ range - port = port % 1000 + 10000 - log.debug(f"calculated LSF main port: {port}") - return port - raise ValueError("Could not find job id in environment variable LSB_JOBID") + return self.lsf_impl.node_rank() diff --git a/src/lightning/fabric/plugins/environments/slurm.py b/src/lightning/fabric/plugins/environments/slurm.py index 4d98b7ed6a8eb..4ac79ce0014e9 100644 --- a/src/lightning/fabric/plugins/environments/slurm.py +++ b/src/lightning/fabric/plugins/environments/slurm.py @@ -14,7 +14,6 @@ import logging import os -import re import shutil import signal import sys @@ -23,7 +22,7 @@ from typing_extensions import override from lightning.fabric.plugins.environments.cluster_environment import ClusterEnvironment -from lightning.fabric.utilities.imports import _IS_WINDOWS +from lightning.fabric.utilities.imports import _IS_WINDOWS, _raise_enterprise_not_available from lightning.fabric.utilities.rank_zero import rank_zero_warn from lightning.fabric.utilities.warnings import PossibleUserWarning @@ -46,57 +45,35 @@ class SLURMEnvironment(ClusterEnvironment): def __init__(self, auto_requeue: bool = True, requeue_signal: Optional[signal.Signals] = None) -> None: super().__init__() - self.auto_requeue = auto_requeue - if requeue_signal is None and not _IS_WINDOWS: - requeue_signal = signal.SIGUSR1 - self.requeue_signal = requeue_signal - self._validate_srun_used() - self._validate_srun_variables() + _raise_enterprise_not_available() + from pytorch_lightning_enterprise.plugins.environments.slurm import ( + SLURMEnvironment as EnterpriseSLURMEnvironment, + ) + + self.slurm_impl = EnterpriseSLURMEnvironment(auto_requeue=auto_requeue, requeue_signal=requeue_signal) + + @property + def auto_requeue(self) -> bool: + return self.slurm_impl.auto_requeue + + @property + def requeue_signal(self) -> Optional[signal.Signals]: + return self.slurm_impl.requeue_signal @property @override def creates_processes_externally(self) -> bool: - return True + return self.slurm_impl.creates_processes_externally @property @override def main_address(self) -> str: - root_node = os.environ.get("MASTER_ADDR") - if root_node is None: - nodelist = os.environ.get("SLURM_NODELIST", "127.0.0.1") - root_node = self.resolve_root_node_address(nodelist) - os.environ["MASTER_ADDR"] = root_node - - log.debug(f"MASTER_ADDR: {os.environ['MASTER_ADDR']}") - return root_node + return self.slurm_impl.main_address @property @override def main_port(self) -> int: - # ----------------------- - # SLURM JOB = PORT number - # ----------------------- - # this way every process knows what port to use - job_id = os.environ.get("SLURM_JOB_ID") - if job_id is not None: - # use the last 4 numbers in the job id as the id - default_port = job_id[-4:] - # all ports should be in the 10k+ range - default_port = int(default_port) + 15000 - else: - default_port = 12910 - - # ----------------------- - # PORT NUMBER = MASTER_PORT - # ----------------------- - # in case the user passed it in - if "MASTER_PORT" in os.environ: - default_port = int(os.environ["MASTER_PORT"]) - else: - os.environ["MASTER_PORT"] = str(default_port) - - log.debug(f"MASTER_PORT: {os.environ['MASTER_PORT']}") - return default_port + return self.slurm_impl.main_port @staticmethod @override @@ -118,58 +95,40 @@ def job_name() -> Optional[str]: @staticmethod def job_id() -> Optional[int]: - # in interactive mode, don't make logs use the same job id - if _is_slurm_interactive_mode(): - return None - - job_id = os.environ.get("SLURM_JOB_ID") - if job_id is None: - return None - try: - return int(job_id) - except ValueError: - return None + _raise_enterprise_not_available() + from pytorch_lightning_enterprise.plugins.environments.slurm import ( + SLURMEnvironment as EnterpriseSLURMEnvironment, + ) + + return EnterpriseSLURMEnvironment.job_id() @override def world_size(self) -> int: - return int(os.environ["SLURM_NTASKS"]) + return self.slurm_impl.world_size() @override def set_world_size(self, size: int) -> None: - log.debug("SLURMEnvironment.set_world_size was called, but setting world size is not allowed. Ignored.") + return self.slurm_impl.set_world_size(size) @override def global_rank(self) -> int: - return int(os.environ["SLURM_PROCID"]) + return self.slurm_impl.global_rank() @override def set_global_rank(self, rank: int) -> None: - log.debug("SLURMEnvironment.set_global_rank was called, but setting global rank is not allowed. Ignored.") + return self.slurm_impl.set_global_rank(rank) @override def local_rank(self) -> int: - return int(os.environ["SLURM_LOCALID"]) + return self.slurm_impl.local_rank() @override def node_rank(self) -> int: - return int(os.environ["SLURM_NODEID"]) + return self.slurm_impl.node_rank() @override def validate_settings(self, num_devices: int, num_nodes: int) -> None: - if _is_slurm_interactive_mode(): - return - ntasks_per_node = os.environ.get("SLURM_NTASKS_PER_NODE") - if ntasks_per_node is not None and int(ntasks_per_node) != num_devices: - raise ValueError( - f"You set `devices={num_devices}` in Lightning, but the number of tasks per node configured in SLURM" - f" `--ntasks-per-node={ntasks_per_node}` does not match. HINT: Set `devices={ntasks_per_node}`." - ) - nnodes = os.environ.get("SLURM_NNODES") - if nnodes is not None and int(nnodes) != num_nodes: - raise ValueError( - f"You set `num_nodes={num_nodes}` in Lightning, but the number of nodes configured in SLURM" - f" `--nodes={nnodes}` does not match. HINT: Set `num_nodes={nnodes}`." - ) + return self.slurm_impl.validate_settings(num_devices, num_nodes) @staticmethod def resolve_root_node_address(nodes: str) -> str: @@ -182,9 +141,12 @@ def resolve_root_node_address(nodes: str) -> str: - the range notation with brackets, e.g., 'host[5-9]' yields 'host5' as the root """ - nodes = re.sub(r"\[(.*?)[,-].*\]", "\\1", nodes) # Take the first node of every node range - nodes = re.sub(r"\[(.*?)\]", "\\1", nodes) # handle special case where node range is single number - return nodes.split(" ")[0].split(",")[0] + _raise_enterprise_not_available() + from pytorch_lightning_enterprise.plugins.environments.slurm import ( + SLURMEnvironment as EnterpriseSLURMEnvironment, + ) + + return EnterpriseSLURMEnvironment.resolve_root_node_address(nodes) @staticmethod def _validate_srun_used() -> None: @@ -216,12 +178,12 @@ def _validate_srun_variables() -> None: for a complete list of supported srun variables. """ - ntasks = int(os.environ.get("SLURM_NTASKS", "1")) - if ntasks > 1 and "SLURM_NTASKS_PER_NODE" not in os.environ: - raise RuntimeError( - f"You set `--ntasks={ntasks}` in your SLURM bash script, but this variable is not supported." - f" HINT: Use `--ntasks-per-node={ntasks}` instead." - ) + _raise_enterprise_not_available() + from pytorch_lightning_enterprise.plugins.environments.slurm import ( + SLURMEnvironment as EnterpriseSLURMEnvironment, + ) + + return EnterpriseSLURMEnvironment._validate_srun_variables() def _is_srun_used() -> bool: diff --git a/src/lightning/fabric/plugins/environments/torchelastic.py b/src/lightning/fabric/plugins/environments/torchelastic.py index 4003dd30dec09..62ddd442df7b2 100644 --- a/src/lightning/fabric/plugins/environments/torchelastic.py +++ b/src/lightning/fabric/plugins/environments/torchelastic.py @@ -13,13 +13,12 @@ # limitations under the License. import logging -import os import torch.distributed from typing_extensions import override from lightning.fabric.plugins.environments.cluster_environment import ClusterEnvironment -from lightning.fabric.utilities.rank_zero import rank_zero_warn +from lightning.fabric.utilities.imports import _raise_enterprise_not_available log = logging.getLogger(__name__) @@ -27,29 +26,29 @@ class TorchElasticEnvironment(ClusterEnvironment): """Environment for fault-tolerant and elastic training with `torchelastic `_""" + def __init__(self) -> None: + super().__init__() + _raise_enterprise_not_available() + from pytorch_lightning_enterprise.plugins.environments.torchelastic import ( + TorchElasticEnvironment as EnterpriseTorchElasticEnvironment, + ) + + self.torchelastic_impl = EnterpriseTorchElasticEnvironment() + @property @override def creates_processes_externally(self) -> bool: - return True + return self.torchelastic_impl.creates_processes_externally @property @override def main_address(self) -> str: - if "MASTER_ADDR" not in os.environ: - rank_zero_warn("MASTER_ADDR environment variable is not defined. Set as localhost") - os.environ["MASTER_ADDR"] = "127.0.0.1" - log.debug(f"MASTER_ADDR: {os.environ['MASTER_ADDR']}") - return os.environ["MASTER_ADDR"] + return self.torchelastic_impl.main_address @property @override def main_port(self) -> int: - if "MASTER_PORT" not in os.environ: - rank_zero_warn("MASTER_PORT environment variable is not defined. Set as 12910") - os.environ["MASTER_PORT"] = "12910" - log.debug(f"MASTER_PORT: {os.environ['MASTER_PORT']}") - - return int(os.environ["MASTER_PORT"]) + return self.torchelastic_impl.main_port @staticmethod @override @@ -60,34 +59,28 @@ def detect() -> bool: @override def world_size(self) -> int: - return int(os.environ["WORLD_SIZE"]) + return self.torchelastic_impl.world_size() @override def set_world_size(self, size: int) -> None: - log.debug("TorchElasticEnvironment.set_world_size was called, but setting world size is not allowed. Ignored.") + return self.torchelastic_impl.set_world_size(size) @override def global_rank(self) -> int: - return int(os.environ["RANK"]) + return self.torchelastic_impl.global_rank() @override def set_global_rank(self, rank: int) -> None: - log.debug( - "TorchElasticEnvironment.set_global_rank was called, but setting global rank is not allowed. Ignored." - ) + return self.torchelastic_impl.set_global_rank(rank) @override def local_rank(self) -> int: - return int(os.environ["LOCAL_RANK"]) + return self.torchelastic_impl.local_rank() @override def node_rank(self) -> int: - return int(os.environ.get("GROUP_RANK", 0)) + return self.torchelastic_impl.node_rank() @override def validate_settings(self, num_devices: int, num_nodes: int) -> None: - if num_devices * num_nodes != self.world_size(): - raise ValueError( - f"You set `devices={num_devices}` and `num_nodes={num_nodes}` in Lightning, but the product" - f" ({num_devices} * {num_nodes}) does not match the world size ({self.world_size()})." - ) + return self.torchelastic_impl.validate_settings(num_devices, num_nodes) diff --git a/src/lightning/fabric/plugins/environments/xla.py b/src/lightning/fabric/plugins/environments/xla.py index b8350872f22d9..515668fa2b154 100644 --- a/src/lightning/fabric/plugins/environments/xla.py +++ b/src/lightning/fabric/plugins/environments/xla.py @@ -17,8 +17,9 @@ from typing_extensions import override -from lightning.fabric.accelerators.xla import _XLA_AVAILABLE, _XLA_GREATER_EQUAL_2_1, XLAAccelerator +from lightning.fabric.accelerators.xla import XLAAccelerator from lightning.fabric.plugins.environments.cluster_environment import ClusterEnvironment +from lightning.fabric.utilities.imports import _raise_enterprise_not_available log = logging.getLogger(__name__) @@ -32,26 +33,28 @@ class XLAEnvironment(ClusterEnvironment): """ def __init__(self, *args: Any, **kwargs: Any) -> None: - if not _XLA_AVAILABLE: - raise ModuleNotFoundError(str(_XLA_AVAILABLE)) - super().__init__(*args, **kwargs) + super().__init__() + _raise_enterprise_not_available() + from pytorch_lightning_enterprise.plugins.environments.xla import ( + XLAEnvironment as EnterpriseXLAEnvironment, + ) + + self.xla_impl = EnterpriseXLAEnvironment(*args, **kwargs) @property @override def creates_processes_externally(self) -> bool: - return False + return self.xla_impl.creates_processes_externally @property @override def main_address(self) -> str: - # unused by lightning - raise NotImplementedError + return self.xla_impl.main_address @property @override def main_port(self) -> int: - # unused by lightning - raise NotImplementedError + return self.xla_impl.main_port @staticmethod @override @@ -66,18 +69,11 @@ def world_size(self) -> int: The output is cached for performance. """ - if _XLA_GREATER_EQUAL_2_1: - from torch_xla import runtime as xr - - return xr.world_size() - - import torch_xla.core.xla_model as xm - - return xm.xrt_world_size() + return self.xla_impl.world_size() @override def set_world_size(self, size: int) -> None: - log.debug("XLAEnvironment.set_world_size was called, but setting world size is not allowed. Ignored.") + return self.xla_impl.set_world_size(size) @override @functools.lru_cache(maxsize=1) @@ -87,18 +83,11 @@ def global_rank(self) -> int: The output is cached for performance. """ - if _XLA_GREATER_EQUAL_2_1: - from torch_xla import runtime as xr - - return xr.global_ordinal() - - import torch_xla.core.xla_model as xm - - return xm.get_ordinal() + return self.xla_impl.global_rank() @override def set_global_rank(self, rank: int) -> None: - log.debug("XLAEnvironment.set_global_rank was called, but setting global rank is not allowed. Ignored.") + return self.xla_impl.set_global_rank(rank) @override @functools.lru_cache(maxsize=1) @@ -108,14 +97,7 @@ def local_rank(self) -> int: The output is cached for performance. """ - if _XLA_GREATER_EQUAL_2_1: - from torch_xla import runtime as xr - - return xr.local_ordinal() - - import torch_xla.core.xla_model as xm - - return xm.get_local_ordinal() + return self.xla_impl.local_rank() @override @functools.lru_cache(maxsize=1) @@ -125,11 +107,4 @@ def node_rank(self) -> int: The output is cached for performance. """ - if _XLA_GREATER_EQUAL_2_1: - from torch_xla import runtime as xr - - return xr.host_index() - import torch_xla.core.xla_env_vars as xenv - from torch_xla.utils.utils import getenv_as - - return getenv_as(xenv.HOST_ORDINAL, int, 0) + return self.xla_impl.node_rank() diff --git a/src/lightning/fabric/plugins/io/xla.py b/src/lightning/fabric/plugins/io/xla.py index 146fa2f33b510..da6ea414cac97 100644 --- a/src/lightning/fabric/plugins/io/xla.py +++ b/src/lightning/fabric/plugins/io/xla.py @@ -12,17 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -import os from typing import Any, Optional -import torch -from lightning_utilities.core.apply_func import apply_to_collection -from lightning_utilities.core.imports import RequirementCache from typing_extensions import override -from lightning.fabric.accelerators.xla import _XLA_AVAILABLE from lightning.fabric.plugins.io.torch_io import TorchCheckpointIO -from lightning.fabric.utilities.cloud_io import get_filesystem +from lightning.fabric.utilities.imports import _raise_enterprise_not_available from lightning.fabric.utilities.types import _PATH log = logging.getLogger(__name__) @@ -36,9 +31,11 @@ class XLACheckpointIO(TorchCheckpointIO): """ def __init__(self, *args: Any, **kwargs: Any) -> None: - if not _XLA_AVAILABLE: - raise ModuleNotFoundError(str(_XLA_AVAILABLE)) super().__init__(*args, **kwargs) + _raise_enterprise_not_available() + from pytorch_lightning_enterprise.plugins.io.xla import XLACheckpointIO as EnterpriseXLACheckpointIO + + self.xla_impl = EnterpriseXLACheckpointIO(*args, **kwargs) @override def save_checkpoint(self, checkpoint: dict[str, Any], path: _PATH, storage_options: Optional[Any] = None) -> None: @@ -54,21 +51,4 @@ def save_checkpoint(self, checkpoint: dict[str, Any], path: _PATH, storage_optio If ``storage_options`` arg is passed in """ - if storage_options is not None: - raise TypeError( - "`Trainer.save_checkpoint(..., storage_options=...)` with `storage_options` arg" - f" is not supported for `{self.__class__.__name__}`. Please implement your custom `CheckpointIO`" - " to define how you'd like to use `storage_options`." - ) - fs = get_filesystem(path) - fs.makedirs(os.path.dirname(path), exist_ok=True) - if RequirementCache("omegaconf"): - # workaround for https://github.com/pytorch/xla/issues/2773 - from omegaconf import DictConfig, ListConfig, OmegaConf - - checkpoint = apply_to_collection(checkpoint, (DictConfig, ListConfig), OmegaConf.to_container) - import torch_xla.core.xla_model as xm - - cpu_data = xm._maybe_convert_to_cpu(checkpoint, convert=True) - log.debug(f"Saving checkpoint: {path}") - torch.save(cpu_data, path) + return self.xla_impl.save_checkpoint(checkpoint, path, storage_options) diff --git a/src/lightning/fabric/plugins/precision/bitsandbytes.py b/src/lightning/fabric/plugins/precision/bitsandbytes.py index 4c648f2b97181..4ff20ee9fab1e 100644 --- a/src/lightning/fabric/plugins/precision/bitsandbytes.py +++ b/src/lightning/fabric/plugins/precision/bitsandbytes.py @@ -11,34 +11,15 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import functools -import logging -import math -import os -import warnings -from collections import OrderedDict -from contextlib import AbstractContextManager, ExitStack -from functools import partial -from types import ModuleType -from typing import Any, Callable, Literal, Optional, cast +from contextlib import AbstractContextManager +from typing import Any, Literal, Optional import torch -from lightning_utilities import apply_to_collection from lightning_utilities.core.imports import RequirementCache -from torch import Tensor -from torch.nn import init -from torch.nn.modules.module import _IncompatibleKeys -from typing_extensions import Self, override +from typing_extensions import override from lightning.fabric.plugins.precision.precision import Precision -from lightning.fabric.plugins.precision.utils import ( - _ClassReplacementContextManager, - _convert_fp_tensor, - _DtypeContextManager, -) -from lightning.fabric.utilities.types import _DEVICE - -log = logging.getLogger(__name__) +from lightning.fabric.utilities.imports import _raise_enterprise_not_available _BITSANDBYTES_AVAILABLE = RequirementCache("bitsandbytes") @@ -73,376 +54,34 @@ def __init__( dtype: Optional[torch.dtype] = None, ignore_modules: Optional[set[str]] = None, ) -> None: - _import_bitsandbytes() - - if dtype is None: - # try to be smart about the default selection - if mode.startswith("int8"): - dtype = torch.float16 - else: - dtype = ( - torch.bfloat16 if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else torch.float16 - ) - if mode.startswith("int8") and dtype is not torch.float16: - # this limitation is mentioned in https://huggingface.co/blog/hf-bitsandbytes-integration#usage - raise ValueError(f"{mode!r} only works with `dtype=torch.float16`, but you chose `{dtype}`") + super().__init__() + _raise_enterprise_not_available() + from pytorch_lightning_enterprise.plugins.precision.bitsandbytes import ( + BitsandbytesPrecision as EnterpriseBitsandbytesPrecision, + ) - globals_ = globals() - mode_to_cls = { - "nf4": globals_["_NF4Linear"], - "nf4-dq": globals_["_NF4DQLinear"], - "fp4": globals_["_FP4Linear"], - "fp4-dq": globals_["_FP4DQLinear"], - "int8-training": globals_["_Linear8bitLt"], - "int8": globals_["_Int8LinearInference"], - } - self._linear_cls = mode_to_cls[mode] - self.dtype = dtype - self.ignore_modules = ignore_modules or set() + self.bitsandbytes_impl = EnterpriseBitsandbytesPrecision(mode=mode, dtype=dtype, ignore_modules=ignore_modules) @override def convert_module(self, module: torch.nn.Module) -> torch.nn.Module: - # avoid naive users thinking they quantized their model - if not any(isinstance(m, torch.nn.Linear) for m in module.modules()): - raise TypeError( - "You are using the bitsandbytes precision plugin, but your model has no Linear layers. This plugin" - " won't work for your model." - ) - - # convert modules if they haven't been converted already - bnb = _import_bitsandbytes() - if not any(isinstance(m, (bnb.nn.Linear8bitLt, bnb.nn.Linear4bit)) for m in module.modules()): - # this will not quantize the model but only replace the layer classes - _convert_layers(module, self._linear_cls, self.ignore_modules) - - # set the compute dtype if necessary - for m in module.modules(): - if isinstance(m, bnb.nn.Linear4bit): - m.compute_dtype = self.dtype - m.compute_type_is_set = False - return module + return self.bitsandbytes_impl.convert_module(module) @override def tensor_init_context(self) -> AbstractContextManager: - return _DtypeContextManager(self.dtype) + return self.bitsandbytes_impl.tensor_init_context() @override def module_init_context(self) -> AbstractContextManager: - if self.ignore_modules: - # cannot patch the Linear class if the user wants to skip some submodules - raise RuntimeError( - "Instantiating your model under the `init_module` context manager is not supported when used with" - f" `BitsandbytesPrecision(..., ignore_modules={self.ignore_modules})` as this" - " may initialize the layers on-device, defeating the purpose of quantization. You can remove" - " `ignore_modules` or remove the `init_module` context manager." - ) - dtype_ctx = self.tensor_init_context() - # TODO: this could also support replacing `Embedding` and `Conv1D` - context_manager = _ClassReplacementContextManager({"torch.nn.Linear": self._linear_cls}) - stack = ExitStack() - stack.enter_context(dtype_ctx) - stack.enter_context(context_manager) - return stack + return self.bitsandbytes_impl.module_init_context() @override def forward_context(self) -> AbstractContextManager: - return _DtypeContextManager(self.dtype) + return self.bitsandbytes_impl.forward_context() @override def convert_input(self, data: Any) -> Any: - return apply_to_collection(data, function=_convert_fp_tensor, dtype=Tensor, dst_type=self.dtype) + return self.bitsandbytes_impl.convert_input(data) @override def convert_output(self, data: Any) -> Any: - return apply_to_collection(data, function=_convert_fp_tensor, dtype=Tensor, dst_type=torch.get_default_dtype()) - - -def _quantize_on_load_hook(quantize_fn: Callable[[torch.Tensor], None], state_dict: OrderedDict, *_: Any) -> None: - # There is only one key that ends with `*.weight`, the other one is the bias - weight_key = next((name for name in state_dict if name.endswith("weight")), None) - if weight_key is None: - return - # Load the weight from the state dict and re-quantize it - weight = state_dict.pop(weight_key) - quantize_fn(weight) - - -def _ignore_missing_weights_hook(module: torch.nn.Module, incompatible_keys: _IncompatibleKeys) -> None: - # since we manually loaded the weight in the `_quantize_on_load_hook` hook, we need to avoid this missing key false - # positive - for key in reversed(incompatible_keys.missing_keys): - if key.endswith("weight"): - incompatible_keys.missing_keys.remove(key) - - -def _replace_param( - param: torch.nn.Parameter, data: torch.Tensor, quant_state: Optional[tuple] = None -) -> torch.nn.Parameter: - bnb = _import_bitsandbytes() - - # doing `param.data = weight` raises a RuntimeError if param.data was on meta-device, so - # we need to re-create the parameters instead of overwriting the data - if param.device.type == "meta": - if isinstance(param, bnb.nn.Params4bit): - return bnb.nn.Params4bit( - data=data, - requires_grad=data.requires_grad, - quant_state=quant_state, - blocksize=param.blocksize, - compress_statistics=param.compress_statistics, - quant_type=param.quant_type, - quant_storage=param.quant_storage, - module=param.module, - bnb_quantized=param.bnb_quantized, - ) - return torch.nn.Parameter(data, requires_grad=data.requires_grad) - param.data = data - if isinstance(param, bnb.nn.Params4bit): - param.quant_state = quant_state - return param - - -@functools.lru_cache(maxsize=1) -def _import_bitsandbytes() -> ModuleType: - if not _BITSANDBYTES_AVAILABLE: - raise ModuleNotFoundError(str(_BITSANDBYTES_AVAILABLE)) - # configuration for bitsandbytes before import - nowelcome_set = "BITSANDBYTES_NOWELCOME" in os.environ - if not nowelcome_set: - os.environ["BITSANDBYTES_NOWELCOME"] = "1" - warnings.filterwarnings("ignore", message=r".*bitsandbytes was compiled without GPU support.*") - warnings.filterwarnings( - "ignore", message=r"MatMul8bitLt: inputs will be cast from .* to float16 during quantization" - ) - import bitsandbytes as bnb - - if not nowelcome_set: - del os.environ["BITSANDBYTES_NOWELCOME"] - - class _Linear8bitLt(bnb.nn.Linear8bitLt): - """Wraps `bnb.nn.Linear8bitLt` and enables instantiation directly on the device and re-quantizaton when loading - the state dict.""" - - def __init__(self, *args: Any, device: Optional[_DEVICE] = None, threshold: float = 6.0, **kwargs: Any) -> None: - super().__init__(*args, device=device, threshold=threshold, **kwargs) - self.weight = cast(bnb.nn.Int8Params, self.weight) # type: ignore[has-type] - self.bias: Optional[torch.nn.Parameter] = self.bias - # if the device is CUDA or we are under a CUDA context manager, quantize the weight here, so we don't end up - # filling the device memory with float32 weights which could lead to OOM - if torch.tensor(0, device=device).device.type == "cuda": - self.quantize_() - self._register_load_state_dict_pre_hook(partial(_quantize_on_load_hook, self.quantize_)) - self.register_load_state_dict_post_hook(_ignore_missing_weights_hook) - - def quantize_(self, weight: Optional[torch.Tensor] = None, device: Optional[torch.device] = None) -> None: - """Inplace quantize.""" - if weight is None: - weight = self.weight.data - if weight.data.dtype == torch.int8: - # already quantized - return - assert isinstance(self.weight, bnb.nn.Int8Params) - self.weight = self.quantize(self.weight, weight, device) - - @staticmethod - def quantize( - int8params: bnb.nn.Int8Params, weight: torch.Tensor, device: Optional[torch.device] - ) -> bnb.nn.Int8Params: - device = device or torch.device("cuda") - if device.type != "cuda": - raise RuntimeError(f"Unexpected device type: {device.type}") - # https://github.com/TimDettmers/bitsandbytes/blob/0.41.0/bitsandbytes/nn/modules.py#L291-L302 - B = weight.contiguous().to(device=device, dtype=torch.float16) - if int8params.has_fp16_weights: - int8params.data = B - else: - # bitsandbytes >= 0.45 supports an improved API - if hasattr(bnb.functional, "int8_vectorwise_quant"): - CB, SCB, _ = bnb.functional.int8_vectorwise_quant(B) - else: # old method is deprecated in 0.45, removed in 0.46+. - CB, _, SCB, _, _ = bnb.functional.double_quant(B) - - int8params.data = CB - setattr(int8params, "CB", CB) - setattr(int8params, "SCB", SCB) - return int8params - - def to_empty(self, *, device: _DEVICE, recurse: bool = True) -> Self: - if self.weight.device.type == "meta": - # need custom logic if int8params is on meta device - raise NotImplementedError - if self.weight.dtype == torch.uint8: # was quantized - # need the original shape here - raise NotImplementedError - device = torch.device(device) - weight = torch.empty_like(self.weight.data, device=device) - if device.type == "cuda": # re-quantize - self.quantize_(weight, device) - else: - self.weight = _replace_param(self.weight, weight) - if self.bias is not None: - self.bias = _replace_param(self.bias, torch.empty_like(self.bias, device=device)) - return self - - def reset_parameters(self) -> None: - # from `torch.nn.Linear.reset_parameters` - if self.bias is not None: - fan_in, _ = torch.nn.init._calculate_fan_in_and_fan_out(self.weight) - bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 - init.uniform_(self.bias, -bound, bound) - - linear_init_finished = isinstance(self.weight, bnb.nn.Params4bit) - if linear_init_finished and self.weight.dtype == torch.uint8: # was quantized - # need the original shape here - raise NotImplementedError - weight = self.weight.data - torch.nn.init.kaiming_uniform_(weight, a=math.sqrt(5)) - if linear_init_finished: - if self.weight.device.type == "meta": - # need custom logic if int8params is on meta device - raise NotImplementedError - if self.weight.device.type == "cuda": # re-quantize - self.quantize_(weight) - else: - self.weight = _replace_param(self.weight, weight) - - class _Linear4bit(bnb.nn.Linear4bit): - """Wraps `bnb.nn.Linear4bit` to enable: instantiation directly on the device, re-quantizaton when loading the - state dict, meta-device initialization, and materialization.""" - - def __init__(self, *args: Any, device: Optional[_DEVICE] = None, **kwargs: Any) -> None: - super().__init__(*args, device=device, **kwargs) - self.weight = cast(bnb.nn.Params4bit, self.weight) # type: ignore[has-type] - self.bias: Optional[torch.nn.Parameter] = self.bias - # if the device is CUDA or we are under a CUDA context manager, quantize the weight here, so we don't end up - # filling the device memory with float32 weights which could lead to OOM - if torch.tensor(0, device=device).device.type == "cuda": - self.quantize_() - self._register_load_state_dict_pre_hook(partial(_quantize_on_load_hook, self.quantize_)) - self.register_load_state_dict_post_hook(_ignore_missing_weights_hook) - - def quantize_(self, weight: Optional[torch.Tensor] = None, device: Optional[torch.device] = None) -> None: - """Inplace quantize.""" - if weight is None: - weight = self.weight.data - if weight.data.dtype == torch.uint8: - # already quantized - return - assert isinstance(self.weight, bnb.nn.Params4bit) - self.weight = self.quantize(self.weight, weight, device) - self.weight.bnb_quantized = True - - @staticmethod - def quantize( - params4bit: bnb.nn.Params4bit, weight: torch.Tensor, device: Optional[torch.device] - ) -> bnb.nn.Params4bit: - device = device or torch.device("cuda") - if device.type != "cuda": - raise RuntimeError(f"Unexpected device type: {device.type}") - # https://github.com/TimDettmers/bitsandbytes/blob/0.41.0/bitsandbytes/nn/modules.py#L156-L159 - w = weight.contiguous().to(device=device, dtype=torch.half) - w_4bit, quant_state = bnb.functional.quantize_4bit( - w, - blocksize=params4bit.blocksize, - compress_statistics=params4bit.compress_statistics, - quant_type=params4bit.quant_type, - quant_storage=params4bit.quant_storage, - ) - return _replace_param(params4bit, w_4bit, quant_state) - - def to_empty(self, *, device: _DEVICE, recurse: bool = True) -> Self: - if self.weight.dtype == torch.uint8: # was quantized - # cannot init the quantized params directly - weight = torch.empty(self.weight.quant_state.shape, device=device, dtype=torch.half) - else: - weight = torch.empty_like(self.weight.data, device=device) - device = torch.device(device) - if device.type == "cuda": # re-quantize - self.quantize_(weight, device) - else: - self.weight = _replace_param(self.weight, weight) - if self.bias is not None: - self.bias = _replace_param(self.bias, torch.empty_like(self.bias, device=device)) - return self - - def reset_parameters(self) -> None: - # from `torch.nn.Linear.reset_parameters` - if self.bias is not None: - fan_in, _ = torch.nn.init._calculate_fan_in_and_fan_out(self.weight) - bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 - init.uniform_(self.bias, -bound, bound) - - linear_init_finished = isinstance(self.weight, bnb.nn.Params4bit) - if linear_init_finished and self.weight.dtype == torch.uint8: # was quantized - # cannot init the quantized params directly - weight = torch.empty(self.weight.quant_state.shape, device=self.weight.device, dtype=torch.half) - else: - weight = self.weight.data - torch.nn.init.kaiming_uniform_(weight, a=math.sqrt(5)) - if linear_init_finished: - if self.weight.device.type == "cuda": # re-quantize - self.quantize_(weight) - else: - self.weight = _replace_param(self.weight, weight) - - # Use a class instead `functools.partial` to respect `isinstance` checks and attribute accesses - class _Int8LinearInference(_Linear8bitLt): - def __init__(self, *args: Any, **kwargs: Any) -> None: - super().__init__(*args, has_fp16_weights=False, **kwargs) - - class _FP4Linear(_Linear4bit): - def __init__(self, *args: Any, **kwargs: Any) -> None: - super().__init__(*args, quant_type="fp4", compress_statistics=False, **kwargs) - - class _FP4DQLinear(_Linear4bit): - def __init__(self, *args: Any, **kwargs: Any) -> None: - super().__init__(*args, quant_type="fp4", compress_statistics=True, **kwargs) - - class _NF4Linear(_Linear4bit): - def __init__(self, *args: Any, **kwargs: Any) -> None: - super().__init__(*args, quant_type="nf4", compress_statistics=False, **kwargs) - - class _NF4DQLinear(_Linear4bit): - def __init__(self, *args: Any, **kwargs: Any) -> None: - super().__init__(*args, quant_type="nf4", compress_statistics=True, **kwargs) - - # these classes are defined programmatically like this to avoid importing bitsandbytes in environments that have - # it available but will not use it - classes = { - "_Linear8bitLt": _Linear8bitLt, - "_Linear4bit": _Linear4bit, - "_Int8LinearInference": _Int8LinearInference, - "_FP4Linear": _FP4Linear, - "_FP4DQLinear": _FP4DQLinear, - "_NF4Linear": _NF4Linear, - "_NF4DQLinear": _NF4DQLinear, - } - globals().update(classes) - - return bnb - - -def _convert_layers(module: torch.nn.Module, linear_cls: type, ignore_modules: set[str], prefix: str = "") -> None: - for name, child in module.named_children(): - fullname = f"{prefix}.{name}" if prefix else name - if isinstance(child, torch.nn.Linear) and not any(fullname.startswith(s) for s in ignore_modules): - log.debug(f"Replacing layer {fullname!r} with bitsandbytes equivalent") - has_bias = child.bias is not None - # since we are going to copy over the child's data, the device doesn't matter. I chose CPU - # to avoid spiking CUDA memory even though initialization is slower - # 4bit layers support quantizing from meta-device params so this is only relevant for 8-bit - _Linear4bit = globals()["_Linear4bit"] - device = torch.device("meta" if issubclass(linear_cls, _Linear4bit) else "cpu") - replacement = linear_cls( - child.in_features, - child.out_features, - bias=has_bias, - device=device, - ) - if has_bias: - replacement.bias = _replace_param(replacement.bias, child.bias.data.clone()) - state = {"quant_state": replacement.weight.quant_state if issubclass(linear_cls, _Linear4bit) else None} - replacement.weight = _replace_param(replacement.weight, child.weight.data.clone(), **state) - module.__setattr__(name, replacement) - else: - _convert_layers(child, linear_cls, ignore_modules, prefix=fullname) + return self.bitsandbytes_impl.convert_output(data) diff --git a/src/lightning/fabric/plugins/precision/deepspeed.py b/src/lightning/fabric/plugins/precision/deepspeed.py index 526095008f376..837fb2dad3aac 100644 --- a/src/lightning/fabric/plugins/precision/deepspeed.py +++ b/src/lightning/fabric/plugins/precision/deepspeed.py @@ -11,17 +11,16 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from contextlib import AbstractContextManager, nullcontext +from contextlib import AbstractContextManager from typing import TYPE_CHECKING, Any, Literal import torch -from lightning_utilities.core.apply_func import apply_to_collection from torch import Tensor from torch.nn import Module -from typing_extensions import get_args, override +from typing_extensions import override from lightning.fabric.plugins.precision.precision import Precision -from lightning.fabric.plugins.precision.utils import _convert_fp_tensor, _DtypeContextManager +from lightning.fabric.utilities.imports import _raise_enterprise_not_available from lightning.fabric.utilities.types import Steppable if TYPE_CHECKING: @@ -44,51 +43,37 @@ class DeepSpeedPrecision(Precision): """ def __init__(self, precision: _PRECISION_INPUT) -> None: - supported_precision = get_args(_PRECISION_INPUT) - if precision not in supported_precision: - raise ValueError( - f"`precision={precision!r})` is not supported in DeepSpeed." - f" `precision` must be one of: {supported_precision}." - ) - self.precision = precision - - precision_to_type = { - "bf16-mixed": torch.bfloat16, - "16-mixed": torch.float16, - "bf16-true": torch.bfloat16, - "16-true": torch.float16, - "32-true": torch.float32, - } - self._desired_dtype = precision_to_type[self.precision] + super().__init__() + _raise_enterprise_not_available() + from pytorch_lightning_enterprise.plugins.precision.deepspeed import ( + DeepSpeedPrecisionFabric as EnterpriseDeepSpeedPrecision, + ) + + self.deepspeed_impl = EnterpriseDeepSpeedPrecision(precision=precision) @override def convert_module(self, module: Module) -> Module: - if "true" in self.precision: - return module.to(dtype=self._desired_dtype) - return module + return self.deepspeed_impl.convert_module(module) @override def tensor_init_context(self) -> AbstractContextManager: - if "true" not in self.precision: - return nullcontext() - return _DtypeContextManager(self._desired_dtype) + return self.deepspeed_impl.tensor_init_context() @override def module_init_context(self) -> AbstractContextManager: - return self.tensor_init_context() + return self.deepspeed_impl.module_init_context() @override def convert_input(self, data: Any) -> Any: - return apply_to_collection(data, function=_convert_fp_tensor, dtype=Tensor, dst_type=self._desired_dtype) + return self.deepspeed_impl.convert_input(data) @override def convert_output(self, data: Any) -> Any: - return apply_to_collection(data, function=_convert_fp_tensor, dtype=Tensor, dst_type=torch.get_default_dtype()) + return self.deepspeed_impl.convert_output(data) @override def backward(self, tensor: Tensor, model: "DeepSpeedEngine", *args: Any, **kwargs: Any) -> None: - """Performs back-propagation using DeepSpeed's engine.""" - model.backward(tensor, *args, **kwargs) + return self.deepspeed_impl.backward(tensor, model, *args, **kwargs) @override def optimizer_step( @@ -96,5 +81,20 @@ def optimizer_step( optimizer: Steppable, **kwargs: Any, ) -> Any: - # DeepSpeed handles the optimizer step internally - return optimizer.step(**kwargs) + return self.deepspeed_impl.optimizer_step(optimizer, **kwargs) + + @property + def precision(self) -> _PRECISION_INPUT: + return self.deepspeed_impl.precision + + @precision.setter + def precision(self, precision: _PRECISION_INPUT) -> None: + self.deepspeed_impl.precision = precision + + @property + def _desired_dtype(self) -> torch.dtype: + return self.deepspeed_impl._desired_dtype + + @_desired_dtype.setter + def _desired_dtype(self, dtype: torch.dtype) -> None: + self.deepspeed_impl._desired_dtype = dtype diff --git a/src/lightning/fabric/plugins/precision/transformer_engine.py b/src/lightning/fabric/plugins/precision/transformer_engine.py index bf1e51ea6b2b0..3d3d2491763f4 100644 --- a/src/lightning/fabric/plugins/precision/transformer_engine.py +++ b/src/lightning/fabric/plugins/precision/transformer_engine.py @@ -11,31 +11,19 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import logging from collections.abc import Mapping -from contextlib import AbstractContextManager, ExitStack +from contextlib import AbstractContextManager from typing import TYPE_CHECKING, Any, Literal, Optional, Union import torch -from lightning_utilities import apply_to_collection -from lightning_utilities.core.imports import RequirementCache -from torch import Tensor from typing_extensions import override from lightning.fabric.plugins.precision.precision import Precision -from lightning.fabric.plugins.precision.utils import ( - _ClassReplacementContextManager, - _convert_fp_tensor, - _DtypeContextManager, -) -from lightning.fabric.utilities.rank_zero import rank_zero_info, rank_zero_warn +from lightning.fabric.utilities.imports import _raise_enterprise_not_available if TYPE_CHECKING: from transformer_engine.common.recipe import DelayedScaling -_TRANSFORMER_ENGINE_AVAILABLE = RequirementCache("transformer_engine>=0.11.0") -log = logging.getLogger(__name__) - class TransformerEnginePrecision(Precision): """Plugin for training with fp8 precision via nvidia's @@ -72,111 +60,79 @@ def __init__( replace_layers: Optional[bool] = None, fallback_compute_dtype: Optional[torch.dtype] = None, ) -> None: - if not _TRANSFORMER_ENGINE_AVAILABLE: - raise ModuleNotFoundError(str(_TRANSFORMER_ENGINE_AVAILABLE)) - from transformer_engine.common.recipe import DelayedScaling - - if recipe is None: - recipe = DelayedScaling() - elif isinstance(recipe, Mapping): - recipe = dict(recipe) # copy - if "fp8_format" in recipe: - from transformer_engine.common.recipe import Format - - recipe["fp8_format"] = getattr(Format, recipe["fp8_format"]) - recipe = DelayedScaling(**recipe) - - self.weights_dtype = weights_dtype - self.recipe = recipe - self.replace_layers = replace_layers - self.fallback_compute_dtype = fallback_compute_dtype or weights_dtype + super().__init__() + _raise_enterprise_not_available() + from pytorch_lightning_enterprise.plugins.precision.transformer_engine import ( + TransformerEnginePrecision as EnterpriseTransformerEnginePrecision, + ) + + self.transformer_engine_impl = EnterpriseTransformerEnginePrecision( + weights_dtype=weights_dtype, + recipe=recipe, + replace_layers=replace_layers, + fallback_compute_dtype=fallback_compute_dtype, + ) + + @property + def weights_dtype(self) -> torch.dtype: + return self.transformer_engine_impl.weights_dtype + + @weights_dtype.setter + def weights_dtype(self, value: torch.dtype) -> None: + self.transformer_engine_impl.weights_dtype = value + + @property + def recipe(self) -> Union[Mapping[str, Any], "DelayedScaling"]: + return self.transformer_engine_impl.recipe + + @recipe.setter + def recipe(self, value: Union[Mapping[str, Any], "DelayedScaling"]) -> None: + self.transformer_engine_impl.recipe = value + + @property + def replace_layers(self) -> bool: + return self.transformer_engine_impl.replace_layers + + @replace_layers.setter + def replace_layers(self, value: bool) -> None: + self.transformer_engine_impl.replace_layers = value + + @property + def fallback_compute_dtype(self) -> torch.dtype: + return self.transformer_engine_impl.fallback_compute_dtype + + @fallback_compute_dtype.setter + def fallback_compute_dtype(self, value: torch.dtype) -> None: + self.transformer_engine_impl.fallback_compute_dtype = value @override def convert_module(self, module: torch.nn.Module) -> torch.nn.Module: - # avoid converting if any is found. assume the user took care of it - if any("transformer_engine.pytorch" in m.__module__ for m in module.modules()): - if self.replace_layers is True: - # info level because this is expected with `init_module` - rank_zero_info( - "`TransformerEnginePrecision(replace_layers=True)` is set but the model already contains" - " TransformerEngine layers. Skipping" - ) - elif self.replace_layers in (None, True): - _convert_layers(module) - module = module.to(dtype=self.weights_dtype) - return module + return self.transformer_engine_impl.convert_module(module) @override def tensor_init_context(self) -> AbstractContextManager: - return _DtypeContextManager(self.weights_dtype) + return self.transformer_engine_impl.tensor_init_context() @override def module_init_context(self) -> AbstractContextManager: - dtype_ctx = self.tensor_init_context() - stack = ExitStack() - if self.replace_layers: - import transformer_engine.pytorch as te - - context_manager = _ClassReplacementContextManager({ - "torch.nn.Linear": te.Linear, - "torch.nn.LayerNorm": te.LayerNorm, - }) - stack.enter_context(context_manager) - stack.enter_context(dtype_ctx) - return stack + return self.transformer_engine_impl.module_init_context() @override def forward_context(self) -> AbstractContextManager: - dtype_ctx = _DtypeContextManager(self.weights_dtype) - fallback_autocast_ctx = torch.autocast(device_type="cuda", dtype=self.fallback_compute_dtype) - import transformer_engine.pytorch as te - - autocast_ctx = te.fp8_autocast(enabled=True, fp8_recipe=self.recipe) - stack = ExitStack() - stack.enter_context(dtype_ctx) - # enable an outer fallback autocast for operations that do not support fp8 - stack.enter_context(fallback_autocast_ctx) - stack.enter_context(autocast_ctx) - return stack + return self.transformer_engine_impl.forward_context() @override def convert_input(self, data: Any) -> Any: - return apply_to_collection(data, function=_convert_fp_tensor, dtype=Tensor, dst_type=self.weights_dtype) + return self.transformer_engine_impl.convert_input(data) @override def convert_output(self, data: Any) -> Any: - return apply_to_collection(data, function=_convert_fp_tensor, dtype=Tensor, dst_type=torch.get_default_dtype()) - - -def _convert_layers(module: torch.nn.Module) -> None: - import transformer_engine.pytorch as te - - for name, child in module.named_children(): - if isinstance(child, torch.nn.Linear): - if child.in_features % 8 != 0 or child.out_features % 16 != 0: - # https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/examples/fp8_primer.html#FP8-autocasting - rank_zero_warn( - "Support for FP8 in the linear layers with this plugin is currently limited to" - " tensors with shapes where the dimensions are divisible by 8 and 16 respectively." - f" The layer {name!r} does not fit this criteria. You might want to add padding to your inputs." - ) - continue - has_bias = child.bias is not None - replacement = te.Linear(child.in_features, child.out_features, bias=has_bias) - replacement.weight.data = child.weight.data.clone() - if has_bias: - replacement.bias.data = child.bias.data.clone() - log.debug(f"Replacing layer {name!r} with Transformer Engine equivalent") - module.__setattr__(name, replacement) - elif isinstance(child, torch.nn.LayerNorm): - replacement = te.LayerNorm(child.normalized_shape[0], eps=child.eps) - replacement.weight.data = child.weight.data.clone() - # Check if bias exists before attempting to clone its data - if child.bias is not None and replacement.bias is not None: - replacement.bias.data = child.bias.data.clone() - log.debug(f"Replacing layer {name!r} with Transformer Engine equivalent") - module.__setattr__(name, replacement) - else: - # there are other transformer engine layers that we could convert but require fusion. full list at: - # https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/api/pytorch.html - _convert_layers(child) + return self.transformer_engine_impl.convert_output(data) + + @property + def _desired_dtype(self) -> torch.dtype: + return self.transformer_engine_impl._desired_dtype + + @_desired_dtype.setter + def _desired_dtype(self, dtype: torch.dtype) -> None: + self.transformer_engine_impl._desired_dtype = dtype diff --git a/src/lightning/fabric/plugins/precision/xla.py b/src/lightning/fabric/plugins/precision/xla.py index fdb30032b3cdd..4055d87e14af7 100644 --- a/src/lightning/fabric/plugins/precision/xla.py +++ b/src/lightning/fabric/plugins/precision/xla.py @@ -11,14 +11,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import os from typing import Any, Literal import torch -from typing_extensions import get_args, override +from typing_extensions import override -from lightning.fabric.accelerators.xla import _XLA_AVAILABLE from lightning.fabric.plugins.precision.precision import Precision +from lightning.fabric.utilities.imports import _raise_enterprise_not_available from lightning.fabric.utilities.types import Optimizable _PRECISION_INPUT = Literal["32-true", "16-true", "bf16-true"] @@ -37,24 +36,11 @@ class XLAPrecision(Precision): """ def __init__(self, precision: _PRECISION_INPUT) -> None: - if not _XLA_AVAILABLE: - raise ModuleNotFoundError(str(_XLA_AVAILABLE)) - supported_precision = get_args(_PRECISION_INPUT) - if precision not in supported_precision: - raise ValueError( - f"`precision={precision!r})` is not supported in XLA." - f" `precision` must be one of: {supported_precision}." - ) - self.precision = precision - - if precision == "16-true": - os.environ["XLA_USE_F16"] = "1" - self._desired_dtype = torch.float16 - elif precision == "bf16-true": - os.environ["XLA_USE_BF16"] = "1" - self._desired_dtype = torch.bfloat16 - else: - self._desired_dtype = torch.float32 + super().__init__() + _raise_enterprise_not_available() + from pytorch_lightning_enterprise.plugins.precision.xla import XLAPrecision as EnterpriseXLAPrecision + + self.xla_impl = EnterpriseXLAPrecision(precision=precision) @override def optimizer_step( @@ -62,12 +48,24 @@ def optimizer_step( optimizer: Optimizable, **kwargs: Any, ) -> Any: - import torch_xla.core.xla_model as xm - - # you always want to `xm.mark_step()` after `optimizer.step` for better performance, so we set `barrier=True` - return xm.optimizer_step(optimizer, optimizer_args=kwargs, barrier=True) + return self.xla_impl.optimizer_step(optimizer, **kwargs) @override def teardown(self) -> None: - os.environ.pop("XLA_USE_BF16", None) - os.environ.pop("XLA_USE_F16", None) + return self.xla_impl.teardown() + + @property + def _desired_dtype(self) -> torch.dtype: + return self.xla_impl._desired_dtype + + @_desired_dtype.setter + def _desired_dtype(self, dtype: torch.dtype) -> None: + self.xla_impl._desired_dtype = dtype + + @property + def precision(self) -> _PRECISION_INPUT: + return self.xla_impl.precision + + @precision.setter + def precision(self, precision: _PRECISION_INPUT) -> None: + self.xla_impl.precision = precision diff --git a/src/lightning/fabric/strategies/deepspeed.py b/src/lightning/fabric/strategies/deepspeed.py index 883546fea1f2d..a477a2b940afc 100644 --- a/src/lightning/fabric/strategies/deepspeed.py +++ b/src/lightning/fabric/strategies/deepspeed.py @@ -11,45 +11,30 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import argparse -import json import logging -import os -import platform -from collections.abc import Mapping -from contextlib import AbstractContextManager, ExitStack +from contextlib import AbstractContextManager from datetime import timedelta -from itertools import chain -from pathlib import Path from typing import TYPE_CHECKING, Any, Callable, Optional, Union import torch -from lightning_utilities.core.imports import RequirementCache from torch.nn import Module from torch.optim import Optimizer from typing_extensions import override -from lightning.fabric.accelerators import Accelerator, CUDAAccelerator +from lightning.fabric.accelerators import Accelerator from lightning.fabric.plugins.collectives.torch_collective import default_pg_timeout from lightning.fabric.plugins.environments.cluster_environment import ClusterEnvironment from lightning.fabric.plugins.precision import Precision from lightning.fabric.strategies.ddp import DDPStrategy from lightning.fabric.strategies.registry import _StrategyRegistry from lightning.fabric.strategies.strategy import _Sharded -from lightning.fabric.utilities.distributed import log -from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_6 -from lightning.fabric.utilities.load import _move_state_into -from lightning.fabric.utilities.rank_zero import rank_zero_info, rank_zero_warn -from lightning.fabric.utilities.seed import reset_seed +from lightning.fabric.utilities.imports import _raise_enterprise_not_available from lightning.fabric.utilities.types import _PATH if TYPE_CHECKING: from deepspeed import DeepSpeedEngine from torch.optim.lr_scheduler import _LRScheduler -_DEEPSPEED_AVAILABLE = RequirementCache("deepspeed") -_DEEPSPEED_GREATER_EQUAL_0_16 = RequirementCache("deepspeed>=0.16.0") - # TODO(fabric): Links in the docstrings to PL-specific deepspeed user docs need to be replaced. class DeepSpeedStrategy(DDPStrategy, _Sharded): @@ -235,24 +220,11 @@ def __init__( exclude_frozen_parameters: Exclude frozen parameters when saving checkpoints. """ - if not _DEEPSPEED_AVAILABLE: - raise ImportError( - "To use the `DeepSpeedStrategy`, you must have DeepSpeed installed." - " Install it by running `pip install -U deepspeed`." - ) - - if _TORCH_GREATER_EQUAL_2_6 and not _DEEPSPEED_GREATER_EQUAL_0_16: - # Starting with PyTorch 2.6, `torch.load` defaults to `weights_only=True` when loading full checkpoints. - # DeepSpeed added support for this behavior in version 0.16.0. - import deepspeed - - deepspeed_version = deepspeed.__version__ - raise ImportError( - f"PyTorch >= 2.6 requires DeepSpeed >= 0.16.0. " - f"Detected DeepSpeed version: {deepspeed_version}. " - "Please upgrade by running `pip install -U 'deepspeed>=0.16.0'`." - ) + _raise_enterprise_not_available() + from pytorch_lightning_enterprise.strategies.deepspeed import ( + DeepSpeedStrategyFabric as EnterpriseDeepSpeedStrategy, + ) super().__init__( accelerator=accelerator, @@ -261,77 +233,68 @@ def __init__( precision=precision, process_group_backend=process_group_backend, ) - self._backward_sync_control = None # DeepSpeed handles gradient accumulation internally - self._timeout: Optional[timedelta] = timeout - - self.config = self._load_config(config) - if self.config is None: - # User has not overridden config, set defaults - self.config = self._create_default_config( - zero_optimization, - zero_allow_untested_optimizer, - logging_batch_size_per_gpu, - offload_optimizer=offload_optimizer, - offload_parameters=offload_parameters, - nvme_path=nvme_path, - offload_params_device=offload_params_device, - params_buffer_count=params_buffer_count, - params_buffer_size=params_buffer_size, - max_in_cpu=max_in_cpu, - pin_memory=pin_memory, - offload_optimizer_device=offload_optimizer_device, - optimizer_buffer_count=optimizer_buffer_count, - block_size=block_size, - queue_depth=queue_depth, - single_submit=single_submit, - overlap_events=overlap_events, - thread_count=thread_count, - partition_activations=partition_activations, - cpu_checkpointing=cpu_checkpointing, - contiguous_memory_optimization=contiguous_memory_optimization, - synchronize_checkpoint_boundary=synchronize_checkpoint_boundary, - stage=stage, - contiguous_gradients=contiguous_gradients, - overlap_comm=overlap_comm, - allgather_partitions=allgather_partitions, - reduce_scatter=reduce_scatter, - allgather_bucket_size=allgather_bucket_size, - reduce_bucket_size=reduce_bucket_size, - sub_group_size=sub_group_size, - ) - - import deepspeed - - self._config_initialized = False - deepspeed.utils.logging.logger.setLevel(logging_level) - - self.remote_device = remote_device - self.load_full_weights = load_full_weights - self.exclude_frozen_parameters = exclude_frozen_parameters - - # default FP16 parameters. - self.loss_scale = loss_scale - self.initial_scale_power = initial_scale_power - self.loss_scale_window = loss_scale_window - self.hysteresis = hysteresis - self.min_loss_scale = min_loss_scale - - self._deepspeed_engine: Optional[DeepSpeedEngine] = None + self.deepspeed_impl = EnterpriseDeepSpeedStrategy( + outer_object=self, + accelerator=accelerator, + zero_optimization=zero_optimization, + stage=stage, + remote_device=remote_device, + offload_optimizer=offload_optimizer, + offload_parameters=offload_parameters, + offload_params_device=offload_params_device, + nvme_path=nvme_path, + params_buffer_count=params_buffer_count, + params_buffer_size=params_buffer_size, + max_in_cpu=max_in_cpu, + offload_optimizer_device=offload_optimizer_device, + optimizer_buffer_count=optimizer_buffer_count, + block_size=block_size, + queue_depth=queue_depth, + single_submit=single_submit, + overlap_events=overlap_events, + thread_count=thread_count, + pin_memory=pin_memory, + sub_group_size=sub_group_size, + contiguous_gradients=contiguous_gradients, + overlap_comm=overlap_comm, + allgather_partitions=allgather_partitions, + reduce_scatter=reduce_scatter, + allgather_bucket_size=allgather_bucket_size, + reduce_bucket_size=reduce_bucket_size, + zero_allow_untested_optimizer=zero_allow_untested_optimizer, + logging_batch_size_per_gpu=logging_batch_size_per_gpu, + config=config, + logging_level=logging_level, + parallel_devices=parallel_devices, + cluster_environment=cluster_environment, + loss_scale=loss_scale, + initial_scale_power=initial_scale_power, + loss_scale_window=loss_scale_window, + hysteresis=hysteresis, + min_loss_scale=min_loss_scale, + partition_activations=partition_activations, + cpu_checkpointing=cpu_checkpointing, + contiguous_memory_optimization=contiguous_memory_optimization, + synchronize_checkpoint_boundary=synchronize_checkpoint_boundary, + load_full_weights=load_full_weights, + precision=precision, + process_group_backend=process_group_backend, + timeout=timeout, + exclude_frozen_parameters=exclude_frozen_parameters, + ) @property def zero_stage_3(self) -> bool: - assert isinstance(self.config, dict) - zero_optimization = self.config.get("zero_optimization") - return zero_optimization is not None and zero_optimization.get("stage") == 3 + return self.deepspeed_impl.zero_stage_3 @property @override def distributed_sampler_kwargs(self) -> dict[str, int]: - return {"num_replicas": self.world_size, "rank": self.global_rank} + return self.deepspeed_impl.distributed_sampler_kwargs @property def model(self) -> "DeepSpeedEngine": - return self._deepspeed_engine + return self.deepspeed_impl.model @override def setup_module_and_optimizers( @@ -345,14 +308,9 @@ def setup_module_and_optimizers( deepspeed optimizer, and an optional learning rate scheduler. """ - if len(optimizers) != 1: - raise ValueError( - f"Currently only one optimizer is supported with DeepSpeed. Got {len(optimizers)} optimizers instead." - ) - - self._deepspeed_engine, optimizer, scheduler = self._initialize_engine(module, optimizers[0], scheduler) - self._set_deepspeed_activation_checkpointing() - return self._deepspeed_engine, [optimizer], scheduler + return self.deepspeed_impl.setup_module_and_optimizers( + module=module, optimizers=optimizers, scheduler=scheduler + ) @override def setup_module(self, module: Module) -> "DeepSpeedEngine": @@ -361,8 +319,7 @@ def setup_module(self, module: Module) -> "DeepSpeedEngine": For training, see :meth:`setup_module_and_optimizers`. """ - self._deepspeed_engine, _, _ = self._initialize_engine(module) - return self._deepspeed_engine + return self.deepspeed_impl.setup_module(module=module) @override def setup_optimizer(self, optimizer: Optimizer) -> Optimizer: @@ -371,34 +328,15 @@ def setup_optimizer(self, optimizer: Optimizer) -> Optimizer: Please use :meth:`setup_module_and_optimizers` to set up both module and optimizer together. """ - raise NotImplementedError(self._err_msg_joint_setup_required()) + return self.deepspeed_impl.setup_optimizer(optimizer=optimizer) @override def module_init_context(self, empty_init: Optional[bool] = None) -> AbstractContextManager: - if self.zero_stage_3 and empty_init is False: - raise NotImplementedError( - f"`{empty_init=}` is not a valid choice with `DeepSpeedStrategy` when ZeRO stage 3 is enabled." - ) - module_sharded_ctx = self.module_sharded_context() - stack = ExitStack() - if not self.zero_stage_3: - stack.enter_context(super().module_init_context(empty_init=empty_init)) - stack.enter_context(module_sharded_ctx) - return stack + return self.deepspeed_impl.module_init_context(empty_init=empty_init) @override def module_sharded_context(self) -> AbstractContextManager: - # Current limitation in Fabric: The config needs to be fully determined at the time of calling the context - # manager. Later modifications through e.g. `Fabric.setup()` won't have an effect here. - - import deepspeed - - assert self._config_initialized - return deepspeed.zero.Init( - enabled=self.zero_stage_3, - remote_device=self.remote_device, - config_dict_or_path=self.config, - ) + return self.deepspeed_impl.module_sharded_context() @override def save_checkpoint( @@ -425,46 +363,8 @@ def save_checkpoint( :class:`deepspeed.DeepSpeedEngine` objects were found. """ - if storage_options is not None: - raise TypeError( - "`DeepSpeedStrategy.save_checkpoint(..., storage_options=...)` is not supported because" - " `DeepSpeedStrategy` does not use the `CheckpointIO`." - ) - if filter is not None: - raise TypeError( - "`DeepSpeedStrategy.save_checkpoint(..., filter=...)` is not supported because" - " `DeepSpeedStrategy` manages the state serialization internally." - ) - - engines = _get_deepspeed_engines_from_state(state) - if len(engines) == 0: - raise ValueError( - "Could not find a DeepSpeed model in the provided checkpoint state. Please provide the model as" - " part of the state like so: `save_checkpoint(..., state={'model': model, ...})`. Make sure" - " you set up the model (and optimizers if any) through the strategy before saving the checkpoint." - ) - if len(engines) > 1: - raise ValueError( - "Found multiple DeepSpeed engine modules in the given state. Saving checkpoints with DeepSpeed is" - " currently limited to a single model per checkpoint. To save multiple models, call the" - " save method for each model separately with a different path." - ) - engine = engines[0] - - # broadcast the path from rank 0 to ensure all the states are saved in a common path - path = self.broadcast(path) - - # split the checkpoint into two parts: - # 1) the deepspeed engine encapsulating both the model and optionally the optimizer(s) - # 2) the rest of the user's state, which in deepspeed is called `client state` - excluded_objects = (engine, engine.optimizer) if engine.optimizer is not None else (engine,) - state = {k: v for k, v in state.items() if v not in excluded_objects} - _validate_state_keys(state) - # there might be other stateful objects unrelated to the deepspeed engine - convert them to a state_dict - state = self._convert_stateful_objects_in_state(state, filter={}) - # use deepspeed's internal checkpointing function to handle partitioned weights across processes - engine.save_checkpoint( - path, client_state=state, tag="checkpoint", exclude_frozen_parameters=self.exclude_frozen_parameters + return self.deepspeed_impl.save_checkpoint( + path=path, state=state, storage_options=storage_options, filter=filter ) @override @@ -495,59 +395,7 @@ def load_checkpoint( not in the expected DeepSpeed format. """ - if isinstance(state, (Module, Optimizer)) or self.load_full_weights and self.zero_stage_3: - # This code path to enables loading a checkpoint from a non-deepspeed checkpoint or from - # a consolidated checkpoint - path = self.broadcast(path) - return super().load_checkpoint(path=path, state=state, strict=strict, weights_only=weights_only) - - if not state: - raise ValueError( - f"Got DeepSpeedStrategy.load_checkpoint(..., state={state!r}) but a state with at least " - f" a model instance to reload is required. Pass it in like so:" - " DeepSpeedStrategy.load_checkpoint(..., state={'model': model, ...})" - ) - _validate_checkpoint_directory(path) - - engines = _get_deepspeed_engines_from_state(state) - if len(engines) == 0: - raise ValueError( - "Could not find a DeepSpeed model in the provided checkpoint state. Please provide the model as" - " part of the state like so: `load_checkpoint(..., state={'model': model, ...})`. Make sure" - " you set up the model (and optimizers if any) through the strategy before loading the checkpoint." - ) - if len(engines) > 1: - raise ValueError( - "Found multiple DeepSpeed engine modules in the given state. Saving and loading checkpoints" - " with DeepSpeed is currently limited to a single model per checkpoint. To load multiple model" - " states, call the load method for each model checkpoint separately." - ) - engine = engines[0] - - from deepspeed.runtime.base_optimizer import DeepSpeedOptimizer - - optimzer_state_requested = any(isinstance(item, (Optimizer, DeepSpeedOptimizer)) for item in state.values()) - - torch.cuda.empty_cache() - _, client_state = engine.load_checkpoint( - path, - tag="checkpoint", - load_optimizer_states=optimzer_state_requested, - load_lr_scheduler_states=False, - load_module_strict=strict, - ) - - if client_state is None: - raise RuntimeError( - "DeepSpeed was unable to load the checkpoint. Ensure you passed in a DeepSpeed compatible checkpoint" - " or a single checkpoint file by setting `DeepSpeedStrategy(..., load_full_weights=True)`." - ) - - # `Engine.load_checkpoint` adds useless keys 'optimizer' and 'lr_scheduler' to the client state; remove - # them to avoid name collision with user state - keys = set(client_state) & set(state) - {"optimizer", "lr_scheduler"} - _move_state_into(source=client_state, destination=state, keys=keys) - return client_state + return self.deepspeed_impl.load_checkpoint(path=path, state=state, strict=strict) @override def clip_gradients_norm( @@ -558,19 +406,19 @@ def clip_gradients_norm( norm_type: Union[float, int] = 2.0, error_if_nonfinite: bool = True, ) -> torch.Tensor: - raise NotImplementedError( - "DeepSpeed handles gradient clipping automatically within the optimizer. " - "Make sure to set the `gradient_clipping` value in your Config." + return self.deepspeed_impl.clip_gradients_norm( + module=module, + optimizer=optimizer, + max_norm=max_norm, + norm_type=norm_type, + error_if_nonfinite=error_if_nonfinite, ) @override def clip_gradients_value( self, module: "DeepSpeedEngine", optimizer: Optimizer, clip_val: Union[float, int] ) -> None: - raise NotImplementedError( - "DeepSpeed handles gradient clipping automatically within the optimizer. " - "Make sure to set the `gradient_clipping` value in your Config." - ) + return self.deepspeed_impl.clip_gradients_value(module=module, optimizer=optimizer, clip_val=clip_val) @classmethod @override @@ -613,338 +461,18 @@ def register_strategies(cls, strategy_registry: _StrategyRegistry) -> None: offload_optimizer_device="nvme", ) - def _initialize_engine( - self, model: Module, optimizer: Optional[Optimizer] = None, scheduler: Optional["_LRScheduler"] = None - ) -> tuple["DeepSpeedEngine", Optimizer, Any]: - """Initialize one model and one optimizer with an optional learning rate scheduler. - - This calls ``deepspeed.initialize`` internally. - - """ - import deepspeed - - model_parameters = filter(lambda p: p.requires_grad, model.parameters()) - deepspeed_engine, deepspeed_optimizer, _, deepspeed_scheduler = deepspeed.initialize( - args=argparse.Namespace(device_rank=self.root_device.index), - config=self.config, - model=model, - model_parameters=model_parameters, - optimizer=optimizer, - lr_scheduler=scheduler, - dist_init_required=False, - ) - return deepspeed_engine, deepspeed_optimizer, deepspeed_scheduler - @override def setup_environment(self) -> None: - if not isinstance(self.accelerator, CUDAAccelerator): - raise RuntimeError( - f"The DeepSpeed strategy is only supported on CUDA GPUs but `{self.accelerator.__class__.__name__}`" - " is used." - ) - super().setup_environment() + return self.deepspeed_impl.setup_environment() @override def _setup_distributed(self) -> None: - assert self.parallel_devices is not None - _validate_device_index_selection(self.parallel_devices) - reset_seed() - self._set_world_ranks() - self._init_deepspeed_distributed() - if not self._config_initialized: - self._format_config() - self._config_initialized = True - - def _init_deepspeed_distributed(self) -> None: - import deepspeed - - assert self.cluster_environment is not None - if platform.system() != "Windows": - # do not set env variables on windows, allow deepspeed to control setup - self._set_node_environment_variables() - log.info( - "initializing deepspeed distributed: " - f"GLOBAL_RANK: {self.global_rank}, " - f"MEMBER: {self.global_rank + 1}/{self.world_size}" - ) - self._process_group_backend = self._get_process_group_backend() - deepspeed.init_distributed( - self._process_group_backend, distributed_port=self.cluster_environment.main_port, timeout=self._timeout - ) - - def _set_node_environment_variables(self) -> None: - assert self.cluster_environment is not None - os.environ["MASTER_ADDR"] = self.cluster_environment.main_address - os.environ["MASTER_PORT"] = str(self.cluster_environment.main_port) - os.environ["RANK"] = str(self.global_rank) - os.environ["WORLD_SIZE"] = str(self.world_size) - os.environ["LOCAL_RANK"] = str(self.local_rank) - - def _set_deepspeed_activation_checkpointing(self) -> None: - import deepspeed - - assert isinstance(self.config, dict) - if self.config.get("activation_checkpointing"): - checkpoint_config = self.config["activation_checkpointing"] - deepspeed.checkpointing.configure( - mpu_=None, - partition_activations=checkpoint_config.get("partition_activations"), - contiguous_checkpointing=checkpoint_config.get("contiguous_memory_optimization"), - checkpoint_in_cpu=checkpoint_config.get("cpu_checkpointing"), - profile=checkpoint_config.get("profile"), - ) - - def _format_config(self) -> None: - if self.config is None: - raise ValueError( - "To use DeepSpeed you must pass in a DeepSpeed config dict, or a path to a JSON config." - " See: https://lightning.ai/docs/pytorch/stable/advanced/model_parallel.html#deepspeed" - ) - - self.config.setdefault("train_micro_batch_size_per_gpu", 1) - _format_precision_config( - config=self.config, - precision=self.precision.precision, - loss_scale=self.loss_scale, - loss_scale_window=self.loss_scale_window, - min_loss_scale=self.min_loss_scale, - initial_scale_power=self.initial_scale_power, - hysteresis=self.hysteresis, - ) - - def _create_default_config( - self, - zero_optimization: bool, - zero_allow_untested_optimizer: bool, - logging_batch_size_per_gpu: Optional[int], - partition_activations: bool, - cpu_checkpointing: bool, - contiguous_memory_optimization: bool, - synchronize_checkpoint_boundary: bool, - offload_optimizer: bool, - offload_parameters: bool, - nvme_path: str, - offload_params_device: str, - params_buffer_count: int, - params_buffer_size: int, - max_in_cpu: int, - offload_optimizer_device: str, - optimizer_buffer_count: int, - pin_memory: bool, - block_size: int, - queue_depth: int, - single_submit: bool, - overlap_events: bool, - thread_count: int, - **zero_kwargs: Any, - ) -> dict: - cfg = { - "activation_checkpointing": { - "partition_activations": partition_activations, - "cpu_checkpointing": cpu_checkpointing, - "contiguous_memory_optimization": contiguous_memory_optimization, - "synchronize_checkpoint_boundary": synchronize_checkpoint_boundary, - }, - "aio": { - "block_size": block_size, - "queue_depth": queue_depth, - "single_submit": single_submit, - "overlap_events": overlap_events, - "thread_count": thread_count, - }, - } - if zero_optimization: - zero_config = zero_kwargs - - if offload_optimizer: - zero_config["offload_optimizer"] = { - "device": offload_optimizer_device, - "nvme_path": nvme_path, - "buffer_count": optimizer_buffer_count, - "pin_memory": pin_memory, - } - if offload_parameters: - zero_config["offload_param"] = { - "device": offload_params_device, - "nvme_path": nvme_path, - "buffer_count": params_buffer_count, - "buffer_size": params_buffer_size, - "max_in_cpu": max_in_cpu, - "pin_memory": pin_memory, - } - cfg.update({ - "zero_allow_untested_optimizer": zero_allow_untested_optimizer, - "zero_optimization": zero_config, - }) - if logging_batch_size_per_gpu: - cfg["train_micro_batch_size_per_gpu"] = logging_batch_size_per_gpu - return cfg - - def _restore_zero_state(self, module: Module, ckpt: Mapping[str, Any]) -> None: - """Overrides the normal load_state_dict behaviour in PyTorch to ensure we gather parameters that may be sharded - across processes before loading the state dictionary when using ZeRO stage 3. This is then automatically synced - across processes. - - Args: - ckpt: The ckpt file. - - """ - import deepspeed - - def load(module: torch.nn.Module, prefix: str = "") -> None: - missing_keys: list[str] = [] - unexpected_keys: list[str] = [] - error_msgs: list[str] = [] - state_dict = ckpt["state_dict"] - - # copy state_dict so _load_from_state_dict can modify it - metadata = getattr(state_dict, "_metadata", None) - state_dict = state_dict.copy() - if metadata is not None: - state_dict._metadata = metadata - - local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {}) - # because zero3 puts placeholders in model params, this context - # manager gathers (unpartitions) the params of the current layer, then loads from - # the state dict and then re-partitions them again - with deepspeed.zero.GatheredParameters(list(module.parameters(recurse=False)), modifier_rank=0): - if self.is_global_zero: - module._load_from_state_dict( - state_dict=state_dict, - prefix=prefix, - local_metadata=local_metadata, - strict=True, - missing_keys=missing_keys, - unexpected_keys=unexpected_keys, - error_msgs=error_msgs, - ) - - for name, child in module._modules.items(): - if child is not None: - load(child, prefix + name + ".") - - load(module, prefix="") - - def _load_config(self, config: Optional[Union[_PATH, dict[str, Any]]]) -> Optional[dict[str, Any]]: - if config is None and self.DEEPSPEED_ENV_VAR in os.environ: - rank_zero_info(f"Loading DeepSpeed config from set {self.DEEPSPEED_ENV_VAR} environment variable") - config = os.environ[self.DEEPSPEED_ENV_VAR] - if isinstance(config, (str, Path)): - if not os.path.isfile(config): - raise FileNotFoundError( - f"You passed in a path to a DeepSpeed config but the path does not exist: {config}" - ) - with open(config) as f: - config = json.load(f) - assert isinstance(config, dict) or config is None - return config - - -def _get_deepspeed_engines_from_state(state: dict[str, Any]) -> list["DeepSpeedEngine"]: - from deepspeed import DeepSpeedEngine - - modules = chain(*(module.modules() for module in state.values() if isinstance(module, Module))) - return [engine for engine in modules if isinstance(engine, DeepSpeedEngine)] - - -def _validate_state_keys(state: dict[str, Any]) -> None: - # DeepSpeed merges the client state into its internal engine state when saving, but it does not check for - # colliding keys from the user. We explicitly check it here: - deepspeed_internal_keys = { - "module", - "buffer_names", - "optimizer", - "param_shapes", - "lr_scheduler", - "sparse_tensor_module_names", - "skipped_steps", - "global_steps", - "global_samples", - "dp_world_size", - "mp_world_size", - "ds_config", - "ds_version", - } - colliding_keys = deepspeed_internal_keys.intersection(state.keys()) - if colliding_keys: - rank_zero_warn( - "Your state has keys that collide with DeepSpeed's internal engine state. This could result in your" - " values being overwritten by DeepSpeed. Consider changing the name of these keys to something else: " - + ", ".join(colliding_keys) - ) - - -def _validate_device_index_selection(parallel_devices: list[torch.device]) -> None: - selected_device_indices = [device.index for device in parallel_devices] - expected_device_indices = list(range(len(parallel_devices))) - if selected_device_indices != expected_device_indices: - raise RuntimeError( - f"The selected device indices {selected_device_indices!r} don't match the local rank values of processes." - " If you need to select GPUs at a specific index, set the `CUDA_VISIBLE_DEVICES` environment variable" - f" instead. For example: `CUDA_VISIBLE_DEVICES={','.join(str(i) for i in selected_device_indices)}`." - ) + return self.deepspeed_impl._setup_distributed() + @property + def config(self) -> dict[str, Any]: + return self.deepspeed_impl.config -def _is_deepspeed_checkpoint(path: Path) -> bool: - """Heuristic check whether the path points to a top-level DeepSpeed checkpoint directory.""" - return path.is_dir() and (path / "checkpoint").is_dir() - - -def _validate_checkpoint_directory(path: _PATH) -> None: - """Validates that the path points to a DeepSpeed checkpoint directory and suggests fixes for user error.""" - # Example DeepSpeed checkpoint directory: - # - # epoch=5-step=10999.ckpt - # ├── checkpoint - # │ ├── zero_pp_rank_0_mp_rank_00_model_states.pt - # │ ├── zero_pp_rank_0_mp_rank_00_optim_states.pt - # │ ├── zero_pp_rank_1_mp_rank_00_model_states.pt - # │ └── zero_pp_rank_1_mp_rank_00_optim_states.pt - # ├── latest - # └── zero_to_fp32.py - - path = Path(path) - path_is_ds_checkpoint = _is_deepspeed_checkpoint(path) - default_message = f"The provided path is not a valid DeepSpeed checkpoint: {path}" - - if not path_is_ds_checkpoint: - # Case 1: User may have accidentally passed the subfolder "checkpoint" - parent_is_ds_checkpoint = _is_deepspeed_checkpoint(path.parent) - if parent_is_ds_checkpoint: - raise FileNotFoundError( - f"{default_message}. It looks like you passed the path to a subfolder." - f" Try to load using this parent directory instead: {path.parent}" - ) - # Case 2: User may have accidentally passed the path to a file inside the "checkpoint" subfolder - parent_parent_is_ds_checkpoint = path.is_file() and _is_deepspeed_checkpoint(path.parent.parent) - if parent_parent_is_ds_checkpoint: - raise FileNotFoundError( - f"{default_message}. It looks like you passed the path to a file inside a DeepSpeed checkpoint folder." - f" Try to load using this parent directory instead: {path.parent.parent}" - ) - raise FileNotFoundError(default_message) - - -def _format_precision_config( - config: dict[str, Any], - precision: str, - loss_scale: float, - loss_scale_window: int, - min_loss_scale: int, - initial_scale_power: int, - hysteresis: int, -) -> None: - if "fp16" not in config and precision in ("16-mixed", "16-true"): - # FP16 is a DeepSpeed standalone AMP implementation - rank_zero_info("Enabling DeepSpeed FP16. Model parameters and inputs will be cast to `float16`.") - config["fp16"] = { - "enabled": True, - "loss_scale": loss_scale, - "initial_scale_power": initial_scale_power, - "loss_scale_window": loss_scale_window, - "hysteresis": hysteresis, - "min_loss_scale": min_loss_scale, - } - elif "bf16" not in config and precision in ("bf16-mixed", "bf16-true"): - rank_zero_info("Enabling DeepSpeed BF16. Model parameters and inputs will be cast to `bfloat16`.") - config["bf16"] = {"enabled": True} + @config.setter + def config(self, config: dict[str, Any]) -> None: + self.deepspeed_impl.config = config diff --git a/src/lightning/fabric/strategies/launchers/xla.py b/src/lightning/fabric/strategies/launchers/xla.py index 639de55805646..566312754339b 100644 --- a/src/lightning/fabric/strategies/launchers/xla.py +++ b/src/lightning/fabric/strategies/launchers/xla.py @@ -11,17 +11,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import queue -import time -from typing import TYPE_CHECKING, Any, Callable, Optional, Union +from typing import TYPE_CHECKING, Any, Callable, Union -import torch.multiprocessing as mp from typing_extensions import override -from lightning.fabric.accelerators.xla import _XLA_AVAILABLE from lightning.fabric.strategies.launchers.launcher import _Launcher -from lightning.fabric.strategies.launchers.multiprocessing import _GlobalStateSnapshot -from lightning.fabric.utilities.apply_func import move_data_to_device +from lightning.fabric.utilities.imports import _raise_enterprise_not_available if TYPE_CHECKING: from lightning.fabric.strategies import XLAFSDPStrategy, XLAStrategy @@ -45,15 +40,16 @@ class _XLALauncher(_Launcher): """ def __init__(self, strategy: Union["XLAStrategy", "XLAFSDPStrategy"]) -> None: - if not _XLA_AVAILABLE: - raise ModuleNotFoundError(str(_XLA_AVAILABLE)) - self._strategy = strategy - self._start_method = "fork" + super().__init__() + _raise_enterprise_not_available() + from pytorch_lightning_enterprise.strategies.xla.launcher import XLALauncherFabric as EnterpriseXLALauncher + + self.xla_impl = EnterpriseXLALauncher(strategy=strategy) @property @override def is_interactive_compatible(self) -> bool: - return True + return self.xla_impl.is_interactive_compatible @override def launch(self, function: Callable, *args: Any, **kwargs: Any) -> Any: @@ -68,61 +64,12 @@ def launch(self, function: Callable, *args: Any, **kwargs: Any) -> Any: **kwargs: Optional keyword arguments to be passed to the given function. """ - return_queue: Union[queue.Queue, mp.SimpleQueue] - return_queue = mp.Manager().Queue() - - import torch_xla.distributed.xla_multiprocessing as xmp - - spawn_kwargs = {} - nprocs = self._strategy.num_processes - if nprocs == 1: - # avoid warning: "Unsupported nprocs". If it's 1, it will call the launched function directly. - # otherwise it will use all devices - spawn_kwargs["nprocs"] = nprocs - - xmp.spawn( - self._wrapping_function, - args=(function, args, kwargs, return_queue), - start_method=self._start_method, - **spawn_kwargs, - ) - return return_queue.get() - - def _wrapping_function( - self, - # XLA's multiprocessing returns the global index, not the local index as torch's multiprocessing - # https://github.com/pytorch/xla/blob/v1.13.0/torch_xla/distributed/xla_multiprocessing.py#L321 - process_idx: int, - function: Callable, - args: Any, - kwargs: Any, - return_queue: Union[mp.SimpleQueue, queue.Queue], - global_states: Optional[_GlobalStateSnapshot] = None, - ) -> None: - import torch_xla.core.xla_model as xm - - if len(xm.get_xla_supported_devices()) > 1: - # `get_xla_supported_devices` in the spawned process returns the logical devices (2 for v2/v3 and 1 for v4) - # so when there's more than one (multithreading), objects need to be deep-copied - import copy - - function, args, kwargs = copy.deepcopy((function, args, kwargs)) - - results = function(*args, **kwargs) - - if self._strategy.local_rank == 0: - return_queue.put(move_data_to_device(results, "cpu")) - - _rank_teardown(self._strategy.local_rank) - - -def _rank_teardown(rank: int) -> None: - import torch_xla.core.xla_model as xm - - # Make all processes wait for each other before joining - # https://github.com/pytorch/xla/issues/1801#issuecomment-602799542 - xm.rendezvous("end-process") - # Ensure that the rank 0 process is the one exiting last - # https://github.com/pytorch/xla/issues/2190#issuecomment-641665358 - if rank == 0: - time.sleep(1) + return self.xla_impl.launch(function=function, *args, **kwargs) + + @property + def _start_method(self) -> str: + return self.xla_impl._start_method + + @_start_method.setter + def _start_method(self, start_method: str) -> None: + self.xla_impl._start_method = start_method diff --git a/src/lightning/fabric/strategies/single_xla.py b/src/lightning/fabric/strategies/single_xla.py index ba2fce91f1146..59b38ff6157f6 100644 --- a/src/lightning/fabric/strategies/single_xla.py +++ b/src/lightning/fabric/strategies/single_xla.py @@ -13,15 +13,14 @@ # limitations under the License. from typing import Optional -import torch from typing_extensions import override from lightning.fabric.accelerators import Accelerator -from lightning.fabric.accelerators.xla import _XLA_AVAILABLE from lightning.fabric.plugins import CheckpointIO, Precision, XLAPrecision from lightning.fabric.plugins.io.xla import XLACheckpointIO from lightning.fabric.strategies import _StrategyRegistry from lightning.fabric.strategies.single_device import SingleDeviceStrategy +from lightning.fabric.utilities.imports import _raise_enterprise_not_available from lightning.fabric.utilities.types import _DEVICE @@ -35,20 +34,16 @@ def __init__( checkpoint_io: Optional[XLACheckpointIO] = None, precision: Optional[XLAPrecision] = None, ): - if not _XLA_AVAILABLE: - raise ModuleNotFoundError(str(_XLA_AVAILABLE)) - if isinstance(device, torch.device): - # unwrap the `torch.device` in favor of `xla_device` - device = device.index - - import torch_xla.core.xla_model as xm + _raise_enterprise_not_available() + from pytorch_lightning_enterprise.strategies.xla.single import validate_xla_strategy super().__init__( accelerator=accelerator, - device=xm.xla_device(device), + device=device, checkpoint_io=checkpoint_io, precision=precision, ) + validate_xla_strategy(strategy=self, device=device) @property @override diff --git a/src/lightning/fabric/strategies/xla.py b/src/lightning/fabric/strategies/xla.py index 3a571fef37f00..5a877a5c0dcc8 100644 --- a/src/lightning/fabric/strategies/xla.py +++ b/src/lightning/fabric/strategies/xla.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import io from typing import TYPE_CHECKING, Any, Callable, Optional, Union import torch @@ -22,14 +21,13 @@ from typing_extensions import override from lightning.fabric.accelerators import Accelerator -from lightning.fabric.accelerators.xla import _XLA_GREATER_EQUAL_2_1 from lightning.fabric.plugins import CheckpointIO, Precision, XLAPrecision from lightning.fabric.plugins.environments import XLAEnvironment from lightning.fabric.plugins.io.xla import XLACheckpointIO from lightning.fabric.strategies import ParallelStrategy, _StrategyRegistry from lightning.fabric.strategies.launchers.xla import _XLALauncher from lightning.fabric.strategies.strategy import TBroadcast -from lightning.fabric.utilities.rank_zero import rank_zero_only +from lightning.fabric.utilities.imports import _raise_enterprise_not_available from lightning.fabric.utilities.types import _PATH, ReduceOp if TYPE_CHECKING: @@ -55,22 +53,19 @@ def __init__( checkpoint_io=checkpoint_io, precision=precision, ) - self._backward_sync_control = None # XLA synchronizes gradients in the optimizer.step() call - self._launched = False - self._sync_module_states = sync_module_states + _raise_enterprise_not_available() + from pytorch_lightning_enterprise.strategies.xla.ddp import XLAStrategyFabric as EnterpriseXLAStrategy + + self.xla_strategy_impl = EnterpriseXLAStrategy(outer_object=self, sync_module_states=sync_module_states) @property @override def root_device(self) -> torch.device: - if not self._launched: - raise RuntimeError("Accessing the XLA device before processes have spawned is not allowed.") - import torch_xla.core.xla_model as xm - - return xm.xla_device() + return self.xla_strategy_impl.root_device @property def num_processes(self) -> int: - return len(self.parallel_devices) if self.parallel_devices is not None else 0 + return self.xla_strategy_impl.num_processes @property @override @@ -107,22 +102,22 @@ def precision(self, precision: Optional[Precision]) -> None: @property @override def global_rank(self) -> int: - return super().global_rank if self._launched else 0 + return self.xla_strategy_impl.global_rank @property @override def local_rank(self) -> int: - return super().local_rank if self._launched else 0 + return self.xla_strategy_impl.local_rank @property @override def node_rank(self) -> int: - return super().node_rank if self._launched else 0 + return self.xla_strategy_impl.node_rank @property @override def world_size(self) -> int: - return super().world_size if self._launched else 1 + return self.xla_strategy_impl.world_size @override def _configure_launcher(self) -> None: @@ -130,48 +125,19 @@ def _configure_launcher(self) -> None: @override def setup_environment(self) -> None: - assert self.parallel_devices is not None - if len(self.parallel_devices) == 1: - # spawning only 1 device with PjRT is not supported: - # https://github.com/Lightning-AI/pytorch-lightning/pull/17408#discussion_r1170671732 - raise NotImplementedError( - f"The {type(self).__name__} does not support running on a single device with the PjRT runtime." - " Try using all devices or the `SingleDeviceXLAStrategy` strategy" - ) - - self._launched = True - rank_zero_only.rank = self.global_rank - super().setup_environment() + return self.xla_strategy_impl.setup_environment() @override def setup_module(self, module: Module) -> Module: - if self._sync_module_states: - if _XLA_GREATER_EQUAL_2_1: - from torch_xla.core.xla_model import broadcast_master_param - else: - from torch_xla.experimental.pjrt import broadcast_master_param - - broadcast_master_param(module) - - return module + return self.xla_strategy_impl.setup_module(module=module) @override def module_to_device(self, module: Module) -> None: - module.to(self.root_device) + return self.xla_strategy_impl.module_to_device(module=module) @override def process_dataloader(self, dataloader: DataLoader) -> "MpDeviceLoader": - from torch_xla.distributed.parallel_loader import MpDeviceLoader - - if isinstance(dataloader, MpDeviceLoader): - # dataloader is already wrapped by MpDeviceLoader - return dataloader - - dataloader = MpDeviceLoader(dataloader, self.root_device) - # Mimic interface to torch.utils.data.DataLoader - dataloader.dataset = dataloader._loader.dataset - dataloader.batch_sampler = getattr(dataloader._loader, "batch_sampler", None) - return dataloader + return self.xla_strategy_impl.process_dataloader(dataloader=dataloader) @override def all_gather(self, tensor: Tensor, group: Optional[Any] = None, sync_grads: bool = False) -> Tensor: @@ -185,92 +151,21 @@ def all_gather(self, tensor: Tensor, group: Optional[Any] = None, sync_grads: bo A tensor of shape (world_size, ...) """ - if not self._launched: - return tensor - if not isinstance(tensor, Tensor): - raise NotImplementedError( - f"`{type(self).__name__}.all_gather` is only implemented for tensors. Given {tensor}" - ) - if tensor.dim() == 0: - tensor = tensor.unsqueeze(0) - original_device = tensor.device - tensor = tensor.to(self.root_device) - - import torch_xla.core.functions as xf - import torch_xla.core.xla_model as xm - - tensor = xf.all_gather(tensor) if sync_grads else xm.all_gather(tensor) - tensor = tensor.to(original_device) - return tensor + return self.xla_strategy_impl.all_gather(tensor=tensor, group=group, sync_grads=sync_grads) @override def all_reduce( self, output: Union[Tensor, Any], group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = None ) -> Tensor: - if not isinstance(output, Tensor): - output = torch.tensor(output, device=self.root_device) - - invalid_reduce_op = isinstance(reduce_op, ReduceOp) and reduce_op != ReduceOp.SUM - invalid_reduce_op_str = isinstance(reduce_op, str) and reduce_op.lower() not in ("sum", "mean", "avg") - if invalid_reduce_op or invalid_reduce_op_str: - raise ValueError( - "Currently, the XLAStrategy only supports `sum`, `mean`, `avg` for the reduce operation, got:" - f" {reduce_op}" - ) - import torch_xla.core.xla_model as xm - - output = xm.mesh_reduce("reduce", output, sum) - - if isinstance(reduce_op, str) and reduce_op.lower() in ("avg", "mean"): - output = output / self.world_size - - return output + return self.xla_strategy_impl.all_reduce(output=output, group=group, reduce_op=reduce_op) @override def barrier(self, name: Optional[str] = None, *args: Any, **kwargs: Any) -> None: - if not self._launched: - return - import torch_xla.core.xla_model as xm - - if name is None: - # `None` is not supported: "TypeError: _xla_rendezvous(): incompatible function arguments" - name = "" - xm.rendezvous(name) + return self.xla_strategy_impl.barrier(name=name, *args, **kwargs) @override def broadcast(self, obj: TBroadcast, src: int = 0) -> TBroadcast: - if not self._launched: - return obj - - import torch_xla.core.xla_model as xm - - is_tensor = isinstance(obj, Tensor) - if is_tensor: - if obj.dim() == 0: - obj = obj.unsqueeze(0) - original_device = obj.device - # XLA distributed requires that the data is on the XLA device - obj = obj.to(self.root_device) - else: - # support for arbitrary pickle-ables - buffer = io.BytesIO() - torch.save(obj, buffer) - obj = torch.tensor( # type: ignore[assignment] - bytearray(buffer.getbuffer()), device=self.root_device, dtype=torch.float - ) - - obj = [obj] - xm.collective_broadcast(obj, root_ordinal=src) - obj = obj[0] - - if not is_tensor: - # this will preserve the dtype and device of any tensors - buffer = io.BytesIO(obj.cpu().byte().numpy()) - obj = torch.load(buffer) - else: - obj = obj.to(original_device) - - return obj + return self.xla_strategy_impl.broadcast(obj=obj, src=src) @override def save_checkpoint( @@ -291,12 +186,9 @@ def save_checkpoint( boolean indicating whether the given parameter should be saved (``True``) or filtered out (``False``). """ - import torch_xla.core.xla_model as xm - - # sync any pending lazy tensors on all ranks before saving to prevent potential collective hangs - xm.mark_step() - # save on global rank zero only - super().save_checkpoint(path, state, storage_options=storage_options, filter=filter) + return self.xla_strategy_impl.save_checkpoint( + path=path, state=state, storage_options=storage_options, filter=filter + ) @classmethod @override diff --git a/src/lightning/fabric/strategies/xla_fsdp.py b/src/lightning/fabric/strategies/xla_fsdp.py index 51b528eff26ff..f54267ca59a72 100644 --- a/src/lightning/fabric/strategies/xla_fsdp.py +++ b/src/lightning/fabric/strategies/xla_fsdp.py @@ -11,10 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import io -from contextlib import AbstractContextManager, ExitStack, nullcontext -from functools import partial -from pathlib import Path +from contextlib import AbstractContextManager from typing import TYPE_CHECKING, Any, Callable, Literal, Optional, Union import torch @@ -25,22 +22,16 @@ from typing_extensions import override from lightning.fabric.accelerators import Accelerator -from lightning.fabric.accelerators.xla import _XLA_AVAILABLE from lightning.fabric.plugins import CheckpointIO, Precision, XLAPrecision from lightning.fabric.plugins.environments import XLAEnvironment from lightning.fabric.plugins.io.xla import XLACheckpointIO from lightning.fabric.strategies import ParallelStrategy, _StrategyRegistry -from lightning.fabric.strategies.fsdp import _apply_filter from lightning.fabric.strategies.launchers.xla import _XLALauncher from lightning.fabric.strategies.strategy import ( TBroadcast, - _BackwardSyncControl, _Sharded, - _validate_keys_for_strict_loading, ) -from lightning.fabric.utilities.cloud_io import get_filesystem -from lightning.fabric.utilities.init import _EmptyInit -from lightning.fabric.utilities.rank_zero import rank_zero_only, rank_zero_warn +from lightning.fabric.utilities.imports import _raise_enterprise_not_available from lightning.fabric.utilities.types import _PATH, Optimizable, ReduceOp if TYPE_CHECKING: @@ -93,8 +84,6 @@ def __init__( sequential_save: bool = False, **kwargs: Any, ) -> None: - if not _XLA_AVAILABLE: - raise ModuleNotFoundError(str(_XLA_AVAILABLE)) super().__init__( accelerator=accelerator, parallel_devices=parallel_devices, @@ -102,27 +91,28 @@ def __init__( checkpoint_io=checkpoint_io, precision=precision, ) - self._backward_sync_control = _XLAFSDPBackwardSyncControl() + _raise_enterprise_not_available() + from pytorch_lightning_enterprise.strategies.xla.fsdp import ( + XLAFSDPStrategyFabric as EnterpriseXLAFSDPStrategy, + ) - self._auto_wrap_policy = auto_wrap_policy - self._activation_checkpointing_policy = activation_checkpointing_policy - self._fsdp_kwargs = kwargs - self._state_dict_type = state_dict_type - self._sequential_save = sequential_save - self._launched = False + self.xla_fsdp_impl = EnterpriseXLAFSDPStrategy( + outer_object=self, + auto_wrap_policy=auto_wrap_policy, + activation_checkpointing_policy=activation_checkpointing_policy, + state_dict_type=state_dict_type, + sequential_save=sequential_save, + **kwargs, + ) @property @override def root_device(self) -> torch.device: - if not self._launched: - raise RuntimeError("Accessing the XLA device before processes have spawned is not allowed.") - import torch_xla.core.xla_model as xm - - return xm.xla_device() + return self.xla_fsdp_impl.root_device @property def num_processes(self) -> int: - return len(self.parallel_devices) if self.parallel_devices is not None else 0 + return self.xla_fsdp_impl.num_processes @property @override @@ -159,22 +149,22 @@ def precision(self, precision: Optional[Precision]) -> None: @property @override def global_rank(self) -> int: - return super().global_rank if self._launched else 0 + return self.xla_fsdp_impl.global_rank @property @override def local_rank(self) -> int: - return super().local_rank if self._launched else 0 + return self.xla_fsdp_impl.local_rank @property @override def node_rank(self) -> int: - return super().node_rank if self._launched else 0 + return self.xla_fsdp_impl.node_rank @property @override def world_size(self) -> int: - return super().world_size if self._launched else 1 + return self.xla_fsdp_impl.world_size @override def _configure_launcher(self) -> None: @@ -182,108 +172,40 @@ def _configure_launcher(self) -> None: @override def setup_environment(self) -> None: - assert self.parallel_devices is not None - if len(self.parallel_devices) == 1: - # spawning only 1 device with PjRT is not supported: - # https://github.com/Lightning-AI/pytorch-lightning/pull/17408#discussion_r1170671732 - raise NotImplementedError( - f"The {type(self).__name__} does not support running on a single device with the PjRT runtime." - " Try using all devices or the `SingleDeviceXLAStrategy` strategy" - ) - - self._launched = True - rank_zero_only.rank = self.global_rank - super().setup_environment() + return self.xla_fsdp_impl.setup_environment() @override def setup_module_and_optimizers( self, module: Module, optimizers: list[Optimizer], scheduler: Optional["_LRScheduler"] = None ) -> tuple[Module, list[Optimizer], Optional["_LRScheduler"]]: - """Returns NotImplementedError since for XLAFSDP optimizer setup must happen after module setup.""" - raise NotImplementedError( - f"The `{type(self).__name__}` does not support the joint setup of module and optimizer(s)." - " Please do it in this order: Create the model, call `setup_module`, create the optimizer," - " call `setup_optimizer`." - ) + return self.xla_fsdp_impl.setup_module_and_optimizers(module=module, optimizers=optimizers, scheduler=scheduler) @override def setup_module(self, module: Module) -> Module: - from torch_xla.distributed.fsdp import XlaFullyShardedDataParallel as XLAFSDP - - kwargs = self._parse_fsdp_kwargs() - if any(isinstance(mod, XLAFSDP) for mod in module.modules()) and "auto_wrap_policy" in kwargs: - rank_zero_warn( - "A XLAFSDP `auto_wrap_policy` is set, but at least one submodule is already wrapped." - " The policy will be ignored." - ) - del kwargs["auto_wrap_policy"] - # XLA FSDP requires that the root is wrapped, even if submodules are already wrapped - if not isinstance(module, XLAFSDP): - module = XLAFSDP(module=module, **kwargs) - return module + return self.xla_fsdp_impl.setup_module(module=module) @override def module_to_device(self, module: Module) -> None: - pass + return self.xla_fsdp_impl.module_to_device(module=module) def module_init_context(self, empty_init: Optional[bool] = None) -> AbstractContextManager: - precision_init_ctx = self.precision.module_init_context() - module_sharded_ctx = self.module_sharded_context() - stack = ExitStack() - stack.enter_context(_EmptyInit(enabled=bool(empty_init))) - stack.enter_context(precision_init_ctx) - stack.enter_context(module_sharded_ctx) - return stack + return self.xla_fsdp_impl.module_init_context(empty_init=empty_init) @override def module_sharded_context(self) -> AbstractContextManager: - return nullcontext() + return self.xla_fsdp_impl.module_sharded_context() @override def process_dataloader(self, dataloader: DataLoader) -> "MpDeviceLoader": - from torch_xla.distributed.parallel_loader import MpDeviceLoader - - if isinstance(dataloader, MpDeviceLoader): - # dataloader is already wrapped by MpDeviceLoader - return dataloader - - dataloader = MpDeviceLoader(dataloader, self.root_device) - # Mimic interface to torch.utils.data.DataLoader - dataloader.dataset = dataloader._loader.dataset - dataloader.batch_sampler = getattr(dataloader._loader, "batch_sampler", None) - return dataloader + return self.xla_fsdp_impl.process_dataloader(dataloader=dataloader) @override def setup_optimizer(self, optimizer: Optimizer) -> Optimizer: - """Set up an optimizer for a model wrapped with XLAFSDP. - - This setup method doesn't modify the optimizer or wrap the optimizer. The only thing it currently does is verify - that the optimizer was created after the model was wrapped with :meth:`setup_module` with a reference to the - flattened parameters. - - """ - if any(getattr(p, "_is_sharded", False) for group in optimizer.param_groups for p in group["params"]): - return optimizer - raise ValueError( - "The optimizer does not seem to reference any XLAFSDP parameters. HINT: Make sure to create the optimizer" - " after setting up the model." - ) + return self.xla_fsdp_impl.setup_optimizer(optimizer=optimizer) @override def optimizer_step(self, optimizer: Optimizable, **kwargs: Any) -> Any: - """Overrides default tpu optimizer_step since FSDP should not call `torch_xla.core.xla_model.optimizer_step`. - Performs the actual optimizer step. - - Args: - optimizer: the optimizer performing the step - **kwargs: Any extra arguments to ``optimizer.step`` - - """ - loss = optimizer.step(**kwargs) - import torch_xla.core.xla_model as xm - - xm.mark_step() - return loss + return self.xla_fsdp_impl.optimizer_step(optimizer=optimizer, **kwargs) @override def clip_gradients_norm( @@ -295,17 +217,18 @@ def clip_gradients_norm( error_if_nonfinite: bool = True, ) -> Tensor: """Clip gradients by norm.""" - self.precision.unscale_gradients(optimizer) - assert callable(module.clip_grad_norm_) - return module.clip_grad_norm_(max_norm=max_norm, norm_type=norm_type) + return self.xla_fsdp_impl.clip_gradients_norm( + module=module, + optimizer=optimizer, + max_norm=max_norm, + norm_type=norm_type, + error_if_nonfinite=error_if_nonfinite, + ) @override def clip_gradients_value(self, module: Module, optimizer: Optimizer, clip_val: Union[float, int]) -> None: """Clip gradients by value.""" - raise NotImplementedError( - "XLA's FSDP strategy does not support to clip gradients by value." - " Consider clipping by norm instead or choose another strategy!" - ) + return self.xla_fsdp_impl.clip_gradients_value(module=module, optimizer=optimizer, clip_val=clip_val) @override def all_gather(self, tensor: Tensor, group: Optional[Any] = None, sync_grads: bool = False) -> Tensor: @@ -319,92 +242,21 @@ def all_gather(self, tensor: Tensor, group: Optional[Any] = None, sync_grads: bo A tensor of shape (world_size, ...) """ - if not self._launched: - return tensor - if not isinstance(tensor, Tensor): - raise NotImplementedError( - f"`{type(self).__name__}.all_gather` is only implemented for tensors. Given {tensor}" - ) - if tensor.dim() == 0: - tensor = tensor.unsqueeze(0) - original_device = tensor.device - tensor = tensor.to(self.root_device) - - import torch_xla.core.functions as xf - import torch_xla.core.xla_model as xm - - tensor = xf.all_gather(tensor) if sync_grads else xm.all_gather(tensor) - tensor = tensor.to(original_device) - return tensor + return self.xla_fsdp_impl.all_gather(tensor=tensor, group=group, sync_grads=sync_grads) @override def all_reduce( self, output: Union[Tensor, Any], group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = None ) -> Tensor: - if not isinstance(output, Tensor): - output = torch.tensor(output, device=self.root_device) - - invalid_reduce_op = isinstance(reduce_op, ReduceOp) and reduce_op != ReduceOp.SUM - invalid_reduce_op_str = isinstance(reduce_op, str) and reduce_op.lower() not in ("sum", "mean", "avg") - if invalid_reduce_op or invalid_reduce_op_str: - raise ValueError( - "Currently, the XLAFSDPStrategy only supports `sum`, `mean`, `avg` for the reduce operation, got:" - f" {reduce_op}" - ) - import torch_xla.core.xla_model as xm - - output = xm.mesh_reduce("reduce", output, sum) - - if isinstance(reduce_op, str) and reduce_op.lower() in ("avg", "mean"): - output = output / self.world_size - - return output + return self.xla_fsdp_impl.all_reduce(output=output, group=group, reduce_op=reduce_op) @override def barrier(self, name: Optional[str] = None, *args: Any, **kwargs: Any) -> None: - if not self._launched: - return - import torch_xla.core.xla_model as xm - - if name is None: - # `None` is not supported: "TypeError: _xla_rendezvous(): incompatible function arguments" - name = "" - xm.rendezvous(name) + return self.xla_fsdp_impl.barrier(name=name, *args, **kwargs) @override def broadcast(self, obj: TBroadcast, src: int = 0) -> TBroadcast: - if not self._launched: - return obj - - import torch_xla.core.xla_model as xm - - is_tensor = isinstance(obj, Tensor) - if is_tensor: - if obj.dim() == 0: - obj = obj.unsqueeze(0) - original_device = obj.device - # XLA distributed requires that the data is on the XLA device - obj = obj.to(self.root_device) - else: - # support for arbitrary pickle-ables - buffer = io.BytesIO() - torch.save(obj, buffer) - obj = torch.tensor( # type: ignore[assignment] - bytearray(buffer.getbuffer()), device=self.root_device, dtype=torch.float - ) - - obj = [obj] - xm.collective_broadcast(obj, root_ordinal=src) - obj = obj[0] - - if not is_tensor: - # this will preserve the dtype and device of any tensors - buffer = io.BytesIO(obj.cpu().byte().numpy()) - obj = torch.load(buffer) - else: - obj = obj.to(original_device) - - return obj + return self.xla_fsdp_impl.broadcast(obj=obj, src=src) @override def save_checkpoint( @@ -421,93 +273,8 @@ def save_checkpoint( consolidated checkpoint combining all of the sharded checkpoints. """ - # broadcast the path from rank 0 to ensure all the states are saved in a common path - path = Path(self.broadcast(path)) - if path.is_dir() and any(path.iterdir()): - raise FileExistsError(f"The checkpoint directory already exists and is not empty: {path}") - from torch_xla.distributed.fsdp import XlaFullyShardedDataParallel as XLAFSDP - - modules = [module for module in state.values() if isinstance(module, XLAFSDP)] - if len(modules) == 0: - raise ValueError( - "Could not find a XLAFSDP model in the provided checkpoint state. Please provide the model as" - " part of the state like so: `save_checkpoint(..., state={'model': model, ...})`. Make sure" - " you set up the model (and optimizers if any) through the strategy before saving the checkpoint." - ) - if len(modules) > 1: - raise ValueError( - "Found multiple XLAFSDP modules in the given state. Saving checkpoints with FSDP is" - " currently limited to a single model per checkpoint. To save multiple models, call the" - " save method for each model separately with a different path." - ) - import torch_xla.core.xla_model as xm - - # ensure model parameters are updated - xm.mark_step() - - parallel_devices = self.parallel_devices - assert parallel_devices is not None - if self._sequential_save: - # each host runs this in parallel, but the ranks in the host run it sequentially - for rank in range(len(parallel_devices)): - if rank == self.local_rank: - self._save_checkpoint_shard(path, state, storage_options, filter) - self.barrier(f"wait-for-{rank}-save") - else: - self._save_checkpoint_shard(path, state, storage_options, filter) - - if self._state_dict_type == "full": - ckpt_prefix = str(path / "checkpoint") - ckpt_suffix = "_rank-*-of-*.pth" - if len(parallel_devices) != self.world_size: # multihost - raise OSError( - "Multihost setups do not have a shared filesystem, so the checkpoint shards cannot be consolidated" - " into a single checkpoint after saving them. Please switch to" - " `XLAFSDPStrategy(state_dict_type='sharded')`. TIP: You can consolidate them manually by getting" - " them together into a single directory and running `python -m" - f" torch_xla.distributed.fsdp.consolidate_sharded_ckpts --ckpt_prefix {ckpt_prefix!r} --ckpt_suffix" - f" {ckpt_suffix!r} --save_path 'path/to/consolidated.ckpt'`." - ) - - from torch_xla.distributed.fsdp import consolidate_sharded_model_checkpoints - - self.barrier("before_ckpt_consolidation") - if self.is_global_zero: - save_path = path.parent / "consolidated.ckpt" - # save consolidated checkpoint separate to the shards - consolidate_sharded_model_checkpoints(ckpt_prefix, ckpt_suffix, str(save_path)) - # remove the shards directory - self.checkpoint_io.remove_checkpoint(path) - # mv the consolidated checkpoint where the user would expect it - get_filesystem(save_path).mv(str(save_path), str(path)) - self.barrier("after_ckpt_consolidation") - - def _save_checkpoint_shard( - self, - path: Path, - state: dict[str, Union[Module, Optimizer, Any]], - storage_options: Optional[Any], - filter: Optional[dict[str, Callable[[str, Any], bool]]], - ) -> None: - from torch_xla.distributed.fsdp import XlaFullyShardedDataParallel as XLAFSDP - - converted_state: dict[str, Any] = {} - for key, obj in state.items(): - # convert the state - if isinstance(obj, Module) and isinstance(obj, XLAFSDP): - converted = obj.state_dict() - # add shard_metadata to state - converted_state["shard_metadata"] = obj.get_shard_metadata() - elif isinstance(obj, Optimizer): - converted = obj.state_dict() - else: - converted = obj - _apply_filter(key, filter or {}, converted, converted_state) - - self.checkpoint_io.save_checkpoint( - converted_state, - path / f"checkpoint_rank-{self.global_rank:08d}-of-{self.world_size:08d}.pth", - storage_options=storage_options, + return self.xla_fsdp_impl.save_checkpoint( + path=path, state=state, storage_options=storage_options, filter=filter ) @override @@ -524,164 +291,9 @@ def load_checkpoint( directory of multiple files rather than a single file. """ - if not state: - raise ValueError( - f"Got `XLAFSDPStrategy.load_checkpoint(..., state={state!r})` but a state with at least " - " a model instance to reload is required. Pass it in like so:" - " `FSDPStrategy.load_checkpoint(..., state={'model': model, ...})`" - ) - - # broadcast the path from rank 0 to ensure all the states are loaded from a common path - path = Path(self.broadcast(path)) - - if isinstance(state, (Module, Optimizer)): - raise NotImplementedError( - "Loading a single module or optimizer object from a checkpoint" - " is not supported yet with the XLAFSDP strategy." - ) - - from torch_xla.distributed.fsdp import XlaFullyShardedDataParallel as XLAFSDP - - modules = {key: module for key, module in state.items() if isinstance(module, XLAFSDP)} - optimizers = {key: optim for key, optim in state.items() if isinstance(optim, Optimizer)} - if self._state_dict_type == "sharded": - file = path / f"checkpoint_rank-{self.global_rank:08d}-of-{self.world_size:08d}.pth" - if not file.is_file(): - raise ValueError( - f"The path {str(file)!r} does not point to valid sharded checkpoints. Make sure the path points to" - " a directory with XLAFSDP checkpoint shards." - ) - if len(modules) == 0: - raise ValueError( - "Could not find a XLAFSDP model in the provided checkpoint state. Please provide the model as" - " part of the state like so: `load_checkpoint(..., state={'model': model, ...})`. Make sure" - " you set up the model (and optimizers if any) through the strategy before loading the checkpoint." - ) - if len(modules) > 1: - raise ValueError( - "Found multiple XLAFSDP modules in the given state. Loading checkpoints with FSDP is" - " currently limited to a single model per checkpoint. To load multiple models, call the" - " load method for each model separately with a different path." - ) - - _, module = list(modules.items())[0] - sharded_ckpt = torch.load(file) - - module.load_state_dict(sharded_ckpt["model"], strict=strict) - for opt_key, opt in optimizers.items(): - opt.load_state_dict(sharded_ckpt[opt_key]) - - # Load anything leftover from sharded_ckpt - loaded_metadata_keys = sharded_ckpt.keys() - modules.keys() - optimizers.keys() - requested_metadata_keys = state.keys() - modules.keys() - optimizers.keys() - _validate_keys_for_strict_loading(requested_metadata_keys, loaded_metadata_keys, strict=strict) - for key in requested_metadata_keys: - if key in loaded_metadata_keys: - state[key] = sharded_ckpt[key] - loaded_metadata_keys.remove(key) - - metadata = {} - if len(loaded_metadata_keys): - for key in loaded_metadata_keys: - metadata[key] = sharded_ckpt[key] - - # remove "shard_metadata" that is loaded in - if "shard_metadata" in metadata: - metadata.pop("shard_metadata") - - return metadata - - if self._state_dict_type == "full": - if not path.is_file(): - raise ValueError( - f"The path {str(path)!r} does not point to a valid full checkpoint. Make sure the path points to a" - " directory with a full XLAFSDP checkpoint." - ) - if len(optimizers) > 0 or len(state.keys() - modules.keys() - optimizers.keys()) > 0: - rank_zero_warn( - "Loading a full checkpoint will only load the full model." - " The optimizer and any additional metadata are not included." - ) - if len(modules) > 0: - raise ValueError( - "Found a XLAFSDP model in the provided checkpoint state." - " Please provide the model without any XLAFSDP wrapper." - ) - if "model" not in state or not isinstance(model := state["model"], torch.nn.Module): - raise NotImplementedError("XLAFSDP only supports a single model instance with 'model' as the key.") - full_ckpt = torch.load(path, weights_only=weights_only) - model.load_state_dict(full_ckpt.pop("model"), strict=strict) - return full_ckpt - - raise ValueError(f"Unknown state_dict_type: {self._state_dict_type}") + return self.xla_fsdp_impl.load_checkpoint(path=path, state=state, strict=strict, weights_only=weights_only) @classmethod @override def register_strategies(cls, strategy_registry: _StrategyRegistry) -> None: strategy_registry.register("xla_fsdp", cls, description=cls.__name__) - - def _parse_fsdp_kwargs(self) -> dict: - # this needs to be delayed because `self.precision` isn't available at init - kwargs = self._fsdp_kwargs.copy() - precision = self.precision - if isinstance(precision, XLAPrecision): - # the `compute_dtype` will be passed to the `auto_wrapper_callable` automatically, so we don't need to pass - # it when creating it - kwargs.setdefault("compute_dtype", precision._desired_dtype) - kwargs = _auto_wrap_policy_kwargs(self._auto_wrap_policy, kwargs) - return _activation_checkpointing_kwargs(self._activation_checkpointing_policy, kwargs) - - -def _auto_wrap_policy_kwargs(policy: Optional["_POLICY"], kwargs: dict) -> dict: - if policy is None: - return kwargs - if isinstance(policy, set): - from torch_xla.distributed.fsdp.wrap import transformer_auto_wrap_policy - - # this is not transformer specific despite the name - policy = partial(transformer_auto_wrap_policy, transformer_layer_cls=policy) - kwargs["auto_wrap_policy"] = policy - return kwargs - - -def _activation_checkpointing_auto_wrapper(policy: _POLICY_SET, module: Module, *args: Any, **kwargs: Any) -> Module: - from torch_xla.distributed.fsdp import XlaFullyShardedDataParallel as XLAFSDP - from torch_xla.distributed.fsdp import checkpoint_module - - module = checkpoint_module(module) if isinstance(module, tuple(policy)) else module - return XLAFSDP(module, *args, **kwargs) - - -def _activation_checkpointing_kwargs(policy: Optional[_POLICY_SET], kwargs: dict) -> dict: - if not policy: - return kwargs - if "auto_wrapper_callable" in kwargs: - raise ValueError( - "You cannot set both `auto_wrapper_callable` and `activation_checkpointing_policy`. Choose one" - ) - if not isinstance(policy, set): - raise TypeError( - f"`activation_checkpointing_policy` must be a set, found {policy}. You can try defining and" - " passing `auto_wrapper_callable` instead." - ) - auto_wrapper_callable = partial(_activation_checkpointing_auto_wrapper, policy) - kwargs["auto_wrapper_callable"] = auto_wrapper_callable - return kwargs - - -class _XLAFSDPBackwardSyncControl(_BackwardSyncControl): - @override - def no_backward_sync(self, module: Module, enabled: bool) -> AbstractContextManager: - """Blocks gradient synchronization inside the :class:`~torch_xla.distributed.fsdp.XlaFullyShardedDataParallel` - wrapper.""" - if not enabled: - return nullcontext() - from torch_xla.distributed.fsdp import XlaFullyShardedDataParallel as XLAFSDP - - if not isinstance(module, XLAFSDP): - raise TypeError( - "Blocking backward sync is only possible if the module passed to" - f" `{self.__class__.__name__}.no_backward_sync` is wrapped in `XlaFullyShardedDataParallel`." - f" Got: {module.__class__.__name__}." - ) - return module.no_sync() diff --git a/src/lightning/fabric/utilities/imports.py b/src/lightning/fabric/utilities/imports.py index ac706cc403a5d..dd815d9dc8f5f 100644 --- a/src/lightning/fabric/utilities/imports.py +++ b/src/lightning/fabric/utilities/imports.py @@ -40,3 +40,22 @@ _TORCH_GREATER_EQUAL_2_6 = compare_version("torch", operator.ge, "2.6.0") _TORCHMETRICS_GREATER_EQUAL_1_0_0 = compare_version("torchmetrics", operator.ge, "1.0.0") _PYTHON_GREATER_EQUAL_3_10_0 = (sys.version_info.major, sys.version_info.minor) >= (3, 10) + +_WANDB_AVAILABLE = RequirementCache("wandb>=0.12.10") +_COMET_AVAILABLE = RequirementCache("comet-ml>=3.44.4") +_MLFLOW_AVAILABLE = RequirementCache("mlflow>=1.0.0") +_MLFLOW_SYNCHRONOUS_AVAILABLE = RequirementCache("mlflow>=2.8.0") +_NEPTUNE_AVAILABLE = RequirementCache("neptune>=1.0") + +_DEEPSPEED_AVAILABLE = RequirementCache("deepspeed") +_DEEPSPEED_GREATER_EQUAL_0_16 = RequirementCache("deepspeed>=0.16.0") +_ENTERPRISE_AVAILABLE = RequirementCache("pytorch_lightning_enterprise") +_TRANSFORMER_ENGINE_AVAILABLE = RequirementCache("transformer_engine>=0.11.0") + + +def _raise_enterprise_not_available() -> None: + if not _ENTERPRISE_AVAILABLE: + raise ModuleNotFoundError( + "pytorch_lightning_enterprise is required to use the XLA accelerator. " + "Install it with `pip install pytorch-lightning-enterprise`" + ) diff --git a/src/lightning/fabric/utilities/testing/_runif.py b/src/lightning/fabric/utilities/testing/_runif.py index d085e4138d742..92521362a2885 100644 --- a/src/lightning/fabric/utilities/testing/_runif.py +++ b/src/lightning/fabric/utilities/testing/_runif.py @@ -23,8 +23,7 @@ from lightning.fabric.accelerators import XLAAccelerator from lightning.fabric.accelerators.cuda import num_cuda_devices from lightning.fabric.accelerators.mps import MPSAccelerator -from lightning.fabric.strategies.deepspeed import _DEEPSPEED_AVAILABLE -from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_4 +from lightning.fabric.utilities.imports import _DEEPSPEED_AVAILABLE, _TORCH_GREATER_EQUAL_2_4 def _runif_reasons( diff --git a/src/lightning/pytorch/accelerators/xla.py b/src/lightning/pytorch/accelerators/xla.py index 700a29c2e3e7b..c8b322f477d4d 100644 --- a/src/lightning/pytorch/accelerators/xla.py +++ b/src/lightning/pytorch/accelerators/xla.py @@ -39,15 +39,7 @@ def get_device_stats(self, device: _DEVICE) -> dict[str, Any]: A dictionary mapping the metrics (free memory and peak memory) to their values. """ - import torch_xla.core.xla_model as xm - - memory_info = xm.get_memory_info(device) - free_memory = memory_info["kb_free"] - peak_memory = memory_info["kb_total"] - free_memory - return { - "avg. free memory (MB)": free_memory, - "avg. peak memory (MB)": peak_memory, - } + return self.accelerator_impl.get_device_stats(device) @staticmethod @override diff --git a/src/lightning/pytorch/loggers/comet.py b/src/lightning/pytorch/loggers/comet.py index b544212e755e2..5e57e38dc4072 100644 --- a/src/lightning/pytorch/loggers/comet.py +++ b/src/lightning/pytorch/loggers/comet.py @@ -17,18 +17,15 @@ """ import logging -import os from argparse import Namespace from collections.abc import Mapping from typing import TYPE_CHECKING, Any, Literal, Optional, Union -from lightning_utilities.core.imports import RequirementCache from torch import Tensor from torch.nn import Module from typing_extensions import override -from lightning.fabric.utilities.logger import _convert_params -from lightning.fabric.utilities.rank_zero import _get_rank +from lightning.fabric.utilities.imports import _raise_enterprise_not_available from lightning.pytorch.loggers.logger import Logger, rank_zero_experiment from lightning.pytorch.utilities.rank_zero import rank_zero_only @@ -36,7 +33,6 @@ from comet_ml import ExistingExperiment, Experiment, OfflineExperiment log = logging.getLogger(__name__) -_COMET_AVAILABLE = RequirementCache("comet-ml>=3.44.4", module="comet_ml") FRAMEWORK_NAME = "pytorch-lightning" comet_experiment = Union["Experiment", "ExistingExperiment", "OfflineExperiment"] @@ -208,99 +204,23 @@ def __init__( prefix: Optional[str] = None, **kwargs: Any, ): - if not _COMET_AVAILABLE: - raise ModuleNotFoundError(str(_COMET_AVAILABLE)) + _raise_enterprise_not_available() super().__init__() - ################################################## - # HANDLE PASSED OLD TYPE PARAMS - - # handle old "experiment_name" param - if "experiment_name" in kwargs: - log.warning("The parameter `experiment_name` is deprecated, please use `name` instead.") - experiment_name = kwargs.pop("experiment_name") - - if "name" not in kwargs: - kwargs["name"] = experiment_name - else: - log.warning("You specified both `experiment_name` and `name` parameters, please use `name` only") - - # handle old "project_name" param - if "project_name" in kwargs: - log.warning("The parameter `project_name` is deprecated, please use `project` instead.") - if project is None: - project = kwargs.pop("project_name") - else: - log.warning("You specified both `project_name` and `project` parameters, please use `project` only") - - # handle old "offline" experiment flag - if "offline" in kwargs: - log.warning("The parameter `offline is deprecated, please use `online` instead.") - if online is None: - online = kwargs.pop("offline") - else: - log.warning("You specified both `offline` and `online` parameters, please use `online` only") - - # handle old "save_dir" param - if "save_dir" in kwargs: - log.warning("The parameter `save_dir` is deprecated, please use `offline_directory` instead.") - if "offline_directory" not in kwargs: - kwargs["offline_directory"] = kwargs.pop("save_dir") - else: - log.warning( - "You specified both `save_dir` and `offline_directory` parameters, " - "please use `offline_directory` only" - ) - ################################################## - - self._api_key: Optional[str] = api_key - self._experiment: Optional[comet_experiment] = None - self._workspace: Optional[str] = workspace - self._mode: Optional[Literal["get_or_create", "get", "create"]] = mode - self._online: Optional[bool] = online - self._project_name: Optional[str] = project - self._experiment_key: Optional[str] = experiment_key - self._prefix: Optional[str] = prefix - self._kwargs: dict[str, Any] = kwargs - - # needs to be set before the first `comet_ml` import - # because comet_ml imported after another machine learning libraries (Torch) - os.environ["COMET_DISABLE_AUTO_LOGGING"] = "1" - - import comet_ml - - config_kwargs = self._kwargs.copy() - if online is False: - config_kwargs["disabled"] = True - self._comet_config = comet_ml.ExperimentConfig(**config_kwargs) - - # create real experiment only on main node/process (when strategy=auto/ddp) - if _get_rank() is not None and _get_rank() != 0: - return - - self._create_experiment() - - def _create_experiment(self) -> None: - import comet_ml - - self._experiment = comet_ml.start( - api_key=self._api_key, - workspace=self._workspace, - project=self._project_name, - experiment_key=self._experiment_key, - mode=self._mode, - online=self._online, - experiment_config=self._comet_config, + from pytorch_lightning_enterprise.loggers.comet import CometLogger as EnterpriseCometLogger + + self.logger_impl = EnterpriseCometLogger( + api_key=api_key, + workspace=workspace, + project=project, + experiment_key=experiment_key, + mode=mode, + online=online, + prefix=prefix, + **kwargs, ) - if self._experiment is None: - raise comet_ml.exceptions.ExperimentNotFound("Failed to create Comet experiment.") - - self._experiment_key = self._experiment.get_key() - self._project_name = self._experiment.project_name - self._experiment.log_other("Created from", FRAMEWORK_NAME) - @property @rank_zero_experiment def experiment(self) -> comet_experiment: @@ -313,55 +233,24 @@ def experiment(self) -> comet_experiment: """ - # if by some chance there is no experiment created yet (for example, when strategy=ddp_spawn) - # then we will create a new one - if not self._experiment: - self._create_experiment() - - return self._experiment + return self.logger_impl.experiment @override @rank_zero_only def log_hyperparams(self, params: Union[dict[str, Any], Namespace]) -> None: - params = _convert_params(params) - self.experiment.__internal_api__log_parameters__( - parameters=params, - framework=FRAMEWORK_NAME, - flatten_nested=True, - source="manual", - ) + return self.logger_impl.log_hyperparams(params) @override @rank_zero_only def log_metrics(self, metrics: Mapping[str, Union[Tensor, float]], step: Optional[int] = None) -> None: - assert rank_zero_only.rank == 0, "experiment tried to log from global_rank != 0" - # Comet.com expects metrics to be a dictionary of detached tensors on CPU - metrics_without_epoch = metrics.copy() - for key, val in metrics_without_epoch.items(): - if isinstance(val, Tensor): - metrics_without_epoch[key] = val.cpu().detach() - - epoch = metrics_without_epoch.pop("epoch", None) - self.experiment.__internal_api__log_metrics__( - metrics_without_epoch, - step=step, - epoch=epoch, - prefix=self._prefix, - framework=FRAMEWORK_NAME, - ) + return self.logger_impl.log_metrics(metrics, step) @override @rank_zero_only def finalize(self, status: str) -> None: """We will not end experiment (will not call self._experiment.end()) here to have an ability to continue using it after training is complete but instead of ending we will upload/save all the data.""" - if self._experiment is None: - # When using multiprocessing, finalize() should be a no-op on the main process, as no experiment has been - # initialized there - return - - # just save the data - self.experiment.flush() + return self.logger_impl.finalize(status) @property @override @@ -372,7 +261,7 @@ def save_dir(self) -> Optional[str]: The path to the save directory. """ - return self._comet_config.offline_directory + return self.logger_impl.save_dir @property @override @@ -383,7 +272,7 @@ def name(self) -> Optional[str]: The project name if it is specified. """ - return self._project_name + return self.logger_impl.name @property @override @@ -395,27 +284,8 @@ def version(self) -> Optional[str]: """ # Don't create an experiment if we don't have one - if self._experiment is not None: - return self._experiment.get_key() - - def __getstate__(self) -> dict[str, Any]: - state = self.__dict__.copy() - - # Save the experiment id in case an experiment object already exists, - # this way we could create an ExistingExperiment pointing to the same - # experiment - state["_experiment_key"] = self._experiment.get_key() if self._experiment is not None else None - - # Remove the experiment object as it contains hard to pickle objects - # (like network connections), the experiment object will be recreated if - # needed later - state["_experiment"] = None - return state + return self.logger_impl.version @override def log_graph(self, model: Module, input_array: Optional[Tensor] = None) -> None: - if self._experiment is not None: - self._experiment.__internal_api__set_model_graph__( - graph=model, - framework=FRAMEWORK_NAME, - ) + return self.logger_impl.log_graph(model, input_array) diff --git a/src/lightning/pytorch/loggers/mlflow.py b/src/lightning/pytorch/loggers/mlflow.py index ff9b2b0d7e542..725fd2cb8bfbb 100644 --- a/src/lightning/pytorch/loggers/mlflow.py +++ b/src/lightning/pytorch/loggers/mlflow.py @@ -16,35 +16,21 @@ ------------- """ -import logging import os -import re -import tempfile from argparse import Namespace from collections.abc import Mapping -from pathlib import Path -from time import time -from typing import TYPE_CHECKING, Any, Callable, Literal, Optional, Union +from typing import TYPE_CHECKING, Any, Literal, Optional, Union -import yaml -from lightning_utilities.core.imports import RequirementCache -from torch import Tensor from typing_extensions import override -from lightning.fabric.utilities.logger import _add_prefix, _convert_params, _flatten_dict +from lightning.fabric.utilities.imports import _raise_enterprise_not_available from lightning.pytorch.callbacks.model_checkpoint import ModelCheckpoint from lightning.pytorch.loggers.logger import Logger, rank_zero_experiment -from lightning.pytorch.loggers.utilities import _scan_checkpoints -from lightning.pytorch.utilities.rank_zero import rank_zero_only, rank_zero_warn +from lightning.pytorch.utilities.rank_zero import rank_zero_only if TYPE_CHECKING: from mlflow.tracking import MlflowClient -log = logging.getLogger(__name__) -LOCAL_FILE_URI_PREFIX = "file:" -_MLFLOW_AVAILABLE = RequirementCache("mlflow>=1.0.0", "mlflow") -_MLFLOW_SYNCHRONOUS_AVAILABLE = RequirementCache("mlflow>=2.8.0", "mlflow") - class MLFlowLogger(Logger): """Log using `MLflow `_. @@ -126,31 +112,22 @@ def __init__( run_id: Optional[str] = None, synchronous: Optional[bool] = None, ): - if not _MLFLOW_AVAILABLE: - raise ModuleNotFoundError(str(_MLFLOW_AVAILABLE)) - if synchronous is not None and not _MLFLOW_SYNCHRONOUS_AVAILABLE: - raise ModuleNotFoundError("`synchronous` requires mlflow>=2.8.0") + _raise_enterprise_not_available() + from pytorch_lightning_enterprise.loggers.mlflow import MLFlowLogger as EnterpriseMLFlowLogger + super().__init__() - if not tracking_uri: - tracking_uri = f"{LOCAL_FILE_URI_PREFIX}{save_dir}" - - self._experiment_name = experiment_name - self._experiment_id: Optional[str] = None - self._tracking_uri = tracking_uri - self._run_name = run_name - self._run_id = run_id - self.tags = tags - self._log_model = log_model - self._logged_model_time: dict[str, float] = {} - self._checkpoint_callback: Optional[ModelCheckpoint] = None - self._prefix = prefix - self._artifact_location = artifact_location - self._log_batch_kwargs = {} if synchronous is None else {"synchronous": synchronous} - self._initialized = False - - from mlflow.tracking import MlflowClient - - self._mlflow_client = MlflowClient(tracking_uri) + self.logger_impl = EnterpriseMLFlowLogger( + experiment_name=experiment_name, + run_name=run_name, + tracking_uri=tracking_uri, + tags=tags, + save_dir=save_dir, + log_model=log_model, + prefix=prefix, + artifact_location=artifact_location, + run_id=run_id, + synchronous=synchronous, + ) @property @rank_zero_experiment @@ -163,46 +140,7 @@ def experiment(self) -> "MlflowClient": self.logger.experiment.some_mlflow_function() """ - import mlflow - - if self._initialized: - return self._mlflow_client - - mlflow.set_tracking_uri(self._tracking_uri) - - if self._run_id is not None: - run = self._mlflow_client.get_run(self._run_id) - self._experiment_id = run.info.experiment_id - self._initialized = True - return self._mlflow_client - - if self._experiment_id is None: - expt = self._mlflow_client.get_experiment_by_name(self._experiment_name) - if expt is not None and expt.lifecycle_stage != "deleted": - self._experiment_id = expt.experiment_id - else: - log.warning(f"Experiment with name {self._experiment_name} not found. Creating it.") - self._experiment_id = self._mlflow_client.create_experiment( - name=self._experiment_name, artifact_location=self._artifact_location - ) - - if self._run_id is None: - if self._run_name is not None: - self.tags = self.tags or {} - - from mlflow.utils.mlflow_tags import MLFLOW_RUN_NAME - - if MLFLOW_RUN_NAME in self.tags: - log.warning( - f"The tag {MLFLOW_RUN_NAME} is found in tags. The value will be overridden by {self._run_name}." - ) - self.tags[MLFLOW_RUN_NAME] = self._run_name - - resolve_tags = _get_resolve_tags() - run = self._mlflow_client.create_run(experiment_id=self._experiment_id, tags=resolve_tags(self.tags)) - self._run_id = run.info.run_id - self._initialized = True - return self._mlflow_client + return self.logger_impl.experiment @property def run_id(self) -> Optional[str]: @@ -212,8 +150,7 @@ def run_id(self) -> Optional[str]: The run id. """ - _ = self.experiment - return self._run_id + return self.logger_impl.run_id @property def experiment_id(self) -> Optional[str]: @@ -223,71 +160,22 @@ def experiment_id(self) -> Optional[str]: The experiment id. """ - _ = self.experiment - return self._experiment_id + return self.logger_impl.experiment_id @override @rank_zero_only def log_hyperparams(self, params: Union[dict[str, Any], Namespace]) -> None: - params = _convert_params(params) - params = _flatten_dict(params) - - from mlflow.entities import Param - - # Truncate parameter values to 250 characters. - # TODO: MLflow 1.28 allows up to 500 characters: https://github.com/mlflow/mlflow/releases/tag/v1.28.0 - params_list = [Param(key=k, value=str(v)[:250]) for k, v in params.items()] - - # Log in chunks of 100 parameters (the maximum allowed by MLflow). - for idx in range(0, len(params_list), 100): - self.experiment.log_batch(run_id=self.run_id, params=params_list[idx : idx + 100], **self._log_batch_kwargs) + return self.logger_impl.log_hyperparams(params) @override @rank_zero_only def log_metrics(self, metrics: Mapping[str, float], step: Optional[int] = None) -> None: - assert rank_zero_only.rank == 0, "experiment tried to log from global_rank != 0" - - from mlflow.entities import Metric - - metrics = _add_prefix(metrics, self._prefix, self.LOGGER_JOIN_CHAR) - metrics_list: list[Metric] = [] - - timestamp_ms = int(time() * 1000) - for k, v in metrics.items(): - if isinstance(v, str): - log.warning(f"Discarding metric with string value {k}={v}.") - continue - - new_k = re.sub("[^a-zA-Z0-9_/. -]+", "", k) - if k != new_k: - rank_zero_warn( - "MLFlow only allows '_', '/', '.' and ' ' special characters in metric name." - f" Replacing {k} with {new_k}.", - category=RuntimeWarning, - ) - k = new_k - metrics_list.append(Metric(key=k, value=v, timestamp=timestamp_ms, step=step or 0)) - - self.experiment.log_batch(run_id=self.run_id, metrics=metrics_list, **self._log_batch_kwargs) + return self.logger_impl.log_metrics(metrics, step) @override @rank_zero_only def finalize(self, status: str = "success") -> None: - if not self._initialized: - return - if status == "success": - status = "FINISHED" - elif status == "failed": - status = "FAILED" - elif status == "finished": - status = "FINISHED" - - # log checkpoints as artifacts - if self._checkpoint_callback: - self._scan_and_log_checkpoints(self._checkpoint_callback) - - if self.experiment.get_run(self.run_id): - self.experiment.set_terminated(self.run_id, status) + return self.logger_impl.finalize(status) @property @override @@ -299,9 +187,7 @@ def save_dir(self) -> Optional[str]: Otherwise returns `None`. """ - if self._tracking_uri.startswith(LOCAL_FILE_URI_PREFIX): - return self._tracking_uri[len(LOCAL_FILE_URI_PREFIX) :] - return None + return self.logger_impl.save_dir @property @override @@ -312,7 +198,7 @@ def name(self) -> Optional[str]: The experiment id. """ - return self.experiment_id + return self.logger_impl.name @property @override @@ -323,76 +209,8 @@ def version(self) -> Optional[str]: The run id. """ - return self.run_id + return self.logger_impl.version @override def after_save_checkpoint(self, checkpoint_callback: ModelCheckpoint) -> None: - # log checkpoints as artifacts - if self._log_model == "all" or self._log_model is True and checkpoint_callback.save_top_k == -1: - self._scan_and_log_checkpoints(checkpoint_callback) - elif self._log_model is True: - self._checkpoint_callback = checkpoint_callback - - def _scan_and_log_checkpoints(self, checkpoint_callback: ModelCheckpoint) -> None: - # get checkpoints to be saved with associated score - checkpoints = _scan_checkpoints(checkpoint_callback, self._logged_model_time) - - # log iteratively all new checkpoints - for t, p, s, tag in checkpoints: - metadata = { - # Ensure .item() is called to store Tensor contents - "score": s.item() if isinstance(s, Tensor) else s, - "original_filename": Path(p).name, - "Checkpoint": { - k: getattr(checkpoint_callback, k) - for k in [ - "monitor", - "mode", - "save_last", - "save_top_k", - "save_weights_only", - "_every_n_train_steps", - "_every_n_val_epochs", - ] - # ensure it does not break if `Checkpoint` args change - if hasattr(checkpoint_callback, k) - }, - } - aliases = ["latest", "best"] if p == checkpoint_callback.best_model_path else ["latest"] - - # Artifact path on mlflow - artifact_path = Path(p).stem - - # Log the checkpoint - self.experiment.log_artifact(self._run_id, p, artifact_path) - - # Create a temporary directory to log on mlflow - with tempfile.TemporaryDirectory() as tmp_dir: - # Log the metadata - with open(f"{tmp_dir}/metadata.yaml", "w") as tmp_file_metadata: - yaml.dump(metadata, tmp_file_metadata, default_flow_style=False) - - # Log the aliases - with open(f"{tmp_dir}/aliases.txt", "w") as tmp_file_aliases: - tmp_file_aliases.write(str(aliases)) - - # Log the metadata and aliases - self.experiment.log_artifacts(self._run_id, tmp_dir, artifact_path) - - # remember logged models - timestamp needed in case filename didn't change (lastkckpt or custom name) - self._logged_model_time[p] = t - - -def _get_resolve_tags() -> Callable: - from mlflow.tracking import context - - # before v1.1.0 - if hasattr(context, "resolve_tags"): - from mlflow.tracking.context import resolve_tags - # since v1.1.0 - elif hasattr(context, "registry"): - from mlflow.tracking.context.registry import resolve_tags - else: - resolve_tags = lambda tags: tags - - return resolve_tags + return self.logger_impl.after_save_checkpoint(checkpoint_callback) diff --git a/src/lightning/pytorch/loggers/neptune.py b/src/lightning/pytorch/loggers/neptune.py index bf9669c824784..c2dd2a9b5a142 100644 --- a/src/lightning/pytorch/loggers/neptune.py +++ b/src/lightning/pytorch/loggers/neptune.py @@ -16,23 +16,17 @@ -------------- """ -import contextlib import logging -import os from argparse import Namespace -from collections.abc import Generator -from functools import wraps -from typing import TYPE_CHECKING, Any, Callable, Optional, Union +from typing import TYPE_CHECKING, Any, Optional, Union -from lightning_utilities.core.imports import RequirementCache from torch import Tensor from typing_extensions import override import lightning.pytorch as pl -from lightning.fabric.utilities.logger import _add_prefix, _convert_params, _sanitize_callable_params +from lightning.fabric.utilities.imports import _raise_enterprise_not_available from lightning.pytorch.callbacks import Checkpoint from lightning.pytorch.loggers.logger import Logger, rank_zero_experiment -from lightning.pytorch.utilities.model_summary import ModelSummary from lightning.pytorch.utilities.rank_zero import rank_zero_only if TYPE_CHECKING: @@ -41,27 +35,6 @@ log = logging.getLogger(__name__) -# Neptune is available with two names on PyPI : `neptune` and `neptune-client` -# `neptune` was introduced as a name transition of neptune-client and the long-term target is to get -# rid of Neptune-client package completely someday. It was introduced as a part of breaking-changes with a release -# of neptune-client==1.0. neptune-client>=1.0 is just an alias of neptune package and have some breaking-changes -# in compare to neptune-client<1.0.0. -_NEPTUNE_AVAILABLE = RequirementCache("neptune>=1.0") -_INTEGRATION_VERSION_KEY = "source_code/integrations/pytorch-lightning" - - -# Neptune client throws `InactiveRunException` when trying to log to an inactive run. -# This may happen when the run was stopped through the UI and the logger is still trying to log to it. -def _catch_inactive(func: Callable) -> Callable: - @wraps(func) - def wrapper(*args: Any, **kwargs: Any) -> Any: - from neptune.exceptions import InactiveRunException - - with contextlib.suppress(InactiveRunException): - return func(*args, **kwargs) - - return wrapper - class NeptuneLogger(Logger): r"""Log using `Neptune `_. @@ -242,113 +215,19 @@ def __init__( prefix: str = "training", **neptune_run_kwargs: Any, ): - if not _NEPTUNE_AVAILABLE: - raise ModuleNotFoundError(str(_NEPTUNE_AVAILABLE)) - - # verify if user passed proper init arguments - self._verify_input_arguments(api_key, project, name, run, neptune_run_kwargs) + _raise_enterprise_not_available() super().__init__() - self._log_model_checkpoints = log_model_checkpoints - self._prefix = prefix - self._run_name = name - self._project_name = project - self._api_key = api_key - self._run_instance = run - self._neptune_run_kwargs = neptune_run_kwargs - self._run_short_id: Optional[str] = None - - if self._run_instance is not None: - self._retrieve_run_data() - - from neptune.handler import Handler - - # make sure that we've log integration version for outside `Run` instances - root_obj = self._run_instance - if isinstance(root_obj, Handler): - root_obj = root_obj.get_root_object() - - root_obj[_INTEGRATION_VERSION_KEY] = pl.__version__ - - def _retrieve_run_data(self) -> None: - from neptune.handler import Handler - - assert self._run_instance is not None - root_obj = self._run_instance - if isinstance(root_obj, Handler): - root_obj = root_obj.get_root_object() - - root_obj.wait() - - if root_obj.exists("sys/id"): - self._run_short_id = root_obj["sys/id"].fetch() - self._run_name = root_obj["sys/name"].fetch() - else: - self._run_short_id = "OFFLINE" - self._run_name = "offline-name" - - @property - def _neptune_init_args(self) -> dict: - args: dict = {} - # Backward compatibility in case of previous version retrieval - with contextlib.suppress(AttributeError): - args = self._neptune_run_kwargs - - if self._project_name is not None: - args["project"] = self._project_name - - if self._api_key is not None: - args["api_token"] = self._api_key - - if self._run_short_id is not None: - args["run"] = self._run_short_id - - # Backward compatibility in case of previous version retrieval - with contextlib.suppress(AttributeError): - if self._run_name is not None: - args["name"] = self._run_name - - return args - - def _construct_path_with_prefix(self, *keys: str) -> str: - """Return sequence of keys joined by `LOGGER_JOIN_CHAR`, started with `_prefix` if defined.""" - if self._prefix: - return self.LOGGER_JOIN_CHAR.join([self._prefix, *keys]) - return self.LOGGER_JOIN_CHAR.join(keys) - - @staticmethod - def _verify_input_arguments( - api_key: Optional[str], - project: Optional[str], - name: Optional[str], - run: Optional[Union["Run", "Handler"]], - neptune_run_kwargs: dict, - ) -> None: - from neptune import Run - from neptune.handler import Handler - - # check if user passed the client `Run`/`Handler` object - if run is not None and not isinstance(run, (Run, Handler)): - raise ValueError("Run parameter expected to be of type `neptune.Run`, or `neptune.handler.Handler`.") - - # check if user passed redundant neptune.init_run arguments when passed run - any_neptune_init_arg_passed = any(arg is not None for arg in [api_key, project, name]) or neptune_run_kwargs - if run is not None and any_neptune_init_arg_passed: - raise ValueError( - "When an already initialized run object is provided, you can't provide other `neptune.init_run()`" - " parameters." - ) - - def __getstate__(self) -> dict[str, Any]: - state = self.__dict__.copy() - # Run instance can't be pickled - state["_run_instance"] = None - return state - - def __setstate__(self, state: dict[str, Any]) -> None: - import neptune - - self.__dict__ = state - self._run_instance = neptune.init_run(**self._neptune_init_args) + from pytorch_lightning_enterprise.loggers.neptune import NeptuneLogger as EnterpriseNeptuneLogger + + self.logger_impl = EnterpriseNeptuneLogger( + api_key=api_key, + project=project, + name=name, + run=run, + log_model_checkpoints=log_model_checkpoints, + prefix=prefix, + **neptune_run_kwargs, + ) @property @rank_zero_experiment @@ -378,72 +257,20 @@ def training_step(self, batch, batch_idx): with NeptuneLogger. """ - return self.run + return self.logger_impl.experiment @property @rank_zero_experiment def run(self) -> "Run": - import neptune - - if not self._run_instance: - self._run_instance = neptune.init_run(**self._neptune_init_args) - self._retrieve_run_data() - # make sure that we've log integration version for newly created - self._run_instance[_INTEGRATION_VERSION_KEY] = pl.__version__ - - return self._run_instance + return self.logger_impl.run @override @rank_zero_only - @_catch_inactive def log_hyperparams(self, params: Union[dict[str, Any], Namespace]) -> None: - r"""Log hyperparameters to the run. - - Hyperparameters will be logged under the "/hyperparams" namespace. - - Note: - - You can also log parameters by directly using the logger instance: - ``neptune_logger.experiment["model/hyper-parameters"] = params_dict``. - - In this way you can keep hierarchical structure of the parameters. - - Args: - params: `dict`. - Python dictionary structure with parameters. - - Example:: - - from lightning.pytorch.loggers import NeptuneLogger - import neptune - - PARAMS = { - "batch_size": 64, - "lr": 0.07, - "decay_factor": 0.97, - } - - neptune_logger = NeptuneLogger( - api_key=neptune.ANONYMOUS_API_TOKEN, - project="common/pytorch-lightning-integration" - ) - - neptune_logger.log_hyperparams(PARAMS) - - """ - from neptune.utils import stringify_unsupported - - params = _convert_params(params) - params = _sanitize_callable_params(params) - - parameters_key = self.PARAMETERS_KEY - parameters_key = self._construct_path_with_prefix(parameters_key) - - self.run[parameters_key] = stringify_unsupported(params) + return self.logger_impl.log_hyperparams(params) @override @rank_zero_only - @_catch_inactive def log_metrics(self, metrics: dict[str, Union[Tensor, float]], step: Optional[int] = None) -> None: """Log metrics (numeric values) in Neptune runs. @@ -452,26 +279,12 @@ def log_metrics(self, metrics: dict[str, Union[Tensor, float]], step: Optional[i step: Step number at which the metrics should be recorded """ - if rank_zero_only.rank != 0: - raise ValueError("run tried to log from global_rank != 0") - - metrics = _add_prefix(metrics, self._prefix, self.LOGGER_JOIN_CHAR) - - for key, val in metrics.items(): - self.run[key].append(val, step=step) + return self.logger_impl.log_metrics(metrics, step) @override @rank_zero_only - @_catch_inactive def finalize(self, status: str) -> None: - if not self._run_instance: - # When using multiprocessing, finalize() should be a no-op on the main process, as no experiment has been - # initialized there - return - if status: - self.run[self._construct_path_with_prefix("status")] = status - - super().finalize(status) + return self.logger_impl.finalize(status) @property @override @@ -483,21 +296,14 @@ def save_dir(self) -> Optional[str]: the root directory where experiment logs get saved """ - return os.path.join(os.getcwd(), ".neptune") + return self.logger_impl.save_dir @rank_zero_only - @_catch_inactive def log_model_summary(self, model: "pl.LightningModule", max_depth: int = -1) -> None: - from neptune.types import File - - model_str = str(ModelSummary(model=model, max_depth=max_depth)) - self.run[self._construct_path_with_prefix("model/summary")] = File.from_content( - content=model_str, extension="txt" - ) + return self.logger_impl.log_model_summary(model, max_depth) @override @rank_zero_only - @_catch_inactive def after_save_checkpoint(self, checkpoint_callback: Checkpoint) -> None: """Automatically log checkpointed model. Called after model checkpoint callback saves a new checkpoint. @@ -505,83 +311,13 @@ def after_save_checkpoint(self, checkpoint_callback: Checkpoint) -> None: checkpoint_callback: the model checkpoint callback instance """ - if not self._log_model_checkpoints: - return - - file_names = set() - checkpoints_namespace = self._construct_path_with_prefix("model/checkpoints") - - # save last model - if hasattr(checkpoint_callback, "last_model_path") and checkpoint_callback.last_model_path: - model_last_name = self._get_full_model_name(checkpoint_callback.last_model_path, checkpoint_callback) - file_names.add(model_last_name) - self.run[f"{checkpoints_namespace}/{model_last_name}"].upload(checkpoint_callback.last_model_path) - - # save best k models - if hasattr(checkpoint_callback, "best_k_models"): - for key in checkpoint_callback.best_k_models: - model_name = self._get_full_model_name(key, checkpoint_callback) - file_names.add(model_name) - self.run[f"{checkpoints_namespace}/{model_name}"].upload(key) - - # log best model path and checkpoint - if hasattr(checkpoint_callback, "best_model_path") and checkpoint_callback.best_model_path: - self.run[self._construct_path_with_prefix("model/best_model_path")] = checkpoint_callback.best_model_path - - model_name = self._get_full_model_name(checkpoint_callback.best_model_path, checkpoint_callback) - file_names.add(model_name) - self.run[f"{checkpoints_namespace}/{model_name}"].upload(checkpoint_callback.best_model_path) - - # remove old models logged to experiment if they are not part of best k models at this point - if self.run.exists(checkpoints_namespace): - exp_structure = self.run.get_structure() - uploaded_model_names = self._get_full_model_names_from_exp_structure(exp_structure, checkpoints_namespace) - - for file_to_drop in list(uploaded_model_names - file_names): - del self.run[f"{checkpoints_namespace}/{file_to_drop}"] - - # log best model score - if hasattr(checkpoint_callback, "best_model_score") and checkpoint_callback.best_model_score: - self.run[self._construct_path_with_prefix("model/best_model_score")] = ( - checkpoint_callback.best_model_score.cpu().detach().numpy() - ) - - @staticmethod - def _get_full_model_name(model_path: str, checkpoint_callback: Checkpoint) -> str: - """Returns model name which is string `model_path` appended to `checkpoint_callback.dirpath`.""" - if hasattr(checkpoint_callback, "dirpath"): - model_path = os.path.normpath(model_path) - expected_model_path = os.path.normpath(checkpoint_callback.dirpath) - if not model_path.startswith(expected_model_path): - raise ValueError(f"{model_path} was expected to start with {expected_model_path}.") - # Remove extension from filepath - filepath, _ = os.path.splitext(model_path[len(expected_model_path) + 1 :]) - return filepath.replace(os.sep, "/") - return model_path.replace(os.sep, "/") - - @classmethod - def _get_full_model_names_from_exp_structure(cls, exp_structure: dict[str, Any], namespace: str) -> set[str]: - """Returns all paths to properties which were already logged in `namespace`""" - structure_keys: list[str] = namespace.split(cls.LOGGER_JOIN_CHAR) - for key in structure_keys: - exp_structure = exp_structure[key] - uploaded_models_dict = exp_structure - return set(cls._dict_paths(uploaded_models_dict)) - - @classmethod - def _dict_paths(cls, d: dict[str, Any], path_in_build: Optional[str] = None) -> Generator: - for k, v in d.items(): - path = f"{path_in_build}/{k}" if path_in_build is not None else k - if not isinstance(v, dict): - yield path - else: - yield from cls._dict_paths(v, path) + return self.logger_impl.after_save_checkpoint(checkpoint_callback) @property @override def name(self) -> Optional[str]: """Return the experiment name or 'offline-name' when exp is run in offline mode.""" - return self._run_name + return self.logger_impl.name @property @override @@ -591,4 +327,4 @@ def version(self) -> Optional[str]: It's Neptune Run's short_id """ - return self._run_short_id + return self.logger_impl.version diff --git a/src/lightning/pytorch/loggers/wandb.py b/src/lightning/pytorch/loggers/wandb.py index 37ca362fa40c1..41d38b460d0f0 100644 --- a/src/lightning/pytorch/loggers/wandb.py +++ b/src/lightning/pytorch/loggers/wandb.py @@ -16,37 +16,24 @@ ------------------------- """ -import os from argparse import Namespace from collections.abc import Mapping -from pathlib import Path from typing import TYPE_CHECKING, Any, Literal, Optional, Union import torch.nn as nn -from lightning_utilities.core.imports import RequirementCache -from torch import Tensor from typing_extensions import override -from lightning.fabric.utilities.logger import ( - _add_prefix, - _convert_json_serializable, - _convert_params, - _sanitize_callable_params, -) +from lightning.fabric.utilities.imports import _raise_enterprise_not_available from lightning.fabric.utilities.types import _PATH from lightning.pytorch.callbacks.model_checkpoint import ModelCheckpoint from lightning.pytorch.loggers.logger import Logger, rank_zero_experiment -from lightning.pytorch.loggers.utilities import _scan_checkpoints -from lightning.pytorch.utilities.exceptions import MisconfigurationException -from lightning.pytorch.utilities.rank_zero import rank_zero_only, rank_zero_warn +from lightning.pytorch.utilities.rank_zero import rank_zero_only if TYPE_CHECKING: from wandb import Artifact from wandb.sdk.lib import RunDisabled from wandb.wandb_run import Run -_WANDB_AVAILABLE = RequirementCache("wandb>=0.12.10") - class WandbLogger(Logger): r"""Log using `Weights and Biases `_. @@ -308,70 +295,27 @@ def __init__( add_file_policy: Literal["mutable", "immutable"] = "mutable", **kwargs: Any, ) -> None: - if not _WANDB_AVAILABLE: - raise ModuleNotFoundError(str(_WANDB_AVAILABLE)) - - if offline and log_model: - raise MisconfigurationException( - f"Providing log_model={log_model} and offline={offline} is an invalid configuration" - " since model checkpoints cannot be uploaded in offline mode.\n" - "Hint: Set `offline=False` to log your model." - ) + _raise_enterprise_not_available() super().__init__() - self._offline = offline - self._log_model = log_model - self._prefix = prefix - self._experiment = experiment - self._logged_model_time: dict[str, float] = {} - self._checkpoint_callbacks: dict[int, ModelCheckpoint] = {} - self.add_file_policy = add_file_policy - - # paths are processed as strings - if save_dir is not None: - save_dir = os.fspath(save_dir) - elif dir is not None: - dir = os.fspath(dir) - - project = project or os.environ.get("WANDB_PROJECT", "lightning_logs") - - # set wandb init arguments - self._wandb_init: dict[str, Any] = { - "name": name, - "project": project, - "dir": save_dir or dir, - "id": version or id, - "resume": "allow", - "anonymous": ("allow" if anonymous else None), - } - self._wandb_init.update(**kwargs) - # extract parameters - self._project = self._wandb_init.get("project") - self._save_dir = self._wandb_init.get("dir") - self._name = self._wandb_init.get("name") - self._id = self._wandb_init.get("id") - self._checkpoint_name = checkpoint_name - - def __getstate__(self) -> dict[str, Any]: - import wandb - - # Hack: If the 'spawn' launch method is used, the logger will get pickled and this `__getstate__` gets called. - # We create an experiment here in the main process, and attach to it in the worker process. - # Using wandb-service, we persist the same experiment even if multiple `Trainer.fit/test/validate` calls - # are made. - wandb.require("service") - _ = self.experiment - - state = self.__dict__.copy() - # args needed to reload correct experiment - if self._experiment is not None: - state["_id"] = getattr(self._experiment, "id", None) - state["_attach_id"] = getattr(self._experiment, "_attach_id", None) - state["_name"] = self._experiment.name - - # cannot be pickled - state["_experiment"] = None - return state + from pytorch_lightning_enterprise.loggers.wandb import WandbLogger as EnterpriseWandbLogger + + self.logger_impl = EnterpriseWandbLogger( + name=name, + save_dir=save_dir, + version=version, + offline=offline, + dir=dir, + id=id, + anonymous=anonymous, + project=project, + log_model=log_model, + experiment=experiment, + prefix=prefix, + checkpoint_name=checkpoint_name, + add_file_policy=add_file_policy, + **kwargs, + ) @property @rank_zero_experiment @@ -386,40 +330,7 @@ def experiment(self) -> Union["Run", "RunDisabled"]: self.logger.experiment.some_wandb_function() """ - import wandb - from wandb.sdk.lib import RunDisabled - from wandb.wandb_run import Run - - if self._experiment is None: - if self._offline: - os.environ["WANDB_MODE"] = "dryrun" - - attach_id = getattr(self, "_attach_id", None) - if wandb.run is not None: - # wandb process already created in this instance - rank_zero_warn( - "There is a wandb run already in progress and newly created instances of `WandbLogger` will reuse" - " this run. If this is not desired, call `wandb.finish()` before instantiating `WandbLogger`." - ) - self._experiment = wandb.run - elif attach_id is not None and hasattr(wandb, "_attach"): - # attach to wandb process referenced - self._experiment = wandb._attach(attach_id) - else: - # create new wandb process - self._experiment = wandb.init(**self._wandb_init) - - # define default x-axis - if isinstance(self._experiment, (Run, RunDisabled)) and getattr( - self._experiment, "define_metric", None - ): - if self._wandb_init.get("sync_tensorboard"): - self._experiment.define_metric("*", step_metric="global_step") - else: - self._experiment.define_metric("trainer/global_step") - self._experiment.define_metric("*", step_metric="trainer/global_step", step_sync=True) - - return self._experiment + return self.logger_impl.experiment def watch( self, model: nn.Module, log: Optional[str] = "gradients", log_freq: int = 100, log_graph: bool = True @@ -429,21 +340,12 @@ def watch( @override @rank_zero_only def log_hyperparams(self, params: Union[dict[str, Any], Namespace]) -> None: - params = _convert_params(params) - params = _sanitize_callable_params(params) - params = _convert_json_serializable(params) - self.experiment.config.update(params, allow_val_change=True) + return self.logger_impl.log_hyperparams(params) @override @rank_zero_only def log_metrics(self, metrics: Mapping[str, float], step: Optional[int] = None) -> None: - assert rank_zero_only.rank == 0, "experiment tried to log from global_rank != 0" - - metrics = _add_prefix(metrics, self._prefix, self.LOGGER_JOIN_CHAR) - if step is not None and not self._wandb_init.get("sync_tensorboard"): - self.experiment.log(dict(metrics, **{"trainer/global_step": step})) - else: - self.experiment.log(metrics) + return self.logger_impl.log_metrics(metrics, step) @rank_zero_only def log_table( @@ -459,10 +361,7 @@ def log_table( Can be defined either with `columns` and `data` or with `dataframe`. """ - import wandb - - metrics = {key: wandb.Table(columns=columns, data=data, dataframe=dataframe)} - self.log_metrics(metrics, step) + return self.logger_impl.log_table(key, columns, data, dataframe, step) @rank_zero_only def log_text( @@ -478,8 +377,7 @@ def log_text( Can be defined either with `columns` and `data` or with `dataframe`. """ - - self.log_table(key, columns, data, dataframe, step) + return self.logger_impl.log_text(key, columns, data, dataframe, step) @rank_zero_only def log_image(self, key: str, images: list[Any], step: Optional[int] = None, **kwargs: Any) -> None: @@ -488,18 +386,7 @@ def log_image(self, key: str, images: list[Any], step: Optional[int] = None, **k Optional kwargs are lists passed to each image (ex: caption, masks, boxes). """ - if not isinstance(images, list): - raise TypeError(f'Expected a list as "images", found {type(images)}') - n = len(images) - for k, v in kwargs.items(): - if len(v) != n: - raise ValueError(f"Expected {n} items but only found {len(v)} for {k}") - kwarg_list = [{k: kwargs[k][i] for k in kwargs} for i in range(n)] - - import wandb - - metrics = {key: [wandb.Image(img, **kwarg) for img, kwarg in zip(images, kwarg_list)]} - self.log_metrics(metrics, step) # type: ignore[arg-type] + return self.logger_impl.log_image(key, images, step, **kwargs) @rank_zero_only def log_audio(self, key: str, audios: list[Any], step: Optional[int] = None, **kwargs: Any) -> None: @@ -514,18 +401,7 @@ def log_audio(self, key: str, audios: list[Any], step: Optional[int] = None, **k Optional kwargs are lists passed to each audio (ex: caption, sample_rate). """ - if not isinstance(audios, list): - raise TypeError(f'Expected a list as "audios", found {type(audios)}') - n = len(audios) - for k, v in kwargs.items(): - if len(v) != n: - raise ValueError(f"Expected {n} items but only found {len(v)} for {k}") - kwarg_list = [{k: kwargs[k][i] for k in kwargs} for i in range(n)] - - import wandb - - metrics = {key: [wandb.Audio(audio, **kwarg) for audio, kwarg in zip(audios, kwarg_list)]} - self.log_metrics(metrics, step) # type: ignore[arg-type] + return self.logger_impl.log_audio(key, audios, step, **kwargs) @rank_zero_only def log_video(self, key: str, videos: list[Any], step: Optional[int] = None, **kwargs: Any) -> None: @@ -540,18 +416,7 @@ def log_video(self, key: str, videos: list[Any], step: Optional[int] = None, **k Optional kwargs are lists passed to each video (ex: caption, fps, format). """ - if not isinstance(videos, list): - raise TypeError(f'Expected a list as "videos", found {type(videos)}') - n = len(videos) - for k, v in kwargs.items(): - if len(v) != n: - raise ValueError(f"Expected {n} items but only found {len(v)} for {k}") - kwarg_list = [{k: kwargs[k][i] for k in kwargs} for i in range(n)] - - import wandb - - metrics = {key: [wandb.Video(video, **kwarg) for video, kwarg in zip(videos, kwarg_list)]} - self.log_metrics(metrics, step) # type: ignore[arg-type] + return self.logger_impl.log_video(key, videos, step, **kwargs) @property @override @@ -562,7 +427,7 @@ def save_dir(self) -> Optional[str]: The path to the save directory. """ - return self._save_dir + return self.logger_impl.save_dir @property @override @@ -574,7 +439,7 @@ def name(self) -> Optional[str]: name. To access wandb's internal experiment name, use ``logger.experiment.name`` instead. """ - return self._project + return self.logger_impl.name @property @override @@ -586,15 +451,12 @@ def version(self) -> Optional[str]: """ # don't create an experiment if we don't have one - return self._experiment.id if self._experiment else self._id + return self.logger_impl.version @override def after_save_checkpoint(self, checkpoint_callback: ModelCheckpoint) -> None: # log checkpoints as artifacts - if self._log_model == "all" or self._log_model is True and checkpoint_callback.save_top_k == -1: - self._scan_and_log_checkpoints(checkpoint_callback) - elif self._log_model is True: - self._checkpoint_callbacks[id(checkpoint_callback)] = checkpoint_callback + return self.logger_impl.after_save_checkpoint(checkpoint_callback) @staticmethod @rank_zero_only @@ -616,16 +478,10 @@ def download_artifact( The path to the downloaded artifact. """ - import wandb - - if wandb.run is not None and use_artifact: - artifact = wandb.run.use_artifact(artifact) - else: - api = wandb.Api() - artifact = api.artifact(artifact, type=artifact_type) + _raise_enterprise_not_available() + from pytorch_lightning_enterprise.loggers.wandb import WandbLogger as EnterpriseWandbLogger - save_dir = None if save_dir is None else os.fspath(save_dir) - return artifact.download(root=save_dir) + return EnterpriseWandbLogger.download_artifact(artifact, save_dir, artifact_type, use_artifact) def use_artifact(self, artifact: str, artifact_type: Optional[str] = None) -> "Artifact": """Logs to the wandb dashboard that the mentioned artifact is used by the run. @@ -638,49 +494,9 @@ def use_artifact(self, artifact: str, artifact_type: Optional[str] = None) -> "A wandb Artifact object for the artifact. """ - return self.experiment.use_artifact(artifact, type=artifact_type) + return self.logger_impl.use_artifact(artifact, artifact_type) @override @rank_zero_only def finalize(self, status: str) -> None: - if status != "success": - # Currently, checkpoints only get logged on success - return - # log checkpoints as artifacts - if self._experiment is not None: - for checkpoint_callback in self._checkpoint_callbacks.values(): - self._scan_and_log_checkpoints(checkpoint_callback) - - def _scan_and_log_checkpoints(self, checkpoint_callback: ModelCheckpoint) -> None: - import wandb - - # get checkpoints to be saved with associated score - checkpoints = _scan_checkpoints(checkpoint_callback, self._logged_model_time) - - # log iteratively all new checkpoints - for t, p, s, tag in checkpoints: - metadata = { - "score": s.item() if isinstance(s, Tensor) else s, - "original_filename": Path(p).name, - checkpoint_callback.__class__.__name__: { - k: getattr(checkpoint_callback, k) - for k in [ - "monitor", - "mode", - "save_last", - "save_top_k", - "save_weights_only", - "_every_n_train_steps", - ] - # ensure it does not break if `ModelCheckpoint` args change - if hasattr(checkpoint_callback, k) - }, - } - if not self._checkpoint_name: - self._checkpoint_name = f"model-{self.experiment.id}" - artifact = wandb.Artifact(name=self._checkpoint_name, type="model", metadata=metadata) - artifact.add_file(p, name="model.ckpt", policy=self.add_file_policy) - aliases = ["latest", "best"] if p == checkpoint_callback.best_model_path else ["latest"] - self.experiment.log_artifact(artifact, aliases=aliases) - # remember logged models - timestamp needed in case filename didn't change (lastkckpt or custom name) - self._logged_model_time[p] = t + return self.logger_impl.finalize(status) diff --git a/src/lightning/pytorch/plugins/precision/deepspeed.py b/src/lightning/pytorch/plugins/precision/deepspeed.py index 9225e3bb9e7be..98580ee61e112 100644 --- a/src/lightning/pytorch/plugins/precision/deepspeed.py +++ b/src/lightning/pytorch/plugins/precision/deepspeed.py @@ -11,30 +11,21 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from contextlib import AbstractContextManager, nullcontext -from typing import TYPE_CHECKING, Any, Callable, Optional, Union +from contextlib import AbstractContextManager +from typing import Any, Callable, Optional, Union import torch -from lightning_utilities import apply_to_collection from torch import Tensor from torch.nn import Module -from torch.optim import LBFGS, Optimizer -from typing_extensions import get_args, override +from torch.optim import Optimizer +from typing_extensions import override import lightning.pytorch as pl from lightning.fabric.plugins.precision.deepspeed import _PRECISION_INPUT -from lightning.fabric.plugins.precision.utils import _convert_fp_tensor, _DtypeContextManager +from lightning.fabric.utilities.imports import _raise_enterprise_not_available from lightning.fabric.utilities.types import Steppable from lightning.pytorch.plugins.precision.precision import Precision from lightning.pytorch.utilities import GradClipAlgorithmType -from lightning.pytorch.utilities.exceptions import MisconfigurationException -from lightning.pytorch.utilities.model_helpers import is_overridden -from lightning.pytorch.utilities.rank_zero import WarningCache - -if TYPE_CHECKING: - import deepspeed - -warning_cache = WarningCache() class DeepSpeedPrecision(Precision): @@ -53,41 +44,29 @@ class DeepSpeedPrecision(Precision): """ def __init__(self, precision: _PRECISION_INPUT) -> None: - supported_precision = get_args(_PRECISION_INPUT) - if precision not in supported_precision: - raise ValueError( - f"`Trainer(strategy='deepspeed', precision={precision!r})` is not supported." - f" `precision` must be one of: {supported_precision}." - ) - self.precision = precision - precision_to_type = { - "bf16-mixed": torch.bfloat16, - "16-mixed": torch.float16, - "bf16-true": torch.bfloat16, - "16-true": torch.float16, - "32-true": torch.float32, - } - self._desired_dtype = precision_to_type[self.precision] + super().__init__() + _raise_enterprise_not_available() + from pytorch_lightning_enterprise.plugins.precision.deepspeed import ( + DeepSpeedPrecisionTrainer as EnterpriseDeepSpeedPrecision, + ) + + self.deepspeed_precision_impl = EnterpriseDeepSpeedPrecision(outer_object=self, precision=precision) @override def convert_module(self, module: Module) -> Module: - if "true" in self.precision: - return module.to(dtype=self._desired_dtype) - return module + return self.deepspeed_precision_impl.convert_module(module=module) @override def convert_input(self, data: Any) -> Any: - return apply_to_collection(data, function=_convert_fp_tensor, dtype=Tensor, dst_type=self._desired_dtype) + return self.deepspeed_precision_impl.convert_input(data=data) @override def tensor_init_context(self) -> AbstractContextManager: - if "true" not in self.precision: - return nullcontext() - return _DtypeContextManager(self._desired_dtype) + return self.deepspeed_precision_impl.tensor_init_context() @override def module_init_context(self) -> AbstractContextManager: - return self.tensor_init_context() + return self.deepspeed_precision_impl.module_init_context() @override def backward( # type: ignore[override] @@ -98,7 +77,7 @@ def backward( # type: ignore[override] *args: Any, **kwargs: Any, ) -> None: - r"""Performs back-propagation using DeepSpeed's engine. + r"""Performs back-propagation. Args: tensor: the loss tensor @@ -108,13 +87,7 @@ def backward( # type: ignore[override] \**kwargs: additional keyword arguments for the :meth:`deepspeed.DeepSpeedEngine.backward` call """ - if is_overridden("backward", model): - warning_cache.warn( - "You have overridden the `LightningModule.backward` hook but it will be ignored since DeepSpeed handles" - " the backward logic internally." - ) - deepspeed_engine: deepspeed.DeepSpeedEngine = model.trainer.model - deepspeed_engine.backward(tensor, *args, **kwargs) + return self.deepspeed_precision_impl.backward(tensor=tensor, model=model, optimizer=optimizer, *args, **kwargs) @override def optimizer_step( # type: ignore[override] @@ -124,19 +97,7 @@ def optimizer_step( # type: ignore[override] closure: Callable[[], Any], **kwargs: Any, ) -> Any: - if isinstance(optimizer, LBFGS): - raise MisconfigurationException("DeepSpeed and the LBFGS optimizer are not compatible.") - closure_result = closure() - self._after_closure(model, optimizer) - skipped_backward = closure_result is None - # in manual optimization, the closure does not return a value - if model.automatic_optimization and skipped_backward: - raise MisconfigurationException( - "Skipping backward by returning `None` from your `training_step` is not supported by `DeepSpeed`" - ) - # DeepSpeed handles the optimizer step internally - deepspeed_engine: deepspeed.DeepSpeedEngine = model.trainer.model - return deepspeed_engine.step(**kwargs) + return self.deepspeed_precision_impl.optimizer_step(optimizer=optimizer, model=model, closure=closure, **kwargs) @override def clip_gradients( @@ -145,4 +106,22 @@ def clip_gradients( clip_val: Union[int, float] = 0.0, gradient_clip_algorithm: GradClipAlgorithmType = GradClipAlgorithmType.NORM, ) -> None: - """DeepSpeed handles gradient clipping internally.""" + return self.deepspeed_precision_impl.clip_gradients( + optimizer=optimizer, clip_val=clip_val, gradient_clip_algorithm=gradient_clip_algorithm + ) + + @property + def precision(self) -> str: + return self.deepspeed_precision_impl.precision + + @precision.setter + def precision(self, value: str) -> None: + self.deepspeed_precision_impl.precision = value + + @property + def _desired_dtype(self) -> torch.dtype: + return self.deepspeed_precision_impl._desired_dtype + + @_desired_dtype.setter + def _desired_dtype(self, value: torch.dtype) -> None: + self.deepspeed_precision_impl._desired_dtype = value diff --git a/src/lightning/pytorch/plugins/precision/xla.py b/src/lightning/pytorch/plugins/precision/xla.py index 6890cc4c1d825..5f84399401fe3 100644 --- a/src/lightning/pytorch/plugins/precision/xla.py +++ b/src/lightning/pytorch/plugins/precision/xla.py @@ -11,19 +11,16 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import os -from functools import partial from typing import Any, Callable import torch -from typing_extensions import get_args, override +from typing_extensions import override import lightning.pytorch as pl -from lightning.fabric.accelerators.xla import _XLA_AVAILABLE from lightning.fabric.plugins.precision.xla import _PRECISION_INPUT +from lightning.fabric.utilities.imports import _raise_enterprise_not_available from lightning.fabric.utilities.types import Optimizable from lightning.pytorch.plugins.precision.precision import Precision -from lightning.pytorch.utilities.exceptions import MisconfigurationException class XLAPrecision(Precision): @@ -39,25 +36,12 @@ class XLAPrecision(Precision): """ def __init__(self, precision: _PRECISION_INPUT = "32-true") -> None: - if not _XLA_AVAILABLE: - raise ModuleNotFoundError(str(_XLA_AVAILABLE)) - - supported_precision = get_args(_PRECISION_INPUT) - if precision not in supported_precision: - raise ValueError( - f"`precision={precision!r})` is not supported in XLA." - f" `precision` must be one of: {supported_precision}." - ) - self.precision = precision - - if precision == "16-true": - os.environ["XLA_USE_F16"] = "1" - self._desired_dtype = torch.float16 - elif precision == "bf16-true": - os.environ["XLA_USE_BF16"] = "1" - self._desired_dtype = torch.bfloat16 - else: - self._desired_dtype = torch.float32 + super().__init__() + + _raise_enterprise_not_available() + from pytorch_lightning_enterprise.plugins.precision.xla import XLAPrecision as EnterpriseXLAPrecision + + self.xla_impl = EnterpriseXLAPrecision(precision) @override def optimizer_step( # type: ignore[override] @@ -67,31 +51,24 @@ def optimizer_step( # type: ignore[override] closure: Callable[[], Any], **kwargs: Any, ) -> Any: - import torch_xla.core.xla_model as xm - - closure = partial(self._xla_wrap_closure, optimizer, closure) - closure = partial(self._wrap_closure, model, optimizer, closure) - closure_result = optimizer.step(closure=closure, **kwargs) - xm.mark_step() - skipped_backward = closure_result is None - # in manual optimization, the closure does not return a value - if model.automatic_optimization and skipped_backward: - # we lack coverage here so disable this - something to explore if there's demand - raise MisconfigurationException( - "Skipping backward by returning `None` from your `training_step` is not implemented with XLA." - " Please, open an issue in `https://github.com/Lightning-AI/pytorch-lightning/issues`" - " requesting this feature." - ) - return closure_result + return self.xla_impl.optimizer_step(optimizer, model, closure, **kwargs) - @override - def teardown(self) -> None: - os.environ.pop("XLA_USE_BF16", None) - os.environ.pop("XLA_USE_F16", None) + @property + def precision(self) -> _PRECISION_INPUT: + return self.xla_impl.precision - def _xla_wrap_closure(self, optimizer: Optimizable, closure: Callable[[], Any]) -> Any: - import torch_xla.core.xla_model as xm + @precision.setter + def precision(self, precision: _PRECISION_INPUT) -> None: + self.xla_impl.precision = precision - closure_result = closure() - xm.reduce_gradients(optimizer) - return closure_result + @property + def _desired_dtype(self) -> torch.dtype: + return self.xla_impl._desired_dtype + + @_desired_dtype.setter + def _desired_dtype(self, dtype: torch.dtype) -> None: + self.xla_impl._desired_dtype = dtype + + @override + def teardown(self) -> None: + return self.xla_impl.teardown() diff --git a/src/lightning/pytorch/strategies/deepspeed.py b/src/lightning/pytorch/strategies/deepspeed.py index 1fce7b06887cd..856760815b698 100644 --- a/src/lightning/pytorch/strategies/deepspeed.py +++ b/src/lightning/pytorch/strategies/deepspeed.py @@ -11,52 +11,27 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import argparse -import json + import logging -import os -import platform from collections import OrderedDict from collections.abc import Generator, Mapping from contextlib import contextmanager from datetime import timedelta -from pathlib import Path from typing import TYPE_CHECKING, Any, Optional, Union import torch from torch.nn import Module from torch.optim import Optimizer -from torch.optim.lr_scheduler import LRScheduler, ReduceLROnPlateau from typing_extensions import override import lightning.pytorch as pl from lightning.fabric.plugins import ClusterEnvironment from lightning.fabric.plugins.collectives.torch_collective import default_pg_timeout from lightning.fabric.strategies import _StrategyRegistry -from lightning.fabric.strategies.deepspeed import ( - _DEEPSPEED_AVAILABLE, - _DEEPSPEED_GREATER_EQUAL_0_16, - _format_precision_config, - _validate_checkpoint_directory, - _validate_device_index_selection, -) -from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_6 -from lightning.fabric.utilities.optimizer import _optimizers_to_device -from lightning.fabric.utilities.seed import reset_seed +from lightning.fabric.utilities.imports import _raise_enterprise_not_available from lightning.fabric.utilities.types import _PATH -from lightning.pytorch.accelerators.cuda import CUDAAccelerator -from lightning.pytorch.core.optimizer import _init_optimizers_and_lr_schedulers from lightning.pytorch.plugins.precision import Precision from lightning.pytorch.strategies.ddp import DDPStrategy -from lightning.pytorch.trainer.states import TrainerFn -from lightning.pytorch.utilities import GradClipAlgorithmType -from lightning.pytorch.utilities.exceptions import MisconfigurationException -from lightning.pytorch.utilities.model_helpers import is_overridden -from lightning.pytorch.utilities.rank_zero import WarningCache, rank_zero_info, rank_zero_warn -from lightning.pytorch.utilities.types import LRSchedulerConfig - -log = logging.getLogger(__name__) -warning_cache = WarningCache() if TYPE_CHECKING: import deepspeed @@ -258,25 +233,6 @@ def __init__( exclude_frozen_parameters: Exclude frozen parameters when saving checkpoints. """ - if not _DEEPSPEED_AVAILABLE: - raise MisconfigurationException( - "To use the `DeepSpeedStrategy`, you must have DeepSpeed installed." - " Install it by running `pip install -U deepspeed`." - ) - - if _TORCH_GREATER_EQUAL_2_6 and not _DEEPSPEED_GREATER_EQUAL_0_16: - # Starting with PyTorch 2.6, `torch.load` defaults to `weights_only=True` when loading full checkpoints. - # DeepSpeed added support for this behavior in version 0.16.0. - import deepspeed - - deepspeed_version = deepspeed.__version__ - - raise ImportError( - f"PyTorch >= 2.6 requires DeepSpeed >= 0.16.0. " - f"Detected DeepSpeed version: {deepspeed_version}. " - "Please upgrade by running `pip install -U 'deepspeed>=0.16.0'`." - ) - super().__init__( accelerator=accelerator, parallel_devices=parallel_devices, @@ -284,124 +240,77 @@ def __init__( precision_plugin=precision_plugin, process_group_backend=process_group_backend, ) - self._timeout: Optional[timedelta] = timeout - - self.config = self._load_config(config) - if self.config is None: - # User has not overridden config, set defaults - self.config = self._create_default_config( - zero_optimization, - zero_allow_untested_optimizer, - logging_batch_size_per_gpu, - offload_optimizer=offload_optimizer, - offload_parameters=offload_parameters, - nvme_path=nvme_path, - offload_params_device=offload_params_device, - params_buffer_count=params_buffer_count, - params_buffer_size=params_buffer_size, - max_in_cpu=max_in_cpu, - pin_memory=pin_memory, - offload_optimizer_device=offload_optimizer_device, - optimizer_buffer_count=optimizer_buffer_count, - block_size=block_size, - queue_depth=queue_depth, - single_submit=single_submit, - overlap_events=overlap_events, - thread_count=thread_count, - partition_activations=partition_activations, - cpu_checkpointing=cpu_checkpointing, - contiguous_memory_optimization=contiguous_memory_optimization, - synchronize_checkpoint_boundary=synchronize_checkpoint_boundary, - stage=stage, - contiguous_gradients=contiguous_gradients, - overlap_comm=overlap_comm, - allgather_partitions=allgather_partitions, - reduce_scatter=reduce_scatter, - allgather_bucket_size=allgather_bucket_size, - reduce_bucket_size=reduce_bucket_size, - sub_group_size=sub_group_size, - ) - import deepspeed - - self._config_initialized = False - deepspeed.utils.logging.logger.setLevel(logging_level) - - self.remote_device = remote_device - self.load_full_weights = load_full_weights - self.exclude_frozen_parameters = exclude_frozen_parameters - - # default FP16 parameters. - self.loss_scale = loss_scale - self.initial_scale_power = initial_scale_power - self.loss_scale_window = loss_scale_window - self.hysteresis = hysteresis - self.min_loss_scale = min_loss_scale + _raise_enterprise_not_available() + from pytorch_lightning_enterprise.strategies.deepspeed import ( + DeepSpeedStrategyTrainer as EnterpriseDeepSpeedStrategy, + ) + + self.deepspeed_strategy_impl = EnterpriseDeepSpeedStrategy( + outer_object=self, + accelerator=accelerator, + zero_optimization=zero_optimization, + stage=stage, + remote_device=remote_device, + offload_optimizer=offload_optimizer, + offload_parameters=offload_parameters, + offload_params_device=offload_params_device, + nvme_path=nvme_path, + params_buffer_count=params_buffer_count, + params_buffer_size=params_buffer_size, + max_in_cpu=max_in_cpu, + offload_optimizer_device=offload_optimizer_device, + optimizer_buffer_count=optimizer_buffer_count, + block_size=block_size, + queue_depth=queue_depth, + single_submit=single_submit, + overlap_events=overlap_events, + thread_count=thread_count, + pin_memory=pin_memory, + sub_group_size=sub_group_size, + contiguous_gradients=contiguous_gradients, + overlap_comm=overlap_comm, + allgather_partitions=allgather_partitions, + reduce_scatter=reduce_scatter, + allgather_bucket_size=allgather_bucket_size, + reduce_bucket_size=reduce_bucket_size, + zero_allow_untested_optimizer=zero_allow_untested_optimizer, + logging_batch_size_per_gpu=logging_batch_size_per_gpu, + config=config, + logging_level=logging_level, + parallel_devices=parallel_devices, + cluster_environment=cluster_environment, + loss_scale=loss_scale, + initial_scale_power=initial_scale_power, + loss_scale_window=loss_scale_window, + hysteresis=hysteresis, + min_loss_scale=min_loss_scale, + partition_activations=partition_activations, + cpu_checkpointing=cpu_checkpointing, + contiguous_memory_optimization=contiguous_memory_optimization, + synchronize_checkpoint_boundary=synchronize_checkpoint_boundary, + load_full_weights=load_full_weights, + precision_plugin=precision_plugin, + process_group_backend=process_group_backend, + timeout=timeout, + exclude_frozen_parameters=exclude_frozen_parameters, + ) @override def setup_environment(self) -> None: - if not isinstance(self.accelerator, CUDAAccelerator): - raise RuntimeError( - f"The DeepSpeed strategy is only supported on CUDA GPUs but `{self.accelerator.__class__.__name__}`" - " is used." - ) - super().setup_environment() + return self.deepspeed_strategy_impl.setup_environment() @override def setup_distributed(self) -> None: - assert self.parallel_devices is not None - _validate_device_index_selection(self.parallel_devices) - reset_seed() - self.set_world_ranks() - self._init_deepspeed_distributed() + return self.deepspeed_strategy_impl.setup_distributed() @override def setup(self, trainer: "pl.Trainer") -> None: - self._init_config_if_needed() - assert self.accelerator is not None - self.accelerator.setup(trainer) - - assert self.model is not None - self.model = self.precision_plugin.convert_module(self.model) - self.model = self._setup_model(self.model) - - if trainer.state.fn == TrainerFn.FITTING: - self.setup_optimizers(trainer) - self.setup_precision_plugin() - if trainer.state.fn == TrainerFn.FITTING: - _optimizers_to_device(self.optimizers, self.root_device) - - self.init_deepspeed() - self.barrier() - - def _init_deepspeed_distributed(self) -> None: - import deepspeed - - assert self.cluster_environment is not None - if platform.system() != "Windows": - # do not set env variables on windows, allow deepspeed to control setup - self._set_node_environment_variables() - log.info( - "initializing deepspeed distributed: " - f"GLOBAL_RANK: {self.global_rank}, " - f"MEMBER: {self.global_rank + 1}/{self.world_size}" - ) - self._process_group_backend = self._get_process_group_backend() - deepspeed.init_distributed( - self._process_group_backend, distributed_port=self.cluster_environment.main_port, timeout=self._timeout - ) - - def _set_node_environment_variables(self) -> None: - assert self.cluster_environment is not None - os.environ["MASTER_ADDR"] = self.cluster_environment.main_address - os.environ["MASTER_PORT"] = str(self.cluster_environment.main_port) - os.environ["RANK"] = str(self.global_rank) - os.environ["WORLD_SIZE"] = str(self.world_size) - os.environ["LOCAL_RANK"] = str(self.local_rank) + return self.deepspeed_strategy_impl.setup(trainer=trainer) @property @override def restore_checkpoint_after_setup(self) -> bool: - return True + return self.deepspeed_strategy_impl.restore_checkpoint_after_setup @override def _setup_model_and_optimizers( @@ -416,188 +325,28 @@ def _setup_model_and_optimizers( deepspeed optimizer. """ - if len(optimizers) != 1: - raise ValueError( - f"Currently only one optimizer is supported with DeepSpeed. Got {len(optimizers)} optimizers instead." - ) - - # train_micro_batch_size_per_gpu is used for throughput logging purposes - # normally we set this to the batch size, but it is not available here unless the user provides it - # as part of the config - assert self.config is not None - self.config.setdefault("train_micro_batch_size_per_gpu", 1) - self.model, optimizer = self._setup_model_and_optimizer(model, optimizers[0]) - self._set_deepspeed_activation_checkpointing() - return self.model, [optimizer] - - def _setup_model_and_optimizer( - self, - model: Module, - optimizer: Optional[Optimizer], - lr_scheduler: Optional[Union[LRScheduler, ReduceLROnPlateau]] = None, - ) -> tuple["deepspeed.DeepSpeedEngine", Optimizer]: - """Initialize one model and one optimizer with an optional learning rate scheduler. - - This calls ``deepspeed.initialize`` internally. - - """ - import deepspeed - - model_parameters = filter(lambda p: p.requires_grad, model.parameters()) - deepspeed_engine, deepspeed_optimizer, _, _ = deepspeed.initialize( - args=argparse.Namespace(device_rank=self.root_device.index), - config=self.config, - model=model, - model_parameters=model_parameters, - optimizer=optimizer, - lr_scheduler=lr_scheduler, - dist_init_required=False, - ) - return deepspeed_engine, deepspeed_optimizer - - def init_deepspeed(self) -> None: - assert self.lightning_module is not None - # deepspeed handles gradient clipping internally - if is_overridden("configure_gradient_clipping", self.lightning_module, pl.LightningModule): - rank_zero_warn( - "Since DeepSpeed handles gradient clipping internally, the default" - " `LightningModule.configure_gradient_clipping` implementation will not actually clip gradients." - " The hook will still be called. Consider setting" - " `Trainer(gradient_clip_val=..., gradient_clip_algorithm='norm')`" - " which will use the internal mechanism." - ) - - if self.lightning_module.trainer.gradient_clip_algorithm == GradClipAlgorithmType.VALUE: - raise MisconfigurationException("DeepSpeed does not support clipping gradients by value.") - - assert isinstance(self.model, pl.LightningModule) - if self.lightning_module.trainer and self.lightning_module.trainer.training: - self._initialize_deepspeed_train(self.model) - else: - self._initialize_deepspeed_inference(self.model) - - def _init_optimizers(self) -> tuple[Optimizer, Optional[LRSchedulerConfig]]: - assert self.lightning_module is not None - optimizers, lr_schedulers = _init_optimizers_and_lr_schedulers(self.lightning_module) - if len(optimizers) > 1 or len(lr_schedulers) > 1: - raise MisconfigurationException( - "DeepSpeed currently only supports single optimizer, single optional scheduler." - ) - return optimizers[0], lr_schedulers[0] if lr_schedulers else None + return self.deepspeed_strategy_impl._setup_model_and_optimizers(model=model, optimizers=optimizers) @property def zero_stage_3(self) -> bool: - assert isinstance(self.config, dict) - zero_optimization = self.config.get("zero_optimization") - return zero_optimization is not None and zero_optimization.get("stage") == 3 - - def _initialize_deepspeed_train(self, model: Module) -> None: - optimizer, scheduler = None, None - assert isinstance(self.config, dict) - if "optimizer" in self.config: - rank_zero_info( - "You have specified an optimizer and/or scheduler within the DeepSpeed config." - " It is recommended to define it in `LightningModule.configure_optimizers`." - ) - lr_scheduler = None - else: - ( - optimizer, - lr_scheduler, - ) = self._init_optimizers() - if lr_scheduler is not None: - scheduler = lr_scheduler.scheduler - - model, deepspeed_optimizer = self._setup_model_and_optimizer(model, optimizer, scheduler) - self._set_deepspeed_activation_checkpointing() - - # although we set these here, deepspeed manages the specific optimizer logic - self.optimizers = [deepspeed_optimizer] - - deepspeed_scheduler = model.lr_scheduler - if deepspeed_scheduler is not None: - # disable deepspeed lr scheduling as lightning manages scheduling - model.lr_scheduler = None - if lr_scheduler is None: - lr_scheduler = LRSchedulerConfig(deepspeed_scheduler, interval="step") - else: - lr_scheduler.scheduler = deepspeed_scheduler - self.lr_scheduler_configs = [lr_scheduler] - self.model = model + return self.deepspeed_strategy_impl.zero_stage_3 @contextmanager @override def tensor_init_context(self, empty_init: Optional[bool] = None) -> Generator[None, None, None]: - if self.zero_stage_3: - if empty_init is False: - raise NotImplementedError( - f"`{empty_init=}` is not a valid choice with `DeepSpeedStrategy` when ZeRO stage 3 is enabled." - ) - yield - return - with super().tensor_init_context(empty_init=empty_init): + with self.deepspeed_strategy_impl.tensor_init_context(empty_init=empty_init): yield @contextmanager @override def model_sharded_context(self) -> Generator[None, None, None]: - import deepspeed - - self._init_config_if_needed() - with deepspeed.zero.Init( - enabled=self.zero_stage_3, - remote_device=self.remote_device, - config_dict_or_path=self.config, - ): + with self.deepspeed_strategy_impl.model_sharded_context(): yield - def _set_deepspeed_activation_checkpointing(self) -> None: - import deepspeed - - assert isinstance(self.config, dict) - if self.config.get("activation_checkpointing"): - checkpoint_config = self.config["activation_checkpointing"] - deepspeed.checkpointing.configure( - mpu_=None, - partition_activations=checkpoint_config.get("partition_activations"), - contiguous_checkpointing=checkpoint_config.get("contiguous_memory_optimization"), - checkpoint_in_cpu=checkpoint_config.get("cpu_checkpointing"), - profile=checkpoint_config.get("profile"), - ) - - def _initialize_deepspeed_inference(self, model: Module) -> None: - import deepspeed - - assert isinstance(self.config, dict) - - # todo: this is required for DeepSpeed throughput timers - inference_config = {"train_micro_batch_size_per_gpu": 1} - if "fp16" in self.config: - inference_config.update({"fp16": self.config["fp16"]}) - if "bf16" in self.config: - inference_config.update({"bf16": self.config["bf16"]}) - if self.zero_stage_3: - inference_config.update({ - "zero_allow_untested_optimizer": self.config["zero_allow_untested_optimizer"], - "zero_optimization": self.config["zero_optimization"], - }) - # Remove all module hooks before initializing new model - remove_module_hooks(model) - model, _, _, _ = deepspeed.initialize( - args=argparse.Namespace(device_rank=self.root_device.index), - config=inference_config, - model=model, - optimizer=None, - lr_scheduler=None, - model_parameters=[], - dist_init_required=False, - ) - self.model = model - @property @override def distributed_sampler_kwargs(self) -> dict[str, int]: - return {"num_replicas": self.world_size, "rank": self.global_rank} + return self.deepspeed_strategy_impl.distributed_sampler_kwargs() @override def setup_optimizers(self, trainer: "pl.Trainer") -> None: @@ -607,29 +356,21 @@ def setup_optimizers(self, trainer: "pl.Trainer") -> None: trainer: the Trainer, these optimizers should be connected to """ - # Skip initializing optimizers here as DeepSpeed handles optimizers via config. - # User may have specified config options instead in configure_optimizers, but this is handled - # via `_initialize_deepspeed_train` - # empty optimizers, schedulers - self.optimizers = [] - self.lr_scheduler_configs = [] - - def _setup_model(self, model: Module) -> Module: # type: ignore[override] - return model + return self.deepspeed_strategy_impl.setup_optimizers(trainer=trainer) @property @override def handles_gradient_accumulation(self) -> bool: """Whether the strategy handles gradient accumulation internally.""" - return True + return self.deepspeed_strategy_impl.handles_gradient_accumulation @property def deepspeed_engine(self) -> "deepspeed.DeepSpeedEngine": - return self.model + return self.deepspeed_strategy_impl.deepspeed_engine @property def _multi_device(self) -> bool: - return self.num_processes > 1 or self.num_nodes > 1 + return self.deepspeed_strategy_impl._multi_device @override def save_checkpoint(self, checkpoint: dict, filepath: _PATH, storage_options: Optional[Any] = None) -> None: @@ -645,135 +386,27 @@ def save_checkpoint(self, checkpoint: dict, filepath: _PATH, storage_options: Op If ``storage_options`` arg is passed in """ - # broadcast the filepath from rank 0 to ensure all the states are saved in a common filepath - filepath = self.broadcast(filepath) - - if storage_options is not None: - raise TypeError( - "`Trainer.save_checkpoint(..., storage_options=...)` with `storage_options` arg" - f" is not supported for `{self.__class__.__name__}` as `CheckpointIO` is not used." - ) - - if self.zero_stage_3 and self._multi_device and self.is_global_zero: - warning_cache.warn( - "When saving the DeepSpeed Stage 3 checkpoint, " - "each worker will save a shard of the checkpoint within a directory. " - "If a single file is required after training, " - "see https://lightning.ai/docs/pytorch/stable/advanced/model_parallel.html#" - "deepspeed-zero-stage-3-single-file for instructions." - ) - # Use deepspeed's internal checkpointing function to handle partitioned weights across processes - # dump states as a checkpoint dictionary object - _exclude_keys = ["state_dict", "optimizer_states"] - checkpoint = {k: v for k, v in checkpoint.items() if k not in _exclude_keys} - self.deepspeed_engine.save_checkpoint( - filepath, - client_state=checkpoint, - tag="checkpoint", - exclude_frozen_parameters=self.exclude_frozen_parameters, + return self.deepspeed_strategy_impl.save_checkpoint( + checkpoint=checkpoint, filepath=filepath, storage_options=storage_options ) @override def load_checkpoint(self, checkpoint_path: _PATH, weights_only: Optional[bool] = None) -> dict[str, Any]: - if self.load_full_weights and self.zero_stage_3: - # Broadcast to ensure we load from the rank 0 checkpoint - # This doesn't have to be the case when using deepspeed sharded checkpointing - checkpoint_path = self.broadcast(checkpoint_path) - return super().load_checkpoint(checkpoint_path, weights_only) - - _validate_checkpoint_directory(checkpoint_path) - - # Rely on deepspeed to load the checkpoint and necessary information - assert self.lightning_module is not None - - from lightning.pytorch.trainer.states import TrainerFn - - is_fitting = self.lightning_module.trainer.state.fn == TrainerFn.FITTING - - _, client_state = self.deepspeed_engine.load_checkpoint( - checkpoint_path, - load_optimizer_states=is_fitting, - load_lr_scheduler_states=False, - load_module_strict=self.lightning_module.strict_loading, - ) - if client_state is None: - raise MisconfigurationException( - "DeepSpeed was unable to load the checkpoint. Ensure you passed in a DeepSpeed compatible checkpoint " - "or a single checkpoint file with `Trainer(strategy=DeepSpeedStrategy(load_full_weights=True))`." - ) - return client_state + return self.deepspeed_strategy_impl.load_checkpoint(checkpoint_path=checkpoint_path, weights_only=weights_only) @property @override def lightning_restore_optimizer(self) -> bool: - assert self.lightning_module is not None - # managed by DeepSpeed - if self.load_full_weights and self.zero_stage_3 and self.lightning_module.trainer.state.fn == TrainerFn.FITTING: - rank_zero_warn( - "A single checkpoint file has been given. This means optimizer states cannot be restored." - " If you'd like to restore these states, you must provide a path to the originally saved DeepSpeed" - " checkpoint. When using ZeRO 3, the original path should be a directory." - ) - return False + return self.deepspeed_strategy_impl.lightning_restore_optimizer @override def load_model_state_dict(self, checkpoint: Mapping[str, Any], strict: bool = True) -> None: # override to do nothing, deepspeed engine already loaded the weights in `load_checkpoint()` - if self.load_full_weights and self.zero_stage_3: - self.model_to_device() - self._restore_zero_state(checkpoint, strict=strict) - - def _restore_zero_state(self, ckpt: Mapping[str, Any], strict: bool) -> None: - """Overrides the normal load_state_dict behaviour in PyTorch to ensure we gather parameters that may be sharded - across processes before loading the state dictionary when using ZeRO stage 3. This is then automatically synced - across processes. - - Args: - ckpt: The ckpt file. - - """ - import deepspeed - - assert self.lightning_module is not None - - def load(module: torch.nn.Module, prefix: str = "") -> None: - missing_keys: list[str] = [] - unexpected_keys: list[str] = [] - error_msgs: list[str] = [] - state_dict = ckpt["state_dict"] - - # copy state_dict so _load_from_state_dict can modify it - metadata = getattr(state_dict, "_metadata", None) - state_dict = state_dict.copy() - if metadata is not None: - state_dict._metadata = metadata - - local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {}) - # because zero3 puts placeholders in model params, this context - # manager gathers (unpartitions) the params of the current layer, then loads from - # the state dict and then re-partitions them again - with deepspeed.zero.GatheredParameters(list(module.parameters(recurse=False)), modifier_rank=0): - if self.is_global_zero: - module._load_from_state_dict( - state_dict=state_dict, - prefix=prefix, - local_metadata=local_metadata, - strict=strict, - missing_keys=missing_keys, - unexpected_keys=unexpected_keys, - error_msgs=error_msgs, - ) - - for name, child in module._modules.items(): - if child is not None: - load(child, prefix + name + ".") - - load(self.lightning_module, prefix="") + return self.deepspeed_strategy_impl.load_model_state_dict(checkpoint=checkpoint, strict=strict) @override def load_optimizer_state_dict(self, checkpoint: Mapping[str, Any]) -> None: - # Override to do nothing, the deepspeed engine already loaded the states in `load_checkpoint()` - pass + return self.deepspeed_strategy_impl.load_optimizer_state_dict(checkpoint=checkpoint) @classmethod @override @@ -809,137 +442,6 @@ def register_strategies(cls, strategy_registry: _StrategyRegistry) -> None: offload_optimizer_device="nvme", ) - def _load_config(self, config: Optional[Union[_PATH, dict[str, Any]]]) -> Optional[dict[str, Any]]: - if config is None and self.DEEPSPEED_ENV_VAR in os.environ: - rank_zero_info(f"Loading DeepSpeed config from set {self.DEEPSPEED_ENV_VAR} environment variable") - config = os.environ[self.DEEPSPEED_ENV_VAR] - if isinstance(config, (str, Path)): - if not os.path.isfile(config): - raise MisconfigurationException( - f"You passed in a path to a DeepSpeed config but the path does not exist: {config}" - ) - with open(config) as f: - config = json.load(f) - assert isinstance(config, dict) or config is None - return config - - def _init_config_if_needed(self) -> None: - if not self._config_initialized: - self._format_config() - self._config_initialized = True - - def _format_config(self) -> None: - if self.config is None: - raise MisconfigurationException( - "To use DeepSpeed you must pass in a DeepSpeed config dict, or a path to a JSON config." - " See: https://lightning.ai/docs/pytorch/stable/advanced/model_parallel.html#deepspeed" - ) - self._format_batch_size_and_grad_accum_config() - _format_precision_config( - config=self.config, - precision=self.precision_plugin.precision, - loss_scale=self.loss_scale, - loss_scale_window=self.loss_scale_window, - min_loss_scale=self.min_loss_scale, - initial_scale_power=self.initial_scale_power, - hysteresis=self.hysteresis, - ) - - def _create_default_config( - self, - zero_optimization: bool, - zero_allow_untested_optimizer: bool, - logging_batch_size_per_gpu: Union[str, int], - partition_activations: bool, - cpu_checkpointing: bool, - contiguous_memory_optimization: bool, - synchronize_checkpoint_boundary: bool, - offload_optimizer: bool, - offload_parameters: bool, - nvme_path: str, - offload_params_device: str, - params_buffer_count: int, - params_buffer_size: int, - max_in_cpu: int, - offload_optimizer_device: str, - optimizer_buffer_count: int, - pin_memory: bool, - block_size: int, - queue_depth: int, - single_submit: bool, - overlap_events: bool, - thread_count: int, - **zero_kwargs: Any, - ) -> dict: - cfg = { - "activation_checkpointing": { - "partition_activations": partition_activations, - "cpu_checkpointing": cpu_checkpointing, - "contiguous_memory_optimization": contiguous_memory_optimization, - "synchronize_checkpoint_boundary": synchronize_checkpoint_boundary, - }, - "aio": { - "block_size": block_size, - "queue_depth": queue_depth, - "single_submit": single_submit, - "overlap_events": overlap_events, - "thread_count": thread_count, - }, - } - if zero_optimization: - zero_config = zero_kwargs - - if offload_optimizer: - zero_config["offload_optimizer"] = { - "device": offload_optimizer_device, - "nvme_path": nvme_path, - "buffer_count": optimizer_buffer_count, - "pin_memory": pin_memory, - } - if offload_parameters: - zero_config["offload_param"] = { - "device": offload_params_device, - "nvme_path": nvme_path, - "buffer_count": params_buffer_count, - "buffer_size": params_buffer_size, - "max_in_cpu": max_in_cpu, - "pin_memory": pin_memory, - } - cfg = { - "zero_allow_untested_optimizer": zero_allow_untested_optimizer, - "zero_optimization": zero_config, - **cfg, - } - if logging_batch_size_per_gpu != "auto": - cfg = {"train_micro_batch_size_per_gpu": logging_batch_size_per_gpu, **cfg} - return cfg - - def _format_batch_size_and_grad_accum_config(self) -> None: - # TODO: Using Fabric, we do not support these variables within the config - assert isinstance(self.config, dict) - if self.lightning_module is None: - return - - if "gradient_accumulation_steps" in self.config: - raise MisconfigurationException( - "Do not set `gradient_accumulation_steps` in the DeepSpeed config" - " as this will be set with the `accumulate_grad_batches` argument passed via the Lightning Trainer." - ) - self.config["gradient_accumulation_steps"] = self.lightning_module.trainer.accumulate_grad_batches - if "train_micro_batch_size_per_gpu" not in self.config: - batch_size = self._auto_select_batch_size() - self.config["train_micro_batch_size_per_gpu"] = batch_size - if "gradient_clipping" not in self.config: - self.config["gradient_clipping"] = self.lightning_module.trainer.gradient_clip_val or 0.0 - - def _auto_select_batch_size(self) -> int: - # train_micro_batch_size_per_gpu is used for throughput logging purposes - # by default we try to use the batch size of the loader - assert self.lightning_module is not None - batch_size = 1 - data_source = self.lightning_module.trainer.fit_loop._data_source - if data_source.is_defined(): - train_dataloader = data_source.dataloader() - if hasattr(train_dataloader, "batch_sampler"): - batch_size = train_dataloader.batch_sampler.batch_size - return batch_size + @property + def config(self) -> dict[str, Any]: + return self.deepspeed_strategy_impl.config diff --git a/src/lightning/pytorch/strategies/launchers/xla.py b/src/lightning/pytorch/strategies/launchers/xla.py index 066fecc79f208..63842d5557d7e 100644 --- a/src/lightning/pytorch/strategies/launchers/xla.py +++ b/src/lightning/pytorch/strategies/launchers/xla.py @@ -11,23 +11,18 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import os import queue from typing import TYPE_CHECKING, Any, Callable, Optional, Union import torch.multiprocessing as mp from typing_extensions import override -from lightning.fabric.accelerators.xla import _XLA_AVAILABLE -from lightning.fabric.strategies.launchers.xla import _rank_teardown -from lightning.fabric.utilities import move_data_to_device +from lightning.fabric.utilities.imports import _raise_enterprise_not_available from lightning.pytorch.strategies.launchers.multiprocessing import ( _GlobalStateSnapshot, _MultiProcessingLauncher, _WorkerOutput, ) -from lightning.pytorch.trainer.states import TrainerFn -from lightning.pytorch.utilities.rank_zero import rank_zero_debug if TYPE_CHECKING: import lightning.pytorch as pl @@ -51,14 +46,16 @@ class _XLALauncher(_MultiProcessingLauncher): """ def __init__(self, strategy: "pl.strategies.XLAStrategy") -> None: - if not _XLA_AVAILABLE: - raise ModuleNotFoundError(str(_XLA_AVAILABLE)) - super().__init__(strategy=strategy, start_method="fork") + super().__init__(strategy) + _raise_enterprise_not_available() + from pytorch_lightning_enterprise.strategies.xla.launcher import _XLALauncherTrainer as EnterpriseXLALauncher + + self.xla_launcher_impl = EnterpriseXLALauncher(strategy) @property @override def is_interactive_compatible(self) -> bool: - return True + return self.xla_launcher_impl.is_interactive_compatible @override def launch(self, function: Callable, *args: Any, trainer: Optional["pl.Trainer"] = None, **kwargs: Any) -> Any: @@ -75,46 +72,7 @@ def launch(self, function: Callable, *args: Any, trainer: Optional["pl.Trainer"] **kwargs: Optional keyword arguments to be passed to the given function. """ - if self._already_fit and trainer is not None and trainer.state.fn == TrainerFn.FITTING: - # resolving https://github.com/Lightning-AI/pytorch-lightning/issues/18775 will lift this restriction - raise NotImplementedError( - "Calling `trainer.fit()` twice on the same Trainer instance using a spawn-based strategy is not" - " supported. You can work around this by creating a new Trainer instance and passing the" - " `fit(ckpt_path=...)` argument." - ) - - # pjrt requires that the queue is serializable - return_queue = mp.Manager().Queue() - - import torch_xla.distributed.xla_multiprocessing as xmp - - spawn_kwargs = {} - nprocs = self._strategy.num_processes - if nprocs == 1: - # avoid warning: "Unsupported nprocs". If it's 1, it will call the launched function directly. - # otherwise it will use all devices - spawn_kwargs["nprocs"] = nprocs - - process_context = xmp.spawn( - self._wrapping_function, - args=(trainer, function, args, kwargs, return_queue), - start_method=self._start_method, - join=False, # we will join ourselves to get the process references - **spawn_kwargs, - ) - # xla will not actually create processes if only 1 device - if process_context is not None: - self.procs = process_context.processes - while not process_context.join(): - pass - - worker_output = return_queue.get() - if trainer is None: - return worker_output - - self._already_fit |= trainer.state.fn == TrainerFn.FITTING - self._recover_results_in_main_process(worker_output, trainer) - return worker_output.trainer_results + return self.xla_launcher_impl.launch(function, *args, trainer=trainer, **kwargs) @override def _wrapping_function( @@ -129,48 +87,10 @@ def _wrapping_function( return_queue: Union[mp.SimpleQueue, queue.Queue], global_states: Optional[_GlobalStateSnapshot] = None, ) -> None: - import torch_xla.core.xla_model as xm - - if len(xm.get_xla_supported_devices()) > 1: - # `get_xla_supported_devices` in the spawned process returns the logical devices (2 for v2/v3 and 1 for v4) - # so when there's more than one (multithreading), objects need to be deep-copied - import copy - - trainer, function, args, kwargs = copy.deepcopy((trainer, function, args, kwargs)) - - results = function(*args, **kwargs) - - if trainer is not None: - results = self._collect_rank_zero_results(trainer, results) - - if self._strategy.local_rank == 0: - return_queue.put(move_data_to_device(results, "cpu")) - - _rank_teardown(self._strategy.local_rank) + return self.xla_launcher_impl._wrapping_function( + process_idx, trainer, function, args, kwargs, return_queue, global_states + ) @override def _collect_rank_zero_results(self, trainer: "pl.Trainer", results: Any) -> Optional["_WorkerOutput"]: - rank_zero_debug("Collecting results from rank 0 process.") - checkpoint_callback = trainer.checkpoint_callback - best_model_path = ( - checkpoint_callback.best_model_path - if checkpoint_callback and hasattr(checkpoint_callback, "best_model_path") - else None - ) - - # save the last weights - weights_path = None - if trainer.state.fn == TrainerFn.FITTING: - # requires to compute the state_dict on all processes in case Metrics are present - state_dict = self._strategy.lightning_module_state_dict() - weights_path = os.path.join(trainer.default_root_dir, ".temp.ckpt") - self._strategy.checkpoint_io.save_checkpoint(state_dict, weights_path) - - # We use `local_rank` here as separate filesystems are used for each VM for TPU Pod Training - if self._strategy.local_rank != 0: - return None - - # add extra result data from trainer to send to main process - extra = self.get_extra_results(trainer) - - return _WorkerOutput(best_model_path, weights_path, trainer.state, results, extra) + return self.xla_launcher_impl._collect_rank_zero_results(trainer, results) diff --git a/src/lightning/pytorch/strategies/single_xla.py b/src/lightning/pytorch/strategies/single_xla.py index 2a5e2f3a85b96..6e687ac815639 100644 --- a/src/lightning/pytorch/strategies/single_xla.py +++ b/src/lightning/pytorch/strategies/single_xla.py @@ -11,23 +11,18 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import os from typing import Optional, Union -import torch from typing_extensions import override import lightning.pytorch as pl -from lightning.fabric.accelerators.xla import _XLA_AVAILABLE from lightning.fabric.plugins import CheckpointIO, Precision, XLACheckpointIO from lightning.fabric.strategies import _StrategyRegistry -from lightning.fabric.utilities.optimizer import _optimizers_to_device +from lightning.fabric.utilities.imports import _raise_enterprise_not_available from lightning.fabric.utilities.types import _DEVICE from lightning.pytorch.plugins.io.wrapper import _WrappingCheckpointIO from lightning.pytorch.plugins.precision.xla import XLAPrecision from lightning.pytorch.strategies.single_device import SingleDeviceStrategy -from lightning.pytorch.trainer.states import TrainerFn -from lightning.pytorch.utilities import find_shared_parameters, set_shared_parameters class SingleDeviceXLAStrategy(SingleDeviceStrategy): @@ -41,20 +36,18 @@ def __init__( precision_plugin: Optional[XLAPrecision] = None, debug: bool = False, ): - if not _XLA_AVAILABLE: - raise ModuleNotFoundError(str(_XLA_AVAILABLE)) - if isinstance(device, torch.device): - # unwrap the `torch.device` in favor of `xla_device` - device = device.index - import torch_xla.core.xla_model as xm - super().__init__( accelerator=accelerator, - device=xm.xla_device(device), + device=device, checkpoint_io=checkpoint_io, precision_plugin=precision_plugin, ) - self.debug = debug + _raise_enterprise_not_available() + from pytorch_lightning_enterprise.strategies.xla.single import ( + SingleDeviceXLAStrategyTrainer as EnterpriseSingleDeviceXLAStrategy, + ) + + self.single_xla_strategy_impl = EnterpriseSingleDeviceXLAStrategy(outer_object=self, device=device, debug=debug) @property @override @@ -90,26 +83,7 @@ def precision_plugin(self, precision_plugin: Optional[Precision]) -> None: @override def setup(self, trainer: "pl.Trainer") -> None: - if self.debug: - os.environ["PT_XLA_DEBUG"] = str(1) - - assert self.accelerator is not None - self.accelerator.setup(trainer) - - assert self.model is not None - self.precision_plugin.convert_module(self.model) - - shared_params = find_shared_parameters(self.model) - self.model_to_device() - set_shared_parameters(self.model, shared_params) - - self.model = self._setup_model(self.model) - - if trainer.state.fn == TrainerFn.FITTING: - self.setup_optimizers(trainer) - self.setup_precision_plugin() - if trainer.state.fn == TrainerFn.FITTING: - _optimizers_to_device(self.optimizers, self.root_device) + return self.single_xla_strategy_impl.setup(trainer=trainer) @classmethod @override @@ -118,5 +92,4 @@ def register_strategies(cls, strategy_registry: _StrategyRegistry) -> None: @override def teardown(self) -> None: - super().teardown() - os.environ.pop("PT_XLA_DEBUG", None) + return self.single_xla_strategy_impl.teardown() diff --git a/src/lightning/pytorch/strategies/xla.py b/src/lightning/pytorch/strategies/xla.py index cbdc890a1ca32..e57c627f374d7 100644 --- a/src/lightning/pytorch/strategies/xla.py +++ b/src/lightning/pytorch/strategies/xla.py @@ -11,8 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import io -import os from typing import TYPE_CHECKING, Any, Optional, Union import torch @@ -21,20 +19,16 @@ from typing_extensions import override import lightning.pytorch as pl -from lightning.fabric.accelerators.xla import _XLA_AVAILABLE, _XLA_GREATER_EQUAL_2_1 from lightning.fabric.plugins import CheckpointIO, Precision, XLACheckpointIO from lightning.fabric.plugins.environments import XLAEnvironment from lightning.fabric.strategies import _StrategyRegistry -from lightning.fabric.utilities.optimizer import _optimizers_to_device +from lightning.fabric.utilities.imports import _raise_enterprise_not_available from lightning.fabric.utilities.types import _PATH, ReduceOp from lightning.pytorch.plugins import XLAPrecision from lightning.pytorch.plugins.io.wrapper import _WrappingCheckpointIO from lightning.pytorch.strategies.ddp import DDPStrategy from lightning.pytorch.strategies.launchers.xla import _XLALauncher from lightning.pytorch.strategies.strategy import TBroadcast -from lightning.pytorch.trainer.states import TrainerFn -from lightning.pytorch.utilities import find_shared_parameters, set_shared_parameters -from lightning.pytorch.utilities.rank_zero import rank_zero_only if TYPE_CHECKING: from torch_xla.distributed.parallel_loader import MpDeviceLoader @@ -56,8 +50,6 @@ def __init__( sync_module_states: bool = True, **_: Any, ) -> None: - if not _XLA_AVAILABLE: - raise ModuleNotFoundError(str(_XLA_AVAILABLE)) super().__init__( accelerator=accelerator, parallel_devices=parallel_devices, @@ -66,9 +58,12 @@ def __init__( precision_plugin=precision_plugin, start_method="fork", ) - self.debug = debug - self._launched = False - self._sync_module_states = sync_module_states + _raise_enterprise_not_available() + from pytorch_lightning_enterprise.strategies.xla.ddp import XLAStrategyTrainer as EnterpriseXLAStrategy + + self.xla_strategy_impl = EnterpriseXLAStrategy( + outer_object=self, debug=debug, sync_module_states=sync_module_states + ) @property @override @@ -105,31 +100,27 @@ def precision_plugin(self, precision_plugin: Optional[Precision]) -> None: @property @override def root_device(self) -> torch.device: - if not self._launched: - raise RuntimeError("Accessing the XLA device before processes have spawned is not allowed.") - import torch_xla.core.xla_model as xm - - return xm.xla_device() + return self.xla_strategy_impl.root_device @property @override def global_rank(self) -> int: - return super().global_rank if self._launched else 0 + return self.xla_strategy_impl.global_rank @property @override def local_rank(self) -> int: - return super().local_rank if self._launched else 0 + return self.xla_strategy_impl.local_rank @property @override def node_rank(self) -> int: - return super().node_rank if self._launched else 0 + return self.xla_strategy_impl.node_rank @property @override def world_size(self) -> int: - return super().world_size if self._launched else 1 + return self.xla_strategy_impl.world_size @override def _configure_launcher(self) -> None: @@ -137,113 +128,36 @@ def _configure_launcher(self) -> None: @override def setup(self, trainer: "pl.Trainer") -> None: - assert self.accelerator is not None - self.accelerator.setup(trainer) - - if self.debug: - os.environ["PT_XLA_DEBUG"] = "1" - - assert self.model is not None - self.precision_plugin.convert_module(self.model) - - shared_params = find_shared_parameters(self.model) - self.model_to_device() - set_shared_parameters(self.model, shared_params) - - self.model = self._setup_model(self.model) - - if self._sync_module_states: - if _XLA_GREATER_EQUAL_2_1: - from torch_xla.core.xla_model import broadcast_master_param - else: - from torch_xla.experimental.pjrt import broadcast_master_param - - broadcast_master_param(self.model) - - if trainer.state.fn == TrainerFn.FITTING: - self.setup_optimizers(trainer) - self.setup_precision_plugin() - if trainer.state.fn == TrainerFn.FITTING: - _optimizers_to_device(self.optimizers, self.root_device) + return self.xla_strategy_impl.setup(trainer=trainer) @override def _setup_model(self, model: Module) -> Module: # type: ignore - return model + return self.xla_strategy_impl._setup_model(model=model) @property @override def distributed_sampler_kwargs(self) -> dict[str, int]: - return {"num_replicas": self.world_size, "rank": self.global_rank} + return self.xla_strategy_impl.distributed_sampler_kwargs @override def process_dataloader(self, dataloader: object) -> "MpDeviceLoader": - from torch_xla.distributed.parallel_loader import MpDeviceLoader - - if isinstance(dataloader, MpDeviceLoader): - # dataloader is already wrapped by MpDeviceLoader - return dataloader - - dataloader = MpDeviceLoader(dataloader, self.root_device) - # Mimic interface to torch.utils.data.DataLoader - dataloader.dataset = dataloader._loader.dataset - dataloader.batch_sampler = getattr(dataloader._loader, "batch_sampler", None) - return dataloader + return self.xla_strategy_impl.process_dataloader(dataloader=dataloader) @override def configure_ddp(self) -> None: - pass + return self.xla_strategy_impl.configure_ddp() @override def model_to_device(self) -> None: - assert self.model is not None - self.model = self.model.to(self.root_device) + return self.xla_strategy_impl.model_to_device() @override def barrier(self, name: Optional[str] = None, *args: Any, **kwargs: Any) -> None: - if not self._launched: - return - - import torch_xla.core.xla_model as xm - - if name is None: - # `None` is not supported: "TypeError: _xla_rendezvous(): incompatible function arguments" - name = "" - xm.rendezvous(name) + return self.xla_strategy_impl.barrier(name=name, *args, **kwargs) @override def broadcast(self, obj: TBroadcast, src: int = 0) -> TBroadcast: - if not self._launched: - return obj - - import torch_xla.core.xla_model as xm - - is_tensor = isinstance(obj, Tensor) - if is_tensor: - if obj.dim() == 0: - obj = obj.unsqueeze(0) - original_device = obj.device - # XLA distributed requires that the data is on the XLA device - obj = obj.to(self.root_device) - else: - # support for arbitrary pickle-ables - buffer = io.BytesIO() - torch.save(obj, buffer) - obj = torch.tensor( # type: ignore[assignment] - bytearray(buffer.getbuffer()), device=self.root_device, dtype=torch.float - ) - - obj = [obj] - xm.collective_broadcast(obj, root_ordinal=src) - obj = obj[0] - - if not is_tensor: - # this will preserve the dtype and device of any tensors - buffer = io.BytesIO(obj.cpu().byte().numpy()) - obj = torch.load(buffer) - else: - obj = obj.to(original_device) - - return obj + return self.xla_strategy_impl.broadcast(obj=obj, src=src) @override def reduce( @@ -252,60 +166,27 @@ def reduce( group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = "mean", ) -> Tensor: - if not isinstance(output, Tensor): - output = torch.tensor(output, device=self.root_device) - - invalid_reduce_op = isinstance(reduce_op, ReduceOp) and reduce_op != ReduceOp.SUM - invalid_reduce_op_str = isinstance(reduce_op, str) and reduce_op.lower() not in ("sum", "mean", "avg") - if invalid_reduce_op or invalid_reduce_op_str: - raise ValueError( - "Currently, the XLAStrategy only supports `sum`, `mean`, `avg` for the reduce operation, got:" - f" {reduce_op}" - ) - - import torch_xla.core.xla_model as xm - - output = xm.mesh_reduce("reduce", output, sum) - - if isinstance(reduce_op, str) and reduce_op.lower() in ("avg", "mean"): - output = output / self.world_size - - return output + return self.xla_strategy_impl.reduce(output=output, group=group, reduce_op=reduce_op) @override def setup_environment(self) -> None: - self._launched = True - super().setup_environment() + return self.xla_strategy_impl.setup_environment() @override def setup_distributed(self) -> None: - assert self.parallel_devices is not None - if len(self.parallel_devices) == 1: - # spawning only 1 device with PjRT is not supported: - # https://github.com/Lightning-AI/pytorch-lightning/pull/17408#discussion_r1170671732 - raise NotImplementedError( - "The `XLAStrategy` does not support running on a single device with the PjRT runtime." - " Try using all devices or the `SingleDeviceXLAStrategy` strategy" - ) - rank_zero_only.rank = self.global_rank + return self.xla_strategy_impl.setup_distributed() @override def set_world_ranks(self) -> None: - # accessing global_rank will initialize the XLA computation client. since this is called outside of the spawned - # processes (by the accelerator connector), we cannot run the code that would normally be here. - # instead it's done in `setup_distributed` - pass + return self.xla_strategy_impl.set_world_ranks() @override def save_checkpoint( self, checkpoint: dict[str, Any], filepath: _PATH, storage_options: Optional[Any] = None ) -> None: - import torch_xla.core.xla_model as xm - - # sync any pending lazy tensors on all ranks before saving to prevent potential collective hangs - xm.mark_step() - # save on global rank zero only - super().save_checkpoint(checkpoint, filepath, storage_options=storage_options) + return self.xla_strategy_impl.save_checkpoint( + checkpoint=checkpoint, filepath=filepath, storage_options=storage_options + ) @override def remove_checkpoint(self, filepath: _PATH) -> None: @@ -315,8 +196,7 @@ def remove_checkpoint(self, filepath: _PATH) -> None: filepath: Path to checkpoint """ - if self.local_rank == 0: - self.checkpoint_io.remove_checkpoint(filepath) + return self.xla_strategy_impl.remove_checkpoint(filepath=filepath) @override def all_gather(self, tensor: Tensor, group: Optional[Any] = None, sync_grads: bool = False) -> Tensor: @@ -330,29 +210,11 @@ def all_gather(self, tensor: Tensor, group: Optional[Any] = None, sync_grads: bo A tensor of shape (world_size, ...) """ - if not self._launched: - return tensor - if not isinstance(tensor, Tensor): - raise NotImplementedError( - f"`{type(self).__name__}.all_gather` is only implemented for tensors. Given {tensor}" - ) - if tensor.dim() == 0: - tensor = tensor.unsqueeze(0) - original_device = tensor.device - tensor = tensor.to(self.root_device) - - import torch_xla.core.functions as xf - import torch_xla.core.xla_model as xm - - tensor = xf.all_gather(tensor) if sync_grads else xm.all_gather(tensor) - tensor = tensor.to(original_device) - return tensor + return self.xla_strategy_impl.all_gather(tensor=tensor, group=group, sync_grads=sync_grads) @override def teardown(self) -> None: - super().teardown() - self._launched = False # after the Trainer finishes, we aren't inside the spawned region - os.environ.pop("PT_XLA_DEBUG", None) + return self.xla_strategy_impl.teardown() @classmethod @override diff --git a/src/lightning/pytorch/utilities/deepspeed.py b/src/lightning/pytorch/utilities/deepspeed.py index 20b418437c681..21a9190552f43 100644 --- a/src/lightning/pytorch/utilities/deepspeed.py +++ b/src/lightning/pytorch/utilities/deepspeed.py @@ -20,8 +20,8 @@ import torch +from lightning.fabric.utilities.imports import _DEEPSPEED_AVAILABLE from lightning.fabric.utilities.types import _PATH -from lightning.pytorch.strategies.deepspeed import _DEEPSPEED_AVAILABLE CPU_DEVICE = torch.device("cpu") diff --git a/tests/tests_fabric/conftest.py b/tests/tests_fabric/conftest.py index 9d4a0b9462f2e..1f4dc37c1a6b7 100644 --- a/tests/tests_fabric/conftest.py +++ b/tests/tests_fabric/conftest.py @@ -19,9 +19,13 @@ from unittest.mock import Mock import pytest +import pytorch_lightning_enterprise.utils.imports import torch.distributed import lightning.fabric +import lightning.fabric.plugins.environments.xla +import lightning.fabric.plugins.io.xla +import lightning.fabric.plugins.precision.xla from lightning.fabric.accelerators import XLAAccelerator from lightning.fabric.strategies.launchers.subprocess_script import _ChildProcessObserver from lightning.fabric.utilities.distributed import _destroy_dist_connection @@ -144,17 +148,27 @@ def reset_cudnn_benchmark(): def mock_xla_available(monkeypatch: pytest.MonkeyPatch, value: bool = True) -> None: - monkeypatch.setattr(lightning.fabric.accelerators.xla, "_XLA_AVAILABLE", value) - monkeypatch.setattr(lightning.fabric.plugins.environments.xla, "_XLA_AVAILABLE", value) - monkeypatch.setattr(lightning.fabric.plugins.precision.xla, "_XLA_AVAILABLE", value) - monkeypatch.setattr(lightning.fabric.plugins.io.xla, "_XLA_AVAILABLE", value) - monkeypatch.setattr(lightning.fabric.strategies.single_xla, "_XLA_AVAILABLE", value) - monkeypatch.setattr(lightning.fabric.strategies.xla_fsdp, "_XLA_AVAILABLE", value) - monkeypatch.setattr(lightning.fabric.strategies.launchers.xla, "_XLA_AVAILABLE", value) + # First, mock torch_xla modules in sys.modules so imports succeed monkeypatch.setitem(sys.modules, "torch_xla", Mock()) monkeypatch.setitem(sys.modules, "torch_xla.core.xla_model", Mock()) monkeypatch.setitem(sys.modules, "torch_xla.experimental", Mock()) monkeypatch.setitem(sys.modules, "torch_xla.distributed.fsdp.wrap", Mock()) + monkeypatch.setitem(sys.modules, "torch_xla._internal", Mock()) + monkeypatch.setitem(sys.modules, "torch_xla._internal.tpu", Mock()) + + # Then patch the _XLA_AVAILABLE flags in various modules + monkeypatch.setattr(pytorch_lightning_enterprise.utils.imports, "_XLA_AVAILABLE", value) + monkeypatch.setattr(pytorch_lightning_enterprise.utils.imports, "_XLA_GREATER_EQUAL_2_1", value) + monkeypatch.setattr(pytorch_lightning_enterprise.utils.imports, "_XLA_GREATER_EQUAL_2_5", value) + monkeypatch.setattr(lightning.fabric.accelerators.xla, "_XLA_AVAILABLE", value) + monkeypatch.setattr(lightning.fabric.accelerators.xla, "_XLA_GREATER_EQUAL_2_1", value) + monkeypatch.setattr(lightning.fabric.accelerators.xla, "_XLA_GREATER_EQUAL_2_5", value) + # Patch in the modules where they're used after import + monkeypatch.setattr("pytorch_lightning_enterprise.accelerators.xla._XLA_AVAILABLE", value) + monkeypatch.setattr("pytorch_lightning_enterprise.accelerators.xla._XLA_GREATER_EQUAL_2_1", value) + monkeypatch.setattr("pytorch_lightning_enterprise.accelerators.xla._XLA_GREATER_EQUAL_2_5", value) + monkeypatch.setattr("pytorch_lightning_enterprise.plugins.environments.xla._XLA_AVAILABLE", value) + monkeypatch.setattr("pytorch_lightning_enterprise.plugins.environments.xla._XLA_GREATER_EQUAL_2_1", value) @pytest.fixture @@ -166,6 +180,14 @@ def mock_tpu_available(monkeypatch: pytest.MonkeyPatch, value: bool = True) -> N mock_xla_available(monkeypatch, value) monkeypatch.setattr(lightning.fabric.accelerators.xla.XLAAccelerator, "is_available", lambda: value) monkeypatch.setattr(lightning.fabric.accelerators.xla.XLAAccelerator, "auto_device_count", lambda *_: 8) + # Also mock the enterprise XLAAccelerator methods + import pytorch_lightning_enterprise.accelerators.xla + + monkeypatch.setattr(pytorch_lightning_enterprise.accelerators.xla.XLAAccelerator, "is_available", lambda: value) + monkeypatch.setattr(pytorch_lightning_enterprise.accelerators.xla.XLAAccelerator, "auto_device_count", lambda *_: 8) + monkeypatch.setitem(sys.modules, "torch_xla", Mock()) + monkeypatch.setitem(sys.modules, "torch_xla.core.xla_model", Mock()) + monkeypatch.setitem(sys.modules, "torch_xla.experimental", Mock()) @pytest.fixture diff --git a/tests/tests_fabric/graveyard/test_tpu.py b/tests/tests_fabric/graveyard/test_tpu.py index 5b72d60491df7..7aae8ab5d2297 100644 --- a/tests/tests_fabric/graveyard/test_tpu.py +++ b/tests/tests_fabric/graveyard/test_tpu.py @@ -15,7 +15,7 @@ def test_graveyard_single_tpu(import_path, name): module = import_module(import_path) cls = getattr(module, name) device = torch.device("cpu") - with pytest.deprecated_call(match="is deprecated"), pytest.raises(ModuleNotFoundError, match="torch_xla"): + with pytest.deprecated_call(match="is deprecated"): cls(device) @@ -37,5 +37,5 @@ def test_graveyard_single_tpu(import_path, name): def test_graveyard_no_device(import_path, name): module = import_module(import_path) cls = getattr(module, name) - with pytest.deprecated_call(match="is deprecated"), pytest.raises(ModuleNotFoundError, match="torch_xla"): + with pytest.deprecated_call(match="is deprecated"): cls() diff --git a/tests/tests_fabric/plugins/environments/test_kubeflow.py b/tests/tests_fabric/plugins/environments/test_kubeflow.py index 3436adc9ce2aa..72fb527d3facb 100644 --- a/tests/tests_fabric/plugins/environments/test_kubeflow.py +++ b/tests/tests_fabric/plugins/environments/test_kubeflow.py @@ -61,14 +61,14 @@ def test_attributes_from_environment_variables(caplog): assert env.local_rank() == 0 assert env.node_rank() == 1 # setter should be no-op - with caplog.at_level(logging.DEBUG, logger="lightning.fabric.plugins.environments"): + with caplog.at_level(logging.DEBUG, logger="pytorch_lightning_enterprise.plugins.environments.kubeflow"): env.set_global_rank(100) assert env.global_rank() == 1 assert "setting global rank is not allowed" in caplog.text caplog.clear() - with caplog.at_level(logging.DEBUG, logger="lightning.fabric.plugins.environments"): + with caplog.at_level(logging.DEBUG, logger="pytorch_lightning_enterprise.plugins.environments.kubeflow"): env.set_world_size(100) assert env.world_size() == 20 assert "setting world size is not allowed" in caplog.text diff --git a/tests/tests_fabric/plugins/environments/test_slurm.py b/tests/tests_fabric/plugins/environments/test_slurm.py index b907c287faa5f..742fc44855083 100644 --- a/tests/tests_fabric/plugins/environments/test_slurm.py +++ b/tests/tests_fabric/plugins/environments/test_slurm.py @@ -21,7 +21,6 @@ from lightning_utilities.test.warning import no_warning_call from lightning.fabric.plugins.environments import SLURMEnvironment -from lightning.fabric.utilities.warnings import PossibleUserWarning from tests_fabric.helpers.runif import RunIf @@ -72,14 +71,14 @@ def test_attributes_from_environment_variables(caplog): assert env.node_rank() == 3 assert env.job_name() == "JOB" # setter should be no-op - with caplog.at_level(logging.DEBUG, logger="lightning.fabric.plugins.environments"): + with caplog.at_level(logging.DEBUG, logger="pytorch_lightning_enterprise.plugins.environments.slurm"): env.set_global_rank(100) assert env.global_rank() == 1 assert "setting global rank is not allowed" in caplog.text caplog.clear() - with caplog.at_level(logging.DEBUG, logger="lightning.fabric.plugins.environments"): + with caplog.at_level(logging.DEBUG, logger="pytorch_lightning_enterprise.plugins.environments.slurm"): env.set_world_size(100) assert env.world_size() == 20 assert "setting world size is not allowed" in caplog.text @@ -135,18 +134,18 @@ def test_detect(): def test_srun_available_and_not_used(monkeypatch): """Test that a warning is emitted if Lightning suspects the user forgot to run their script with `srun`.""" monkeypatch.setattr(sys, "argv", ["train.py", "--lr", "0.01"]) - expected = "`srun` .* available .* but is not used. HINT: .* srun python train.py --lr 0.01" + expected = r"`srun` .* available .* but is not used. HINT: .* srun python\d* train.py --lr 0.01" # pretend `srun` is available with mock.patch("lightning.fabric.plugins.environments.slurm.shutil.which", return_value="/usr/bin/srun"): - with pytest.warns(PossibleUserWarning, match=expected): + with pytest.warns(UserWarning, match=expected): SLURMEnvironment() - with pytest.warns(PossibleUserWarning, match=expected): + with pytest.warns(UserWarning, match=expected): SLURMEnvironment.detect() # no warning if `srun` is unavailable - with no_warning_call(PossibleUserWarning, match=expected): + with no_warning_call(UserWarning, match=expected): SLURMEnvironment() assert not SLURMEnvironment.detect() @@ -176,7 +175,7 @@ def test_validate_user_settings(): # in interactive mode, validation is skipped because processes get launched by Fabric/Trainer, not SLURM with mock.patch( - "lightning.fabric.plugins.environments.slurm.SLURMEnvironment.job_name", return_value="interactive" + "pytorch_lightning_enterprise.plugins.environments.slurm.SLURMEnvironment.job_name", return_value="interactive" ): env = SLURMEnvironment() env.validate_settings(num_devices=4, num_nodes=1) # no error diff --git a/tests/tests_fabric/plugins/environments/test_torchelastic.py b/tests/tests_fabric/plugins/environments/test_torchelastic.py index 161d42894df30..cfd196c231036 100644 --- a/tests/tests_fabric/plugins/environments/test_torchelastic.py +++ b/tests/tests_fabric/plugins/environments/test_torchelastic.py @@ -58,14 +58,14 @@ def test_attributes_from_environment_variables(caplog): assert env.local_rank() == 2 assert env.node_rank() == 3 # setter should be no-op - with caplog.at_level(logging.DEBUG, logger="lightning.fabric.plugins.environments"): + with caplog.at_level(logging.DEBUG, logger="pytorch_lightning_enterprise.plugins.environments.torchelastic"): env.set_global_rank(100) assert env.global_rank() == 1 assert "setting global rank is not allowed" in caplog.text caplog.clear() - with caplog.at_level(logging.DEBUG, logger="lightning.fabric.plugins.environments"): + with caplog.at_level(logging.DEBUG, logger="pytorch_lightning_enterprise.plugins.environments.torchelastic"): env.set_world_size(100) assert env.world_size() == 20 assert "setting world size is not allowed" in caplog.text diff --git a/tests/tests_fabric/plugins/environments/test_xla.py b/tests/tests_fabric/plugins/environments/test_xla.py index 76fd18012ee94..0557f5620e8e1 100644 --- a/tests/tests_fabric/plugins/environments/test_xla.py +++ b/tests/tests_fabric/plugins/environments/test_xla.py @@ -100,8 +100,9 @@ def test_detect(monkeypatch): @mock.patch.dict(os.environ, {}, clear=True) +@mock.patch("pytorch_lightning_enterprise.utils.imports._XLA_AVAILABLE", True) +@mock.patch("pytorch_lightning_enterprise.utils.imports._XLA_GREATER_EQUAL_2_1", True) @mock.patch("lightning.fabric.accelerators.xla._XLA_GREATER_EQUAL_2_1", True) -@mock.patch("lightning.fabric.plugins.environments.xla._XLA_GREATER_EQUAL_2_1", True) def test_world_size_from_xla_runtime_greater_2_1(xla_available): """Test that world_size uses torch_xla.runtime when XLA >= 2.1.""" env = XLAEnvironment() @@ -113,8 +114,9 @@ def test_world_size_from_xla_runtime_greater_2_1(xla_available): @mock.patch.dict(os.environ, {}, clear=True) +@mock.patch("pytorch_lightning_enterprise.utils.imports._XLA_AVAILABLE", True) +@mock.patch("pytorch_lightning_enterprise.utils.imports._XLA_GREATER_EQUAL_2_1", True) @mock.patch("lightning.fabric.accelerators.xla._XLA_GREATER_EQUAL_2_1", True) -@mock.patch("lightning.fabric.plugins.environments.xla._XLA_GREATER_EQUAL_2_1", True) def test_global_rank_from_xla_runtime_greater_2_1(xla_available): """Test that global_rank uses torch_xla.runtime when XLA >= 2.1.""" env = XLAEnvironment() @@ -126,8 +128,9 @@ def test_global_rank_from_xla_runtime_greater_2_1(xla_available): @mock.patch.dict(os.environ, {}, clear=True) +@mock.patch("pytorch_lightning_enterprise.utils.imports._XLA_AVAILABLE", True) +@mock.patch("pytorch_lightning_enterprise.utils.imports._XLA_GREATER_EQUAL_2_1", True) @mock.patch("lightning.fabric.accelerators.xla._XLA_GREATER_EQUAL_2_1", True) -@mock.patch("lightning.fabric.plugins.environments.xla._XLA_GREATER_EQUAL_2_1", True) def test_local_rank_from_xla_runtime_greater_2_1(xla_available): """Test that local_rank uses torch_xla.runtime when XLA >= 2.1.""" env = XLAEnvironment() @@ -139,8 +142,9 @@ def test_local_rank_from_xla_runtime_greater_2_1(xla_available): @mock.patch.dict(os.environ, {}, clear=True) +@mock.patch("pytorch_lightning_enterprise.utils.imports._XLA_AVAILABLE", True) +@mock.patch("pytorch_lightning_enterprise.utils.imports._XLA_GREATER_EQUAL_2_1", True) @mock.patch("lightning.fabric.accelerators.xla._XLA_GREATER_EQUAL_2_1", True) -@mock.patch("lightning.fabric.plugins.environments.xla._XLA_GREATER_EQUAL_2_1", True) def test_setters_readonly_when_xla_runtime_greater_2_1(xla_available): """Test that set_world_size and set_global_rank don't affect values when using XLA runtime >= 2.1.""" env = XLAEnvironment() diff --git a/tests/tests_fabric/plugins/precision/test_transformer_engine.py b/tests/tests_fabric/plugins/precision/test_transformer_engine.py index ed7c984b1ae64..22908e107131b 100644 --- a/tests/tests_fabric/plugins/precision/test_transformer_engine.py +++ b/tests/tests_fabric/plugins/precision/test_transformer_engine.py @@ -15,19 +15,22 @@ from unittest.mock import Mock import pytest +import pytorch_lightning_enterprise.utils.imports import torch import torch.distributed -import lightning.fabric from lightning.fabric.connector import _Connector from lightning.fabric.plugins.precision.transformer_engine import TransformerEnginePrecision def test_transformer_engine_plugin(monkeypatch): - module = lightning.fabric.plugins.precision.transformer_engine + module = pytorch_lightning_enterprise.utils.imports if module._TRANSFORMER_ENGINE_AVAILABLE: pytest.skip("Assumes transformer_engine is unavailable") - monkeypatch.setattr(module, "_TRANSFORMER_ENGINE_AVAILABLE", lambda: True) + monkeypatch.setattr(module, "_TRANSFORMER_ENGINE_AVAILABLE", True) + monkeypatch.setattr( + "pytorch_lightning_enterprise.plugins.precision.transformer_engine._TRANSFORMER_ENGINE_AVAILABLE", True + ) transformer_engine_mock = Mock() monkeypatch.setitem(sys.modules, "transformer_engine", transformer_engine_mock) monkeypatch.setitem(sys.modules, "transformer_engine.pytorch", Mock()) @@ -118,8 +121,11 @@ class TELayerNormMock(Mock): ... def test_convert_module_handles_linear_without_bias(monkeypatch): - module = lightning.fabric.plugins.precision.transformer_engine # Set up mock transformer_engine - monkeypatch.setattr(module, "_TRANSFORMER_ENGINE_AVAILABLE", lambda: True) + module = pytorch_lightning_enterprise.utils.imports # Set up mock transformer_engine + monkeypatch.setattr(module, "_TRANSFORMER_ENGINE_AVAILABLE", True) + monkeypatch.setattr( + "pytorch_lightning_enterprise.plugins.precision.transformer_engine._TRANSFORMER_ENGINE_AVAILABLE", True + ) transformer_engine_mock = Mock() monkeypatch.setitem(sys.modules, "transformer_engine", transformer_engine_mock) diff --git a/tests/tests_fabric/strategies/test_deepspeed.py b/tests/tests_fabric/strategies/test_deepspeed.py index 0194c7b87820a..022e1320068ff 100644 --- a/tests/tests_fabric/strategies/test_deepspeed.py +++ b/tests/tests_fabric/strategies/test_deepspeed.py @@ -75,7 +75,7 @@ def test_deepspeed_defaults(): strategy = DeepSpeedStrategy() assert strategy.config is not None assert isinstance(strategy.config["zero_optimization"], dict) - assert strategy._backward_sync_control is None + assert strategy.deepspeed_impl._backward_sync_control is None @RunIf(deepspeed=True) @@ -230,7 +230,7 @@ def test_deepspeed_save_checkpoint_exclude_frozen_parameters(exclude_frozen_para from deepspeed import DeepSpeedEngine strategy = DeepSpeedStrategy(exclude_frozen_parameters=exclude_frozen_parameters) - assert strategy.exclude_frozen_parameters is exclude_frozen_parameters + assert strategy.deepspeed_impl.exclude_frozen_parameters is exclude_frozen_parameters model = Mock(spec=DeepSpeedEngine, optimizer=None) model.modules.return_value = [model] @@ -275,7 +275,7 @@ def test_deepspeed_load_checkpoint_no_state(tmp_path): @RunIf(deepspeed=True) -@mock.patch("lightning.fabric.strategies.deepspeed._is_deepspeed_checkpoint", return_value=True) +@mock.patch("pytorch_lightning_enterprise.strategies.deepspeed._is_deepspeed_checkpoint", return_value=True) def test_deepspeed_load_checkpoint_one_deepspeed_engine_required(_, tmp_path): """Test that the DeepSpeed strategy can only load one DeepSpeedEngine per checkpoint.""" from deepspeed import DeepSpeedEngine @@ -317,7 +317,7 @@ def test_deepspeed_load_checkpoint_client_state_missing(tmp_path): @RunIf(deepspeed=True) -@mock.patch("lightning.fabric.strategies.deepspeed._is_deepspeed_checkpoint", return_value=True) +@mock.patch("pytorch_lightning_enterprise.strategies.deepspeed._is_deepspeed_checkpoint", return_value=True) def test_deepspeed_load_checkpoint_state_updated_with_client_state(_, tmp_path): """Test that the DeepSpeed strategy properly updates the state variables and returns additional metadata.""" from deepspeed import DeepSpeedEngine @@ -342,7 +342,7 @@ def test_deepspeed_load_checkpoint_state_updated_with_client_state(_, tmp_path): @RunIf(deepspeed=True) @pytest.mark.parametrize("optimzer_state_requested", [True, False]) -@mock.patch("lightning.fabric.strategies.deepspeed._is_deepspeed_checkpoint", return_value=True) +@mock.patch("pytorch_lightning_enterprise.strategies.deepspeed._is_deepspeed_checkpoint", return_value=True) def test_deepspeed_load_checkpoint_optimzer_state_requested(_, optimzer_state_requested, tmp_path): """Test that the DeepSpeed strategy loads the optimizer state only when requested.""" from deepspeed import DeepSpeedEngine @@ -427,6 +427,7 @@ def test_validate_parallel_devices_indices(device_indices): """ accelerator = Mock(spec=CUDAAccelerator) + accelerator.name.return_value = "cuda" strategy = DeepSpeedStrategy( accelerator=accelerator, parallel_devices=[torch.device("cuda", i) for i in device_indices] ) diff --git a/tests/tests_fabric/strategies/test_deepspeed_integration.py b/tests/tests_fabric/strategies/test_deepspeed_integration.py index 2a04c6b27dded..72adf161fb5ce 100644 --- a/tests/tests_fabric/strategies/test_deepspeed_integration.py +++ b/tests/tests_fabric/strategies/test_deepspeed_integration.py @@ -251,7 +251,7 @@ def test_deepspeed_env_variables_on_platforms(_, deepspeed_dist_mock, platform): strategy = fabric._strategy assert isinstance(strategy, DeepSpeedStrategy) with mock.patch("platform.system", return_value=platform) as platform_mock: - strategy._init_deepspeed_distributed() + strategy.deepspeed_impl._init_deepspeed_distributed() deepspeed_dist_mock.assert_called() platform_mock.assert_called() if platform == "Windows": diff --git a/tests/tests_fabric/strategies/test_xla.py b/tests/tests_fabric/strategies/test_xla.py index 9ca21d8d8b894..1884208dc2126 100644 --- a/tests/tests_fabric/strategies/test_xla.py +++ b/tests/tests_fabric/strategies/test_xla.py @@ -194,14 +194,14 @@ def test_rank_properties_access(xla_available): strategy.cluster_environment = Mock() # we're in the main process, no processes have been launched yet - assert not strategy._launched + assert not strategy.xla_strategy_impl._launched assert strategy.global_rank == 0 assert strategy.local_rank == 0 assert strategy.node_rank == 0 assert strategy.world_size == 1 # simulate we're in a worker process - strategy._launched = True + strategy.xla_strategy_impl._launched = True assert strategy.global_rank == strategy.cluster_environment.global_rank() assert strategy.local_rank == strategy.cluster_environment.local_rank() assert strategy.node_rank == strategy.cluster_environment.node_rank() diff --git a/tests/tests_fabric/strategies/test_xla_fsdp.py b/tests/tests_fabric/strategies/test_xla_fsdp.py index c2634283ad110..fa1021402590e 100644 --- a/tests/tests_fabric/strategies/test_xla_fsdp.py +++ b/tests/tests_fabric/strategies/test_xla_fsdp.py @@ -23,7 +23,6 @@ from lightning.fabric.accelerators import XLAAccelerator from lightning.fabric.plugins import XLAPrecision from lightning.fabric.strategies import XLAFSDPStrategy -from lightning.fabric.strategies.xla_fsdp import _activation_checkpointing_auto_wrapper, _XLAFSDPBackwardSyncControl from tests_fabric.helpers.runif import RunIf @@ -43,6 +42,7 @@ def test_xla_fsdp_setup_optimizer_validation(): def test_xla_fsdp_no_backward_sync(): """Test that the backward sync control calls `.no_sync()`, and only on a module wrapped in XlaFullyShardedDataParallel.""" + from pytorch_lightning_enterprise.strategies.xla.fsdp import _XLAFSDPBackwardSyncControl from torch_xla.distributed.fsdp import XlaFullyShardedDataParallel strategy = XLAFSDPStrategy() @@ -75,41 +75,45 @@ def test_xla_fsdp_grad_clipping_value_error(): @mock.patch.dict(os.environ, os.environ.copy(), clear=True) +@mock.patch("pytorch_lightning_enterprise.strategies.xla.fsdp._XLA_AVAILABLE", True) def test_rank_properties_access(xla_available): """Test that the strategy returns the expected values depending on whether we're in the main process or not.""" strategy = XLAFSDPStrategy() strategy.cluster_environment = Mock() # we're in the main process, no processes have been launched yet - assert not strategy._launched + assert not strategy.xla_fsdp_impl._launched assert strategy.global_rank == 0 assert strategy.local_rank == 0 assert strategy.node_rank == 0 assert strategy.world_size == 1 # simulate we're in a worker process - strategy._launched = True + strategy.xla_fsdp_impl._launched = True assert strategy.global_rank == strategy.cluster_environment.global_rank() assert strategy.local_rank == strategy.cluster_environment.local_rank() assert strategy.node_rank == strategy.cluster_environment.node_rank() assert strategy.world_size == strategy.cluster_environment.world_size() +@mock.patch("pytorch_lightning_enterprise.strategies.xla.fsdp._XLA_AVAILABLE", True) def test_xla_fsdp_policy(xla_available): + from pytorch_lightning_enterprise.strategies.xla import fsdp as fsdp_module + strategy = XLAFSDPStrategy(foo=1) - assert strategy._fsdp_kwargs == {"foo": 1} + assert strategy.xla_fsdp_impl._fsdp_kwargs == {"foo": 1} strategy = XLAFSDPStrategy(auto_wrap_policy={torch.nn.Linear}) - kwargs = strategy._parse_fsdp_kwargs() + kwargs = strategy.xla_fsdp_impl._parse_fsdp_kwargs() assert set(kwargs) == {"auto_wrap_policy", "compute_dtype"} assert kwargs["auto_wrap_policy"].func._mock_name == "transformer_auto_wrap_policy" assert kwargs["compute_dtype"] is torch.float32 strategy = XLAFSDPStrategy(activation_checkpointing_policy={torch.nn.Linear}) - _ = strategy._parse_fsdp_kwargs() - kwargs = strategy._parse_fsdp_kwargs() # ensure it's idempotent + _ = strategy.xla_fsdp_impl._parse_fsdp_kwargs() + kwargs = strategy.xla_fsdp_impl._parse_fsdp_kwargs() # ensure it's idempotent assert set(kwargs) == {"auto_wrapper_callable", "compute_dtype"} - assert kwargs["auto_wrapper_callable"].func is _activation_checkpointing_auto_wrapper + assert kwargs["auto_wrapper_callable"].func is fsdp_module._activation_checkpointing_auto_wrapper assert kwargs["compute_dtype"] is torch.float32 strategy = XLAFSDPStrategy( @@ -118,17 +122,17 @@ def test_xla_fsdp_policy(xla_available): activation_checkpointing_policy={torch.nn.Linear}, precision=XLAPrecision("bf16-true"), ) - kwargs = strategy._parse_fsdp_kwargs() + kwargs = strategy.xla_fsdp_impl._parse_fsdp_kwargs() assert set(kwargs) == {"auto_wrap_policy", "auto_wrapper_callable", "compute_dtype"} assert kwargs["auto_wrap_policy"].func._mock_name == "transformer_auto_wrap_policy" - assert kwargs["auto_wrapper_callable"].func is _activation_checkpointing_auto_wrapper + assert kwargs["auto_wrapper_callable"].func is fsdp_module._activation_checkpointing_auto_wrapper assert kwargs["compute_dtype"] is torch.bfloat16 strategy.teardown() strategy = XLAFSDPStrategy(activation_checkpointing_policy={torch.nn.Linear}, auto_wrapper_callable="foo") with pytest.raises(ValueError, match="cannot set both"): - strategy._parse_fsdp_kwargs() + strategy.xla_fsdp_impl._parse_fsdp_kwargs() strategy = XLAFSDPStrategy(activation_checkpointing_policy="foo") with pytest.raises(TypeError, match="must be a set"): - strategy._parse_fsdp_kwargs() + strategy.xla_fsdp_impl._parse_fsdp_kwargs() diff --git a/tests/tests_fabric/test_connector.py b/tests/tests_fabric/test_connector.py index 1074789e71055..faf22bf3a06b5 100644 --- a/tests/tests_fabric/test_connector.py +++ b/tests/tests_fabric/test_connector.py @@ -20,6 +20,7 @@ from unittest.mock import Mock import pytest +import pytorch_lightning_enterprise.utils.imports import torch import torch.distributed from lightning_utilities.test.warning import no_warning_call @@ -242,8 +243,10 @@ class TestStrategy(DDPStrategy): ), ], ) -@mock.patch("lightning.fabric.plugins.environments.lsf.LSFEnvironment._read_hosts", return_value=["node0", "node1"]) -@mock.patch("lightning.fabric.plugins.environments.lsf.LSFEnvironment._get_node_rank", return_value=0) +@mock.patch( + "pytorch_lightning_enterprise.plugins.environments.lsf.LSFEnvironment._read_hosts", return_value=["node0", "node1"] +) +@mock.patch("pytorch_lightning_enterprise.plugins.environments.lsf.LSFEnvironment._get_node_rank", return_value=0) def test_fallback_from_ddp_spawn_to_ddp_on_cluster(_, __, env_vars, expected_environment): with mock.patch.dict(os.environ, env_vars, clear=True): connector = _Connector(strategy="ddp_spawn", accelerator="cpu", devices=2) @@ -341,8 +344,8 @@ def test_cuda_accelerator_can_not_run_on_system(_): @pytest.mark.skipif(XLAAccelerator.is_available(), reason="test requires missing TPU") -@mock.patch("lightning.fabric.accelerators.xla._XLA_AVAILABLE", True) -@mock.patch("lightning.fabric.accelerators.xla._using_pjrt", return_value=True) +@mock.patch("pytorch_lightning_enterprise.accelerators.xla._XLA_AVAILABLE", True) +@mock.patch("pytorch_lightning_enterprise.accelerators.xla._using_pjrt", return_value=True) def test_tpu_accelerator_can_not_run_on_system(_): with pytest.raises(RuntimeError, match="XLAAccelerator` can not run on your system"): _Connector(accelerator="tpu", devices=8) @@ -888,7 +891,7 @@ def test_precision_selection_model_parallel(_, precision, raises): def test_bitsandbytes_precision_cuda_required(monkeypatch): - monkeypatch.setattr(lightning.fabric.plugins.precision.bitsandbytes, "_BITSANDBYTES_AVAILABLE", True) + monkeypatch.setattr(pytorch_lightning_enterprise.utils.imports, "_BITSANDBYTES_AVAILABLE", True) monkeypatch.setitem(sys.modules, "bitsandbytes", Mock()) with pytest.raises(RuntimeError, match="Bitsandbytes is only supported on CUDA GPUs"): _Connector(accelerator="cpu", plugins=BitsandbytesPrecision(mode="int8")) diff --git a/tests/tests_fabric/utilities/test_throughput.py b/tests/tests_fabric/utilities/test_throughput.py index a175fa97fd444..785ec0a3ed2e9 100644 --- a/tests/tests_fabric/utilities/test_throughput.py +++ b/tests/tests_fabric/utilities/test_throughput.py @@ -45,7 +45,13 @@ def test_get_available_flops(xla_available): ): assert get_available_flops(torch.device("cuda"), torch.bfloat16) is None - from torch_xla.experimental import tpu + # Import from the right module based on _XLA_GREATER_EQUAL_2_1 + from lightning.fabric.accelerators.xla import _XLA_GREATER_EQUAL_2_1 + + if _XLA_GREATER_EQUAL_2_1: + from torch_xla._internal import tpu + else: + from torch_xla.experimental import tpu assert isinstance(tpu, Mock) diff --git a/tests/tests_pytorch/accelerators/test_xla.py b/tests/tests_pytorch/accelerators/test_xla.py index 5e56d5c585c88..52434ee3a19d5 100644 --- a/tests/tests_pytorch/accelerators/test_xla.py +++ b/tests/tests_pytorch/accelerators/test_xla.py @@ -22,7 +22,6 @@ from torch import nn from torch.utils.data import DataLoader -import lightning.fabric from lightning.fabric.utilities.imports import _IS_WINDOWS from lightning.pytorch import Trainer from lightning.pytorch.accelerators import CPUAccelerator, XLAAccelerator @@ -312,7 +311,7 @@ def test_warning_if_tpus_not_used(tpu_available): ) @RunIf(min_python="3.9") # mocking issue def test_trainer_config_device_ids(devices, expected_device_ids, tpu_available, monkeypatch): - monkeypatch.setattr(lightning.fabric.accelerators.xla, "_using_pjrt", lambda: True) + monkeypatch.setattr("pytorch_lightning_enterprise.accelerators.xla._using_pjrt", lambda: True) mock = DeviceMock() monkeypatch.setattr(torch, "device", mock) diff --git a/tests/tests_pytorch/conftest.py b/tests/tests_pytorch/conftest.py index 878298c6bfd94..ddee9998f9f4a 100644 --- a/tests/tests_pytorch/conftest.py +++ b/tests/tests_pytorch/conftest.py @@ -22,6 +22,7 @@ from unittest.mock import Mock import pytest +import pytorch_lightning_enterprise.utils.imports import torch.distributed from tqdm import TMonitor @@ -220,14 +221,41 @@ def mps_count_1(monkeypatch): def mock_xla_available(monkeypatch: pytest.MonkeyPatch, value: bool = True) -> None: - monkeypatch.setattr(lightning.pytorch.strategies.xla, "_XLA_AVAILABLE", value) - monkeypatch.setattr(lightning.pytorch.strategies.single_xla, "_XLA_AVAILABLE", value) - monkeypatch.setattr(lightning.pytorch.plugins.precision.xla, "_XLA_AVAILABLE", value) - monkeypatch.setattr(lightning.pytorch.strategies.launchers.xla, "_XLA_AVAILABLE", value) + # First, mock torch_xla modules in sys.modules so imports succeed + monkeypatch.setitem(sys.modules, "torch_xla", Mock()) + monkeypatch.setitem(sys.modules, "torch_xla.core", Mock()) + monkeypatch.setitem(sys.modules, "torch_xla.core.xla_model", Mock()) + monkeypatch.setitem(sys.modules, "torch_xla.core.functions", Mock()) + monkeypatch.setitem(sys.modules, "torch_xla.core.xla_env_vars", Mock()) + monkeypatch.setitem(sys.modules, "torch_xla.experimental", Mock()) + monkeypatch.setitem(sys.modules, "torch_xla.experimental.pjrt", Mock()) + monkeypatch.setitem(sys.modules, "torch_xla.experimental.tpu", Mock()) + monkeypatch.setitem(sys.modules, "torch_xla.distributed", Mock()) + monkeypatch.setitem(sys.modules, "torch_xla.distributed.fsdp", Mock()) + monkeypatch.setitem(sys.modules, "torch_xla.distributed.fsdp.wrap", Mock()) + monkeypatch.setitem(sys.modules, "torch_xla.distributed.parallel_loader", Mock()) + monkeypatch.setitem(sys.modules, "torch_xla.distributed.xla_multiprocessing", Mock()) + monkeypatch.setitem(sys.modules, "torch_xla.runtime", Mock()) + monkeypatch.setitem(sys.modules, "torch_xla.utils", Mock()) + monkeypatch.setitem(sys.modules, "torch_xla.utils.utils", Mock()) + monkeypatch.setitem(sys.modules, "torch_xla.debug", Mock()) + monkeypatch.setitem(sys.modules, "torch_xla.debug.profiler", Mock()) + monkeypatch.setitem(sys.modules, "torch_xla._internal", Mock()) + monkeypatch.setitem(sys.modules, "torch_xla._internal.tpu", Mock()) + + # Then patch the _XLA_AVAILABLE flags in various modules + monkeypatch.setattr(pytorch_lightning_enterprise.utils.imports, "_XLA_AVAILABLE", value) + monkeypatch.setattr(pytorch_lightning_enterprise.utils.imports, "_XLA_GREATER_EQUAL_2_1", value) + monkeypatch.setattr(pytorch_lightning_enterprise.utils.imports, "_XLA_GREATER_EQUAL_2_5", value) monkeypatch.setattr(lightning.fabric.accelerators.xla, "_XLA_AVAILABLE", value) - monkeypatch.setattr(lightning.fabric.plugins.environments.xla, "_XLA_AVAILABLE", value) - monkeypatch.setattr(lightning.fabric.plugins.io.xla, "_XLA_AVAILABLE", value) - monkeypatch.setattr(lightning.fabric.strategies.launchers.xla, "_XLA_AVAILABLE", value) + monkeypatch.setattr(lightning.fabric.accelerators.xla, "_XLA_GREATER_EQUAL_2_1", value) + monkeypatch.setattr(lightning.fabric.accelerators.xla, "_XLA_GREATER_EQUAL_2_5", value) + # Patch in the modules where they're used after import + monkeypatch.setattr("pytorch_lightning_enterprise.accelerators.xla._XLA_AVAILABLE", value) + monkeypatch.setattr("pytorch_lightning_enterprise.accelerators.xla._XLA_GREATER_EQUAL_2_1", value) + monkeypatch.setattr("pytorch_lightning_enterprise.accelerators.xla._XLA_GREATER_EQUAL_2_5", value) + monkeypatch.setattr("pytorch_lightning_enterprise.plugins.environments.xla._XLA_AVAILABLE", value) + monkeypatch.setattr("pytorch_lightning_enterprise.plugins.environments.xla._XLA_GREATER_EQUAL_2_1", value) @pytest.fixture @@ -237,10 +265,14 @@ def xla_available(monkeypatch: pytest.MonkeyPatch) -> None: def mock_tpu_available(monkeypatch: pytest.MonkeyPatch, value: bool = True) -> None: mock_xla_available(monkeypatch, value) - monkeypatch.setattr(lightning.pytorch.accelerators.xla.XLAAccelerator, "is_available", lambda: value) monkeypatch.setattr(lightning.fabric.accelerators.xla.XLAAccelerator, "is_available", lambda: value) - monkeypatch.setattr(lightning.pytorch.accelerators.xla.XLAAccelerator, "auto_device_count", lambda *_: 8) monkeypatch.setattr(lightning.fabric.accelerators.xla.XLAAccelerator, "auto_device_count", lambda *_: 8) + # Also mock the enterprise XLAAccelerator methods + import pytorch_lightning_enterprise.accelerators.xla + + monkeypatch.setattr(pytorch_lightning_enterprise.accelerators.xla.XLAAccelerator, "is_available", lambda: value) + monkeypatch.setattr(pytorch_lightning_enterprise.accelerators.xla.XLAAccelerator, "auto_device_count", lambda *_: 8) + monkeypatch.setitem(sys.modules, "torch_xla", Mock()) monkeypatch.setitem(sys.modules, "torch_xla.core.xla_model", Mock()) monkeypatch.setitem(sys.modules, "torch_xla.experimental", Mock()) diff --git a/tests/tests_pytorch/deprecated_api/test_no_removal_version.py b/tests/tests_pytorch/deprecated_api/test_no_removal_version.py index 0b79b638534fa..a24d53af58fa7 100644 --- a/tests/tests_pytorch/deprecated_api/test_no_removal_version.py +++ b/tests/tests_pytorch/deprecated_api/test_no_removal_version.py @@ -2,6 +2,7 @@ from unittest.mock import Mock import pytest +import pytorch_lightning_enterprise.utils.imports import torch.nn import lightning.fabric @@ -62,6 +63,7 @@ def test_fsdp_precision_plugin(): def test_bitsandbytes_precision_plugin(monkeypatch): monkeypatch.setattr(lightning.fabric.plugins.precision.bitsandbytes, "_BITSANDBYTES_AVAILABLE", True) + monkeypatch.setattr(pytorch_lightning_enterprise.utils.imports, "_BITSANDBYTES_AVAILABLE", True) bitsandbytes_mock = Mock() monkeypatch.setitem(sys.modules, "bitsandbytes", bitsandbytes_mock) @@ -107,7 +109,8 @@ def test_precision_plugin(): def test_transformer_engine_precision_plugin(monkeypatch): - monkeypatch.setattr(lightning.fabric.plugins.precision.transformer_engine, "_TRANSFORMER_ENGINE_AVAILABLE", True) + monkeypatch.setattr(lightning.fabric.utilities.imports, "_TRANSFORMER_ENGINE_AVAILABLE", True) + monkeypatch.setattr(pytorch_lightning_enterprise.utils.imports, "_TRANSFORMER_ENGINE_AVAILABLE", True) transformer_engine_mock = Mock() monkeypatch.setitem(sys.modules, "transformer_engine", transformer_engine_mock) monkeypatch.setitem(sys.modules, "transformer_engine.pytorch", Mock()) diff --git a/tests/tests_pytorch/loggers/conftest.py b/tests/tests_pytorch/loggers/conftest.py index 98819eb080eb8..f5a0c41fc97a4 100644 --- a/tests/tests_pytorch/loggers/conftest.py +++ b/tests/tests_pytorch/loggers/conftest.py @@ -38,8 +38,8 @@ def mlflow_mock(monkeypatch): mlflow.tracking = mlflow_tracking mlflow.entities = mlflow_entities - monkeypatch.setattr("lightning.pytorch.loggers.mlflow._MLFLOW_AVAILABLE", True) - monkeypatch.setattr("lightning.pytorch.loggers.mlflow._MLFLOW_SYNCHRONOUS_AVAILABLE", True) + monkeypatch.setattr("pytorch_lightning_enterprise.utils.imports._MLFLOW_AVAILABLE", True) + monkeypatch.setattr("pytorch_lightning_enterprise.utils.imports._MLFLOW_SYNCHRONOUS_AVAILABLE", True) return mlflow @@ -86,7 +86,7 @@ class RunType: # to make isinstance checks pass wandb.sdk.lib = wandb_sdk_lib wandb.wandb_run = wandb_wandb_run - monkeypatch.setattr("lightning.pytorch.loggers.wandb._WANDB_AVAILABLE", True) + monkeypatch.setattr("pytorch_lightning_enterprise.utils.imports._WANDB_AVAILABLE", True) return wandb @@ -109,7 +109,7 @@ def comet_mock(monkeypatch): comet.start = Mock(name="comet_ml.start", return_value=comet.Experiment()) comet.config = Mock() - monkeypatch.setattr("lightning.pytorch.loggers.comet._COMET_AVAILABLE", True) + monkeypatch.setattr("pytorch_lightning_enterprise.utils.imports._COMET_AVAILABLE", True) return comet @@ -153,5 +153,5 @@ def __setitem__(self, key, value): neptune.types = neptune_types neptune.utils = neptune_utils - monkeypatch.setattr("lightning.pytorch.loggers.neptune._NEPTUNE_AVAILABLE", True) + monkeypatch.setattr("pytorch_lightning_enterprise.utils.imports._NEPTUNE_AVAILABLE", True) return neptune diff --git a/tests/tests_pytorch/loggers/test_all.py b/tests/tests_pytorch/loggers/test_all.py index a9cf9af79185b..22b7632d5bea1 100644 --- a/tests/tests_pytorch/loggers/test_all.py +++ b/tests/tests_pytorch/loggers/test_all.py @@ -70,7 +70,7 @@ def _instantiate_logger(logger_class, save_dir, **override_kwargs): @mock.patch.dict(os.environ, {}) -@mock.patch("lightning.pytorch.loggers.mlflow._get_resolve_tags", Mock()) +@mock.patch("pytorch_lightning_enterprise.loggers.mlflow._get_resolve_tags", Mock()) @pytest.mark.parametrize("logger_class", ALL_LOGGER_CLASSES) def test_loggers_fit_test_all(logger_class, mlflow_mock, wandb_mock, comet_mock, neptune_mock, tmp_path, monkeypatch): """Verify that basic functionality of all loggers.""" @@ -335,7 +335,7 @@ def test_logger_with_prefix_all(mlflow_mock, wandb_mock, comet_mock, neptune_moc logger.experiment.log.assert_called_once_with({"tmp-test": 1.0, "trainer/global_step": 0}) -@mock.patch("lightning.pytorch.loggers.mlflow._get_resolve_tags", Mock()) +@mock.patch("pytorch_lightning_enterprise.loggers.mlflow._get_resolve_tags", Mock()) def test_logger_default_name(mlflow_mock, monkeypatch, tmp_path): """Test that the default logger name is lightning_logs.""" # CSV diff --git a/tests/tests_pytorch/loggers/test_comet.py b/tests/tests_pytorch/loggers/test_comet.py index dae8f617b873e..4108d3c3b641c 100644 --- a/tests/tests_pytorch/loggers/test_comet.py +++ b/tests/tests_pytorch/loggers/test_comet.py @@ -89,10 +89,10 @@ def test_comet_experiment_is_still_alive_after_training_complete(comet_mock): logger.finalize("ended") # Assert that data was saved to comet.com - logger._experiment.flush.assert_called_once() + logger.logger_impl._experiment.flush.assert_called_once() # Assert that was not ended - logger._experiment.end.assert_not_called() + logger.logger_impl._experiment.end.assert_not_called() @mock.patch.dict(os.environ, {}) @@ -115,8 +115,8 @@ def test_comet_logger_experiment_name(comet_mock): experiment_config=comet_mock.ExperimentConfig(), ) # check that we saved "experiment name" in kwargs as new "name" arg - assert logger._kwargs["name"] == experiment_name - assert "experiment_name" not in logger._kwargs + assert logger.logger_impl._kwargs["name"] == experiment_name + assert "experiment_name" not in logger.logger_impl._kwargs # check that "experiment name" was passed to experiment config correctly assert call(experiment_name=experiment_name) not in comet_mock.ExperimentConfig.call_args_list @@ -130,10 +130,10 @@ def test_comet_version(comet_mock): experiment_name = "My Name" logger = CometLogger(api_key=api_key, name=experiment_name) - assert logger._experiment is not None + assert logger.logger_impl._experiment is not None _ = logger.version - logger._experiment.get_key.assert_called() + logger.logger_impl._experiment.get_key.assert_called() @mock.patch.dict(os.environ, {}) @@ -146,7 +146,7 @@ def test_comet_epoch_logging(comet_mock, tmp_path, monkeypatch): {"test": 1}, epoch=1, step=123, - prefix=logger._prefix, + prefix=logger.logger_impl._prefix, framework="pytorch-lightning", ) diff --git a/tests/tests_pytorch/loggers/test_mlflow.py b/tests/tests_pytorch/loggers/test_mlflow.py index c7f9dbe1fe2c6..c45f5ee5b8295 100644 --- a/tests/tests_pytorch/loggers/test_mlflow.py +++ b/tests/tests_pytorch/loggers/test_mlflow.py @@ -16,13 +16,12 @@ from unittest.mock import MagicMock, Mock import pytest +from pytorch_lightning_enterprise.utils.imports import _MLFLOW_AVAILABLE from lightning.pytorch import Trainer from lightning.pytorch.demos.boring_classes import BoringModel from lightning.pytorch.loggers.mlflow import ( - _MLFLOW_AVAILABLE, MLFlowLogger, - _get_resolve_tags, ) @@ -30,13 +29,13 @@ def mock_mlflow_run_creation(logger, experiment_name=None, experiment_id=None, r """Helper function to simulate mlflow client creating a new (or existing) experiment.""" run = MagicMock() run.info.run_id = run_id - logger._mlflow_client.get_experiment_by_name = MagicMock(return_value=experiment_name) - logger._mlflow_client.create_experiment = MagicMock(return_value=experiment_id) - logger._mlflow_client.create_run = MagicMock(return_value=run) + logger.logger_impl._mlflow_client.get_experiment_by_name = MagicMock(return_value=experiment_name) + logger.logger_impl._mlflow_client.create_experiment = MagicMock(return_value=experiment_id) + logger.logger_impl._mlflow_client.create_run = MagicMock(return_value=run) return logger -@mock.patch("lightning.pytorch.loggers.mlflow._get_resolve_tags", Mock()) +@mock.patch("pytorch_lightning_enterprise.loggers.mlflow._get_resolve_tags", Mock()) def test_mlflow_logger_exists(mlflow_mock, tmp_path): """Test launching three independent loggers with either same or different experiment name.""" client = mlflow_mock.tracking.MlflowClient @@ -57,8 +56,8 @@ def test_mlflow_logger_exists(mlflow_mock, tmp_path): client.return_value.create_run = MagicMock(return_value=run1) logger = MLFlowLogger("test", save_dir=str(tmp_path)) - assert logger._experiment_id is None - assert logger._run_id is None + assert logger.logger_impl._experiment_id is None + assert logger.logger_impl._run_id is None _ = logger.experiment assert logger.experiment_id == "exp-id-1" assert logger.run_id == "run-id-1" @@ -96,13 +95,14 @@ def test_mlflow_run_name_setting(tmp_path): pytest.skip("test for explicit file creation requires mlflow dependency to be installed.") from mlflow.utils.mlflow_tags import MLFLOW_RUN_NAME + from pytorch_lightning_enterprise.loggers.mlflow import _get_resolve_tags resolve_tags = _get_resolve_tags() tags = resolve_tags({MLFLOW_RUN_NAME: "run-name-1"}) # run_name is appended to tags logger = MLFlowLogger("test", run_name="run-name-1", save_dir=str(tmp_path)) - logger._mlflow_client = client = Mock() + logger.logger_impl._mlflow_client = client = Mock() logger = mock_mlflow_run_creation(logger, experiment_id="exp-id") _ = logger.experiment @@ -122,7 +122,7 @@ def test_mlflow_run_name_setting(tmp_path): client.create_run.assert_called_with(experiment_id="exp-id", tags=default_tags) -@mock.patch("lightning.pytorch.loggers.mlflow._get_resolve_tags", Mock()) +@mock.patch("pytorch_lightning_enterprise.loggers.mlflow._get_resolve_tags", Mock()) def test_mlflow_run_id_setting(mlflow_mock, tmp_path): """Test that the run_id argument uses the provided run_id.""" client = mlflow_mock.tracking.MlflowClient @@ -143,7 +143,7 @@ def test_mlflow_run_id_setting(mlflow_mock, tmp_path): client.reset_mock(return_value=True) -@mock.patch("lightning.pytorch.loggers.mlflow._get_resolve_tags", Mock()) +@mock.patch("pytorch_lightning_enterprise.loggers.mlflow._get_resolve_tags", Mock()) def test_mlflow_log_dir(mlflow_mock, tmp_path): """Test that the trainer saves checkpoints in the logger's save dir.""" client = mlflow_mock.tracking.MlflowClient @@ -211,8 +211,8 @@ def on_train_epoch_end(self, *args, **kwargs): assert os.listdir(trainer.checkpoint_callback.dirpath) == [f"epoch=0-step={limit_batches}.ckpt"] -@mock.patch("lightning.pytorch.loggers.mlflow._get_resolve_tags", Mock()) -@mock.patch("lightning.pytorch.loggers.mlflow._MLFLOW_AVAILABLE", return_value=True) +@mock.patch("pytorch_lightning_enterprise.loggers.mlflow._get_resolve_tags", Mock()) +@mock.patch("pytorch_lightning_enterprise.utils.imports._MLFLOW_AVAILABLE", return_value=True) def test_mlflow_experiment_id_retrieved_once(_, mlflow_mock, tmp_path): """Test that the logger experiment_id retrieved only once.""" logger = MLFlowLogger("test", save_dir=str(tmp_path)) @@ -222,7 +222,7 @@ def test_mlflow_experiment_id_retrieved_once(_, mlflow_mock, tmp_path): assert logger.experiment.get_experiment_by_name.call_count == 1 -@mock.patch("lightning.pytorch.loggers.mlflow._get_resolve_tags", Mock()) +@mock.patch("pytorch_lightning_enterprise.loggers.mlflow._get_resolve_tags", Mock()) def test_mlflow_logger_with_unexpected_characters(mlflow_mock, tmp_path): """Test that the logger raises warning with special characters not accepted by MLFlow.""" logger = MLFlowLogger("test", save_dir=str(tmp_path)) @@ -232,7 +232,7 @@ def test_mlflow_logger_with_unexpected_characters(mlflow_mock, tmp_path): logger.log_metrics(metrics) -@mock.patch("lightning.pytorch.loggers.mlflow._get_resolve_tags", Mock()) +@mock.patch("pytorch_lightning_enterprise.loggers.mlflow._get_resolve_tags", Mock()) def test_mlflow_logger_experiment_calls(mlflow_mock, tmp_path): """Test that the logger calls methods on the mlflow experiment correctly.""" time = mlflow_mock.entities.time @@ -242,7 +242,7 @@ def test_mlflow_logger_experiment_calls(mlflow_mock, tmp_path): time.return_value = 1 logger = MLFlowLogger("test", save_dir=str(tmp_path), artifact_location="my_artifact_location") - logger._mlflow_client.get_experiment_by_name.return_value = None + logger.logger_impl._mlflow_client.get_experiment_by_name.return_value = None params = {"test": "test_param"} logger.log_hyperparams(params) @@ -260,13 +260,13 @@ def test_mlflow_logger_experiment_calls(mlflow_mock, tmp_path): ) metric.assert_called_with(key="some_metric", value=10, timestamp=1000, step=0) - logger._mlflow_client.create_experiment.assert_called_once_with( + logger.logger_impl._mlflow_client.create_experiment.assert_called_once_with( name="test", artifact_location="my_artifact_location" ) @pytest.mark.parametrize("synchronous", [False, True]) -@mock.patch("lightning.pytorch.loggers.mlflow._get_resolve_tags", Mock()) +@mock.patch("pytorch_lightning_enterprise.loggers.mlflow._get_resolve_tags", Mock()) def test_mlflow_logger_experiment_calls_with_synchronous(mlflow_mock, tmp_path, synchronous): """Test that the logger calls methods on the mlflow experiment with the specified synchronous flag.""" @@ -302,8 +302,8 @@ def test_mlflow_logger_experiment_calls_with_synchronous(mlflow_mock, tmp_path, mlflow_client.create_experiment.assert_called_once_with(name="test", artifact_location="my_artifact_location") -@mock.patch("lightning.pytorch.loggers.mlflow._get_resolve_tags", Mock()) -@mock.patch.dict("lightning.pytorch.loggers.mlflow.__dict__", {"_MLFLOW_SYNCHRONOUS_AVAILABLE": False}) +@mock.patch("pytorch_lightning_enterprise.loggers.mlflow._get_resolve_tags", Mock()) +@mock.patch("pytorch_lightning_enterprise.utils.imports._MLFLOW_SYNCHRONOUS_AVAILABLE", False) def test_mlflow_logger_no_synchronous_support(mlflow_mock, tmp_path): """Test that the logger does not support synchronous flag.""" time = mlflow_mock.entities.time @@ -315,7 +315,7 @@ def test_mlflow_logger_no_synchronous_support(mlflow_mock, tmp_path): MLFlowLogger("test", save_dir=str(tmp_path), artifact_location="my_artifact_location", synchronous=True) -@mock.patch("lightning.pytorch.loggers.mlflow._get_resolve_tags", Mock()) +@mock.patch("pytorch_lightning_enterprise.loggers.mlflow._get_resolve_tags", Mock()) def test_mlflow_logger_with_long_param_value(mlflow_mock, tmp_path): """Test that long parameter values are truncated to 250 characters.""" @@ -333,7 +333,7 @@ def _check_value_length(value, *args, **kwargs): logger.experiment.log_batch.assert_called_once() -@mock.patch("lightning.pytorch.loggers.mlflow._get_resolve_tags", Mock()) +@mock.patch("pytorch_lightning_enterprise.loggers.mlflow._get_resolve_tags", Mock()) def test_mlflow_logger_with_many_params(mlflow_mock, tmp_path): """Test that when logging more than 100 parameters, it will be split into batches of at most 100 parameters.""" logger = MLFlowLogger("test", save_dir=str(tmp_path)) @@ -352,37 +352,37 @@ def test_mlflow_logger_with_many_params(mlflow_mock, tmp_path): ("finished", "FINISHED"), ], ) -@mock.patch("lightning.pytorch.loggers.mlflow._get_resolve_tags", Mock()) +@mock.patch("pytorch_lightning_enterprise.loggers.mlflow._get_resolve_tags", Mock()) def test_mlflow_logger_finalize(mlflow_mock, status, expected): logger = MLFlowLogger("test") # Pretend we are in a worker process and finalizing _ = logger.experiment - assert logger._initialized + assert logger.logger_impl._initialized logger.finalize(status) logger.experiment.set_terminated.assert_called_once_with(logger.run_id, expected) -@mock.patch("lightning.pytorch.loggers.mlflow._get_resolve_tags", Mock()) +@mock.patch("pytorch_lightning_enterprise.loggers.mlflow._get_resolve_tags", Mock()) def test_mlflow_logger_finalize_when_exception(mlflow_mock): logger = MLFlowLogger("test") # Pretend we are on the main process and failing - assert logger._mlflow_client - assert not logger._initialized + assert logger.logger_impl._mlflow_client + assert not logger.logger_impl._initialized logger.finalize("failed") logger.experiment.set_terminated.assert_not_called() # Pretend we are in a worker process and failing _ = logger.experiment - assert logger._initialized + assert logger.logger_impl._initialized logger.finalize("failed") logger.experiment.set_terminated.assert_called_once_with(logger.run_id, "FAILED") @pytest.mark.parametrize("log_model", ["all", True, False]) -@mock.patch("lightning.pytorch.loggers.mlflow._get_resolve_tags", Mock()) +@mock.patch("pytorch_lightning_enterprise.loggers.mlflow._get_resolve_tags", Mock()) def test_mlflow_log_model(mlflow_mock, log_model, tmp_path): """Test that the logger creates the folders and files in the right place.""" client = mlflow_mock.tracking.MlflowClient @@ -420,7 +420,7 @@ def test_mlflow_log_model(mlflow_mock, log_model, tmp_path): assert not client.return_value.log_artifacts.called -@mock.patch("lightning.pytorch.loggers.mlflow._get_resolve_tags", Mock()) +@mock.patch("pytorch_lightning_enterprise.loggers.mlflow._get_resolve_tags", Mock()) def test_set_tracking_uri(mlflow_mock): """Test that the tracking uri is set for logging artifacts to MLFlow server.""" logger = MLFlowLogger(tracking_uri="the_tracking_uri") diff --git a/tests/tests_pytorch/loggers/test_neptune.py b/tests/tests_pytorch/loggers/test_neptune.py index 572b85715da86..e50bcf967ef9a 100644 --- a/tests/tests_pytorch/loggers/test_neptune.py +++ b/tests/tests_pytorch/loggers/test_neptune.py @@ -13,7 +13,6 @@ # limitations under the License. import os import pickle -from collections import namedtuple from unittest import mock from unittest.mock import MagicMock, call @@ -45,10 +44,10 @@ def _fit_and_test(logger, model, tmp_path): def _get_logger_with_mocks(**kwargs): logger = NeptuneLogger(**kwargs) run_instance_mock = MagicMock() - logger._run_instance = run_instance_mock - logger._run_instance.__getitem__.return_value.fetch.return_value = "exp-name" + logger.logger_impl._run_instance = run_instance_mock + logger.logger_impl._run_instance.__getitem__.return_value.fetch.return_value = "exp-name" run_attr_mock = MagicMock() - logger._run_instance.__getitem__.return_value = run_attr_mock + logger.logger_impl._run_instance.__getitem__.return_value = run_attr_mock return logger, run_instance_mock, run_attr_mock @@ -60,7 +59,7 @@ def test_neptune_online(neptune_mock): logger = NeptuneLogger(api_key="test", project="project") created_run_mock = logger.run - assert logger._run_instance == created_run_mock + assert logger.logger_impl._run_instance == created_run_mock created_run_mock.exists.assert_called_once_with("sys/id") assert logger.name == "Run test name" assert logger.version == "TEST-1" @@ -79,8 +78,8 @@ def test_neptune_offline(neptune_mock): logger.experiment["foo"] = "bar" created_run_mock.exists.assert_called_once_with("sys/id") - assert logger._run_short_id == "OFFLINE" - assert logger._run_name == "offline-name" + assert logger.logger_impl._run_short_id == "OFFLINE" + assert logger.logger_impl._run_name == "offline-name" def test_online_with_custom_run(neptune_mock): @@ -91,8 +90,8 @@ def test_online_with_custom_run(neptune_mock): neptune_mock.init_run.reset_mock() logger = NeptuneLogger(run=created_run) - assert logger._run_instance == created_run - assert logger._run_instance == created_run + assert logger.logger_impl._run_instance == created_run + assert logger.logger_impl._run_instance == created_run assert logger.version == "TEST-1" assert neptune_mock.init_run.call_count == 0 @@ -275,38 +274,6 @@ def test_save_dir(neptune_mock): assert logger.save_dir == os.path.join(os.getcwd(), ".neptune") -def test_get_full_model_name(): - SimpleCheckpoint = namedtuple("SimpleCheckpoint", ["dirpath"]) - test_input_data = [ - ("key", os.path.join("foo", "bar", "key.ext"), SimpleCheckpoint(dirpath=os.path.join("foo", "bar"))), - ( - "key/in/parts", - os.path.join("foo", "bar", "key/in/parts.ext"), - SimpleCheckpoint(dirpath=os.path.join("foo", "bar")), - ), - ("key", os.path.join("../foo", "bar", "key.ext"), SimpleCheckpoint(dirpath=os.path.join("../foo", "bar"))), - ("key", os.path.join("foo", "key.ext"), SimpleCheckpoint(dirpath=os.path.join("./foo", "bar/../"))), - ] - - for expected_model_name, model_path, checkpoint in test_input_data: - assert NeptuneLogger._get_full_model_name(model_path, checkpoint) == expected_model_name - - -def test_get_full_model_names_from_exp_structure(): - input_dict = { - "foo": { - "bar": { - "lvl1_1": {"lvl2": {"lvl3_1": "some non important value", "lvl3_2": "some non important value"}}, - "lvl1_2": "some non important value", - }, - "other_non_important": {"val100": 100}, - }, - "other_non_important": {"val42": 42}, - } - expected_keys = {"lvl1_1/lvl2/lvl3_1", "lvl1_1/lvl2/lvl3_2", "lvl1_2"} - assert NeptuneLogger._get_full_model_names_from_exp_structure(input_dict, "foo/bar") == expected_keys - - def test_inactive_run(neptune_mock, tmp_path, monkeypatch): from neptune.exceptions import InactiveRunException diff --git a/tests/tests_pytorch/loggers/test_wandb.py b/tests/tests_pytorch/loggers/test_wandb.py index e9b9e9a8090b0..0409cbd7b8f2c 100644 --- a/tests/tests_pytorch/loggers/test_wandb.py +++ b/tests/tests_pytorch/loggers/test_wandb.py @@ -25,7 +25,6 @@ from lightning.pytorch.cli import LightningCLI from lightning.pytorch.demos.boring_classes import BoringModel from lightning.pytorch.loggers import TensorBoardLogger, WandbLogger -from lightning.pytorch.utilities.exceptions import MisconfigurationException def test_wandb_project_name(wandb_mock): @@ -101,7 +100,7 @@ def test_wandb_logger_init(wandb_mock): _ = logger.experiment # verify default resume value - assert logger._wandb_init["resume"] == "allow" + assert logger.logger_impl._wandb_init["resume"] == "allow" logger.log_metrics({"acc": 1.0}, step=3) wandb_mock.init.assert_called_once() @@ -145,9 +144,9 @@ def test_wandb_logger_sync_tensorboard_log_metrics(wandb_mock): def test_wandb_logger_init_before_spawn(wandb_mock): logger = WandbLogger() - assert logger._experiment is None - logger.__getstate__() - assert logger._experiment is not None + assert logger.logger_impl._experiment is None + logger.logger_impl.__getstate__() + assert logger.logger_impl._experiment is not None def test_wandb_logger_experiment_called_first(wandb_mock, tmp_path): @@ -626,7 +625,7 @@ def test_wandb_log_media(wandb_mock, tmp_path): def test_wandb_logger_offline_log_model(wandb_mock, tmp_path): """Test that log_model=True raises an error in offline mode.""" - with pytest.raises(MisconfigurationException, match="checkpoints cannot be uploaded in offline mode"): + with pytest.raises(ValueError, match="checkpoints cannot be uploaded in offline mode"): _ = WandbLogger(save_dir=tmp_path, offline=True, log_model=True) diff --git a/tests/tests_pytorch/plugins/precision/test_deepspeed_precision.py b/tests/tests_pytorch/plugins/precision/test_deepspeed_precision.py index da4ce6b89aaab..86eb93c4542a0 100644 --- a/tests/tests_pytorch/plugins/precision/test_deepspeed_precision.py +++ b/tests/tests_pytorch/plugins/precision/test_deepspeed_precision.py @@ -19,7 +19,7 @@ def test_invalid_precision_with_deepspeed_precision(): - with pytest.raises(ValueError, match="is not supported. `precision` must be one of"): + with pytest.raises(ValueError, match=r"is not supported in DeepSpeed. `precision` must be one of"): DeepSpeedPrecision(precision="64-true") diff --git a/tests/tests_pytorch/plugins/precision/test_transformer_engine.py b/tests/tests_pytorch/plugins/precision/test_transformer_engine.py index a9967280e3f23..6055f9ecd6e60 100644 --- a/tests/tests_pytorch/plugins/precision/test_transformer_engine.py +++ b/tests/tests_pytorch/plugins/precision/test_transformer_engine.py @@ -16,16 +16,16 @@ from unittest.mock import ANY, Mock import pytest +import pytorch_lightning_enterprise.utils.imports import torch -import lightning.fabric from lightning.pytorch import LightningModule, Trainer from lightning.pytorch.plugins import TransformerEnginePrecision from lightning.pytorch.trainer.connectors.accelerator_connector import _AcceleratorConnector def test_transformer_engine_precision_plugin(monkeypatch): - module = lightning.fabric.plugins.precision.transformer_engine + module = pytorch_lightning_enterprise.plugins.precision.transformer_engine if module._TRANSFORMER_ENGINE_AVAILABLE: pytest.skip("Assumes transformer_engine is unavailable") monkeypatch.setattr(module, "_TRANSFORMER_ENGINE_AVAILABLE", lambda: True) @@ -44,7 +44,7 @@ def test_transformer_engine_precision_plugin(monkeypatch): def test_configure_model(monkeypatch): - module = lightning.fabric.plugins.precision.transformer_engine + module = pytorch_lightning_enterprise.plugins.precision.transformer_engine if module._TRANSFORMER_ENGINE_AVAILABLE: pytest.skip("Assumes transformer_engine is unavailable") monkeypatch.setattr(module, "_TRANSFORMER_ENGINE_AVAILABLE", lambda: True) diff --git a/tests/tests_pytorch/strategies/test_deepspeed.py b/tests/tests_pytorch/strategies/test_deepspeed.py index 047e2d24f77a3..4f22a602b28aa 100644 --- a/tests/tests_pytorch/strategies/test_deepspeed.py +++ b/tests/tests_pytorch/strategies/test_deepspeed.py @@ -153,7 +153,7 @@ def test_deepspeed_precision_choice(cuda_count_1, tmp_path): def test_deepspeed_with_invalid_config_path(): """Test to ensure if we pass an invalid config path we throw an exception.""" with pytest.raises( - MisconfigurationException, match="You passed in a path to a DeepSpeed config but the path does not exist" + FileNotFoundError, match="You passed in a path to a DeepSpeed config but the path does not exist" ): DeepSpeedStrategy(config="invalid_path.json") @@ -1004,7 +1004,7 @@ def test_deepspeed_strategy_env_variables(mock_deepspeed_distributed, tmp_path, strategy = trainer.strategy assert isinstance(strategy, DeepSpeedStrategy) with mock.patch("platform.system", return_value=platform) as mock_platform: - strategy._init_deepspeed_distributed() + strategy.deepspeed_strategy_impl._init_deepspeed_distributed() mock_deepspeed_distributed.assert_called() mock_platform.assert_called() if platform == "Windows": @@ -1267,6 +1267,7 @@ def test_validate_parallel_devices_indices(device_indices): """ accelerator = Mock(spec=CUDAAccelerator) + accelerator.name.return_value = "cuda" strategy = DeepSpeedStrategy( accelerator=accelerator, parallel_devices=[torch.device("cuda", i) for i in device_indices] ) diff --git a/tests/tests_pytorch/strategies/test_xla.py b/tests/tests_pytorch/strategies/test_xla.py index 3fde2600c9483..0fd2afa98690b 100644 --- a/tests/tests_pytorch/strategies/test_xla.py +++ b/tests/tests_pytorch/strategies/test_xla.py @@ -50,14 +50,14 @@ def test_rank_properties_access(xla_available): strategy.cluster_environment = Mock() # we're in the main process, no processes have been launched yet - assert not strategy._launched + assert not strategy.xla_strategy_impl._launched assert strategy.global_rank == 0 assert strategy.local_rank == 0 assert strategy.node_rank == 0 assert strategy.world_size == 1 # simulate we're in a worker process - strategy._launched = True + strategy.xla_strategy_impl._launched = True assert strategy.global_rank == strategy.cluster_environment.global_rank() assert strategy.local_rank == strategy.cluster_environment.local_rank() assert strategy.node_rank == strategy.cluster_environment.node_rank() diff --git a/tests/tests_pytorch/utilities/migration/test_utils.py b/tests/tests_pytorch/utilities/migration/test_utils.py index 41ca1a779f8a5..8bb696a659469 100644 --- a/tests/tests_pytorch/utilities/migration/test_utils.py +++ b/tests/tests_pytorch/utilities/migration/test_utils.py @@ -101,7 +101,10 @@ def test_patch_legacy_imports_unified(pl_version): assert any(key.startswith("lightning." + "pytorch") for key in sys.modules), ( f"Imported unified package, so it has to be in sys.modules: {_list_sys_modules('lightning' + '.pytorch')}" ) - assert not any(key.startswith("pytorch_lightning") for key in sys.modules), ( + assert not any( + key.startswith("pytorch_lightning") and not key.startswith("pytorch_lightning_enterprise") + for key in sys.modules + ), ( "Should not import standalone package, all imports should be redirected to the unified package;\n" f" environment: {_list_sys_modules('pytorch_lightning')}" )