diff --git a/docs/source-fabric/advanced/compile.rst b/docs/source-fabric/advanced/compile.rst index 17ba6e4ca9dc8..df79454f67a6f 100644 --- a/docs/source-fabric/advanced/compile.rst +++ b/docs/source-fabric/advanced/compile.rst @@ -115,9 +115,115 @@ always exclude the first call to ``forward()`` from your measurements, since it Compile median time: 0.0185 seconds Speedup: 1.4x - ---- +********************************************** +Apply torch.compile with ModelParallelStrategy +********************************************** + +:func:`torch.compile` can also be invoked as part of the `parallelize_fn` argument of :class:`~lightning.fabric.strategies.model_parallel.ModelParallelStrategy`. + +This is particularly handy when :func:`torch.compile` is used in combination with the `torch.distributed.tensor` API. + +Here is an example: + +.. code-block:: python + + import lightning as L + import torch + import torch.nn as nn + import torch.nn.functional as F + from lightning.pytorch.demos import Transformer + from lightning.fabric.strategies.model_parallel import ModelParallelStrategy + from torch.distributed._composable.fsdp.fully_shard import fully_shard + from torch.distributed.device_mesh import DeviceMesh + + def parallelize(model: nn.Module, device_mesh: DeviceMesh) -> nn.Module: + for module in model.modules(): + if isinstance(module, (torch.nn.TransformerEncoderLayer, torch.nn.TransformerDecoderLayer)): + fully_shard(module, mesh=device_mesh) + + fully_shard(model, mesh=device_mesh) + + return torch.compile(model) + + def train(): + L.seed_everything(42) + + with torch.device("meta"): + model = Transformer( + vocab_size=50257, + nlayers=16, + nhid=4096, + ninp=1024, + nhead=32, + ) + + strategy = ModelParallelStrategy(data_parallel_size=4, tensor_parallel_size=1, parallelize_fn=parallelize) + + fabric = L.Fabric(precision="bf16-true", strategy=strategy) + fabric.launch() + + model = fabric.setup(model) + +The advantage here is that `parallelize` is called when sharding the model, +so :func:`torch.compile` is guaranteed to run on model shards and capture distributed operations. + +Also, when using other libraries like `torch ao `_ +that need to be applied in a similar fashion, it's easy to reason about the sequence of calls +needed to achieve the equivalent of `compile(distributed(quantized(model)))`: + +.. code-block:: python + + import lightning as L + import torch + import torch.nn as nn + import torch.nn.functional as F + from lightning.pytorch.demos import Transformer + from torch.distributed._composable.fsdp.fully_shard import fully_shard + from torch.distributed.device_mesh import DeviceMesh + from torchao.float8 import Float8LinearConfig, convert_to_float8_training + + def parallelize(model: nn.Module, device_mesh: DeviceMesh) -> nn.Module: + float8_config = Float8LinearConfig( + pad_inner_dim=True, + ) + + def module_filter_fn(mod: torch.nn.Module, fqn: str): + return fqn != "decoder" + + convert_to_float8_training(model, config=float8_config, module_filter_fn=module_filter_fn) + + for module in model.modules(): + if isinstance(module, (torch.nn.TransformerEncoderLayer, torch.nn.TransformerDecoderLayer)): + fully_shard(module, mesh=device_mesh) + + fully_shard(model, mesh=device_mesh) + + return torch.compile(model) + + def train(): + L.seed_everything(42) + + with torch.device("meta"): + model = Transformer( + vocab_size=50257, + nlayers=16, + nhid=4096, + ninp=1024, + nhead=32, + ) + + strategy = ModelParallelStrategy(data_parallel_size=4, tensor_parallel_size=1, parallelize_fn=parallelize) + + fabric = L.Fabric(precision="bf16-true", strategy=strategy) + fabric.launch() + + model = fabric.setup(model) + +For a full example, see our `FP8 Distributed Transformer example `_. + +---- ****************** Avoid graph breaks diff --git a/docs/source-pytorch/advanced/compile.rst b/docs/source-pytorch/advanced/compile.rst index d5bd333c041b3..16fe91ca282df 100644 --- a/docs/source-pytorch/advanced/compile.rst +++ b/docs/source-pytorch/advanced/compile.rst @@ -138,6 +138,122 @@ always exclude the first call to ``forward()``/``*_step()`` from your measuremen ---- +************************************** +Apply torch.compile in configure_model +************************************** + +:func:`torch.compile` can also be invoked as part of the :meth:`~lightning.pytorch.core.hooks.ModelHooks.configure_model` hook. + +This is particularly handy when :func:`torch.compile` is used in combination with :class:`~lightning.pytorch.strategies.model_parallel.ModelParallelStrategy`. + +Here is an example: + +.. code-block:: python + + import lightning as L + import torch + import torch.nn as nn + import torch.nn.functional as F + from lightning.pytorch.demos import Transformer + from lightning.pytorch.strategies.model_parallel import ModelParallelStrategy + from torch.distributed.device_mesh import DeviceMesh + from torch.distributed._composable.fsdp.fully_shard import fully_shard + + class LanguageModel(L.LightningModule): + def __init__(self, vocab_size): + super().__init__() + self.vocab_size = vocab_size + self.model = None + + def configure_model(self): + if self.model is not None: + return + + with torch.device("meta"): + model = Transformer( + vocab_size=self.vocab_size, + nlayers=16, + nhid=4096, + ninp=1024, + nhead=32, + ) + + for module in model.modules(): + if isinstance(module, (nn.TransformerEncoderLayer, nn.TransformerDecoderLayer)): + fully_shard(module, mesh=self.device_mesh) + + fully_shard(model, mesh=self.device_mesh) + + self.model = torch.compile(model) + + def training_step(self, batch): + input, target = batch + output = self.model(input, target) + loss = F.nll_loss(output, target.view(-1)) + self.log("train_loss", loss) + return loss + + def configure_optimizers(self): + return torch.optim.Adam(self.parameters(), lr=1e-4) + +The advantage here is that `configure_model` is called when sharding the model, +so :func:`torch.compile` is guaranteed to run on model shards and capture distributed operations. + +Also, when using other libraries like `torch ao `_ +that need to be applied in a similar fashion, it's easy to reason about the sequence of calls +needed to achieve the equivalent of `compile(distributed(quantized(model)))`: + +.. code-block:: python + + import lightning as L + import torch + import torch.nn as nn + import torch.nn.functional as F + from lightning.pytorch.demos import Transformer + from lightning.pytorch.strategies.model_parallel import ModelParallelStrategy + from torch.distributed._composable.fsdp.fully_shard import fully_shard + from torch.distributed.device_mesh import DeviceMesh + from torchao.float8 import Float8LinearConfig, convert_to_float8_training + + class LanguageModel(L.LightningModule): + def __init__(self, vocab_size): + super().__init__() + self.vocab_size = vocab_size + self.model = None + + def configure_model(self): + if self.model is not None: + return + + with torch.device("meta"): + model = Transformer( + vocab_size=self.vocab_size, + nlayers=16, + nhid=4096, + ninp=1024, + nhead=32, + ) + + float8_config = Float8LinearConfig( + pad_inner_dim=True, + ) + + def module_filter_fn(mod: torch.nn.Module, fqn: str): + return fqn != "decoder" + + convert_to_float8_training(model, config=float8_config, module_filter_fn=module_filter_fn) + + for module in model.modules(): + if isinstance(module, (nn.TransformerEncoderLayer, nn.TransformerDecoderLayer)): + fully_shard(module, mesh=self.device_mesh) + + fully_shard(model, mesh=self.device_mesh) + + self.model = torch.compile(model) + +For a full example, see our `FP8 Distributed Transformer example `_. + +---- ****************** Avoid graph breaks @@ -253,8 +369,8 @@ Limitations There are a few limitations you should be aware of when using ``torch.compile`` **in conjunction with the Trainer**: -* The Trainer currently does not reapply ``torch.compile`` over DDP/FSDP, meaning distributed operations can't benefit from speed ups at the moment. - This limitation will be lifted in the future. +* The Trainer currently does not reapply ``torch.compile`` over :class:`~lightning.pytorch.strategies.DDPStrategy` and :class:`~lightning.pytorch.strategies.FSDPStrategy`, meaning distributed operations can't benefit from speed ups at the moment. + This limitation can be avoided by using :class:`~lightning.pytorch.strategies.model_parallel.ModelParallelStrategy`, as described in `Apply torch.compile in configure_model`_ above. * In some cases, using ``self.log()`` in your LightningModule will cause compilation errors. Until addressed, you can work around these issues by applying ``torch.compile`` to the submodule(s) of your LightningModule rather than to the entire LightningModule at once. diff --git a/examples/fabric/fp8_distributed_transformer/README.md b/examples/fabric/fp8_distributed_transformer/README.md new file mode 100644 index 0000000000000..e980d759bb3ff --- /dev/null +++ b/examples/fabric/fp8_distributed_transformer/README.md @@ -0,0 +1,39 @@ +## Distributed, Low-Precision Transformer Example + +This example shows how to use `ModelParallelStrategy` in `Fabric` to train a Transformer model minimizing memory usage, maximizing throughput, and distributing load across multiple GPUs. + +### Training Large Models and Memory Requirements + +One of the main challenges when training large models, like large language models (LLMs), is dealing with their memory footprint. LLMs can be so large that weights, activations, gradients and optimizer state don't fit a single GPU, so that they need to be distributed across multiple GPUs, and across multiple machines. There are multiple ways of distributing computations, among which fully-sharded data parallelism (FSDP) and tensor parallelism (TP). + +An additional way of reducing memory requirements is representing floating point numbers in weights and activations in low numerical precision, such as 16-bit (`bfloat16`), or 8-bit (`fp8`). This leads to savings in memory usage, as well as memory bandwidth usage (fewer bytes transferred from device memory to GPU cores in unit time). + +Roughly, reducing precision to `fp8` for linear layers can lead to 2x reduction in memory requirements and 1.6x improvement in throughput. Support for `fp8` weights and activations requires recent GPUs - Hopper, Ada Lovelace and above (e.g. H100, L4, L40). + +The introduction of tensor subclasses in PyTorch brought two new APIs that can be used to achieve memory savings and distributed training (as well as inference) in combination: + +- [torch ao](https://github.com/pytorch/ao) to execute linear layers in low numerical precision (`fp8` and other quantized formats) +- [dtensors](https://pytorch.org/docs/stable/distributed.tensor.html) to distribute models across GPUs, by combining TP and FSDP (referred to FSDP2 in PyTorch) + +Notably, `torch ao` introduces quantization and dequantization operations in the model that may result in slow-downs if not optimized. Using `torch.compile` after `torch ao` recovers performance by generating optimized kernels for those operations. + +### Vanilla Transformer Example + +This example shows how to train a vanilla Transformer model using `fp8` precision and the FSDP2 distributed strategy, and then optimize the resulting model through `torch.compile`. + +Specifically, we employ the `ModelParallelStrategy`, and use the `configure_model` hook to distribute the model using the PyTorch DTensor API. +In the same hook we also pass the model through the `torch ao` API (prior to FSDP2), as well as `torch.compile` (after FSDP2). + +The resulting code follows the PyTorch API closely, while also taking advantage of the rest of PyTorch Lightning. + +To execute the code directly just run: + +```bash +python train.py +``` + +### A Note on torch.compile + +Note that PyTorch Lightning also supports calling `torch.compile` on a `LightningModule` and passing it to the `Trainer`. + +While this works for simple cases, in order to get the most out of the combination of the latest distributed, quantization, and compile PyTorch API's, we recommend invoking `torch.compile` at the end of the `configure_model` hook, as shown in this example. diff --git a/examples/fabric/fp8_distributed_transformer/requirements.txt b/examples/fabric/fp8_distributed_transformer/requirements.txt new file mode 100644 index 0000000000000..ce00e191aa9c1 --- /dev/null +++ b/examples/fabric/fp8_distributed_transformer/requirements.txt @@ -0,0 +1 @@ +torchao>=0.7.0 diff --git a/examples/fabric/fp8_distributed_transformer/train.py b/examples/fabric/fp8_distributed_transformer/train.py new file mode 100644 index 0000000000000..ba88603268945 --- /dev/null +++ b/examples/fabric/fp8_distributed_transformer/train.py @@ -0,0 +1,100 @@ +import lightning as L +import torch +import torch.nn as nn +import torch.nn.functional as F +from lightning.fabric.strategies import ModelParallelStrategy +from lightning.pytorch.demos import Transformer, WikiText2 +from torch.distributed._composable.fsdp.fully_shard import fully_shard +from torch.distributed.device_mesh import DeviceMesh +from torch.utils.data import DataLoader +from torchao.float8 import Float8LinearConfig, convert_to_float8_training +from tqdm import tqdm + + +def configure_model(model: nn.Module, device_mesh: DeviceMesh) -> nn.Module: + float8_config = Float8LinearConfig( + # pip install -U --index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/Triton-Nightly/pypi/simple/ triton-nightly # noqa + pad_inner_dim=True, + ) + + def module_filter_fn(mod: torch.nn.Module, fqn: str): + # we skip the decoder because it typically vocabulary size + # is not divisible by 16 as required by float8 + return fqn != "decoder" + + convert_to_float8_training(model, config=float8_config, module_filter_fn=module_filter_fn) + + for module in model.modules(): + if isinstance(module, (torch.nn.TransformerEncoderLayer, torch.nn.TransformerDecoderLayer)): + fully_shard(module, mesh=device_mesh) + + fully_shard(model, mesh=device_mesh) + + return torch.compile(model) + + +def train(): + L.seed_everything(42) + + batch_size = 8 + micro_batch_size = 1 + + max_steps = 100 + + dataset = WikiText2() + dataloader = DataLoader(dataset, num_workers=8, batch_size=micro_batch_size) + + with torch.device("meta"): + model = Transformer( + vocab_size=dataset.vocab_size, + nlayers=16, + nhid=4096, + ninp=1024, + nhead=32, + ) + + strategy = ModelParallelStrategy(data_parallel_size=4, tensor_parallel_size=1, parallelize_fn=configure_model) + + fabric = L.Fabric(precision="bf16-true", strategy=strategy) + fabric.launch() + + model = fabric.setup(model) + + optimizer = torch.optim.Adam(model.parameters(), lr=1e-4) + optimizer = fabric.setup_optimizers(optimizer) + + dataloader = fabric.setup_dataloaders(dataloader) + + iterable = tqdm(enumerate(dataloader), total=len(dataloader)) if fabric.is_global_zero else enumerate(dataloader) + + steps = 0 + + for i, batch in iterable: + input, target = batch + + is_accumulating = i % (batch_size // micro_batch_size) != 0 + + with fabric.no_backward_sync(model, enabled=is_accumulating): + output = model(input, target) + loss = F.nll_loss(output, target.view(-1)) + fabric.backward(loss) + + if not is_accumulating: + fabric.clip_gradients(model, optimizer, max_norm=1.0) + optimizer.step() + optimizer.zero_grad() + steps += 1 + + if fabric.is_global_zero: + iterable.set_postfix_str(f"train_loss={loss.item():.2f}") + + if steps == max_steps: + break + + fabric.print(torch.cuda.memory_summary()) + + +if __name__ == "__main__": + torch.set_float32_matmul_precision("high") + + train() diff --git a/examples/pytorch/fp8_distributed_transformer/README.md b/examples/pytorch/fp8_distributed_transformer/README.md new file mode 100644 index 0000000000000..6c5e12d14da4a --- /dev/null +++ b/examples/pytorch/fp8_distributed_transformer/README.md @@ -0,0 +1,39 @@ +## Distributed, Low-Precision Transformer Example + +This example shows how to use `ModelParallelStrategy` in `Fabric` to train a Transformer model minimizing memory usage, maximizing throughput, and distributing load across multiple GPUs. + +### Training Large Models and Memory Requirements + +One of the main challenges when training large models, like large language models (LLMs), is dealing with their memory footprint. LLMs can be so large that weights, activations, gradients and optimizer state don't fit a single GPU, so that they need to be distributed across multiple GPUs, and across multiple machines. There are multiple ways of distributing computations, among which fully-sharded data parallelism (FSDP) and tensor parallelism (TP). + +An additional way of reducing memory requirements is representing floating point numbers in weights and activations in low numerical precision, such as 16-bit (`bfloat16`), or 8-bit (`fp8`). This leads to savings in memory usage, as well as memory bandwidth usage (fewer bytes transferred from device memory to GPU cores in unit time). + +Roughly, reducing precision to `fp8` for linear layers can lead to 2x reduction in memory requirements and 1.6x improvement in throughput. Support for `fp8` weights and activations requires recent GPUs - Hopper, Ada Lovelace and above (e.g. H100, L4, L40). + +The introduction of tensor subclasses in PyTorch brought two new APIs that can be used to achieve memory savings and distributed training (as well as inference) in combination: + +- [torch ao](https://github.com/pytorch/ao) to execute linear layers in low numerical precision (`fp8` and other quantized formats) +- [dtensors](https://pytorch.org/docs/stable/distributed.tensor.html) to distribute models across GPUs, by combining TP and FSDP (referred to FSDP2 in PyTorch) + +Notably, `torch ao` introduces quantization and dequantization operations in the model that may result in slow-downs if not optimized. Using `torch.compile` after `torch ao` recovers performance by generating optimized kernels for those operations. + +### Vanilla Transformer Example + +This example shows how to train a vanilla Transformer model using `fp8` precision and the FSDP2 distributed strategy, and then optimize the resulting model through `torch.compile`. + +Specifically, we employ the `ModelParallelStrategy`, which accepts a `parallelize_fn` to distribute the model using the PyTorch DTensor API. +We use the same function to also pass the model through the `torch ao` API (prior to FSDP2), as well as `torch.compile` (after FSDP2). + +The resulting code follows the PyTorch API closely, while also taking advantage of the rest of Lightning Fabric. + +To execute the code directly just run: + +```bash +python train.py +``` + +### A Note on torch.compile + +Note that Fabric also supports calling `torch.compile` on a model and passing it to `fabric.setup_model` or `fabric.setup_model_and_optimizers`. + +While this works well, in order to get the most out of the combination of the latest distributed, quantization, and compile PyTorch API's, we recommend invoking `torch.compile` as part of the `parallelize_fn` argument of `ModelParallelStrategy`, as shown in this example. diff --git a/examples/pytorch/fp8_distributed_transformer/requirements.txt b/examples/pytorch/fp8_distributed_transformer/requirements.txt new file mode 100644 index 0000000000000..ce00e191aa9c1 --- /dev/null +++ b/examples/pytorch/fp8_distributed_transformer/requirements.txt @@ -0,0 +1 @@ +torchao>=0.7.0 diff --git a/examples/pytorch/fp8_distributed_transformer/train.py b/examples/pytorch/fp8_distributed_transformer/train.py new file mode 100644 index 0000000000000..6c7be98ee7dbd --- /dev/null +++ b/examples/pytorch/fp8_distributed_transformer/train.py @@ -0,0 +1,85 @@ +import lightning as L +import torch +import torch.nn as nn +import torch.nn.functional as F +from lightning.pytorch.demos import Transformer, WikiText2 +from lightning.pytorch.strategies import ModelParallelStrategy +from torch.distributed._composable.fsdp.fully_shard import fully_shard +from torch.utils.data import DataLoader +from torchao.float8 import Float8LinearConfig, convert_to_float8_training + + +class LanguageModel(L.LightningModule): + def __init__(self, vocab_size): + super().__init__() + self.vocab_size = vocab_size + self.model = None + + def configure_model(self): + if self.model is not None: + return + + with torch.device("meta"): + model = Transformer( + vocab_size=self.vocab_size, + nlayers=16, + nhid=4096, + ninp=1024, + nhead=32, + ) + + float8_config = Float8LinearConfig( + # pip install -U --index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/Triton-Nightly/pypi/simple/ triton-nightly # noqa + pad_inner_dim=True, + ) + + def module_filter_fn(mod: torch.nn.Module, fqn: str): + # we skip the decoder because it typically vocabulary size + # is not divisible by 16 as required by float8 + return fqn != "decoder" + + convert_to_float8_training(model, config=float8_config, module_filter_fn=module_filter_fn) + + for module in model.modules(): + if isinstance(module, (nn.TransformerEncoderLayer, nn.TransformerDecoderLayer)): + fully_shard(module, mesh=self.device_mesh) + + fully_shard(model, mesh=self.device_mesh) + + self.model = torch.compile(model) + + def training_step(self, batch): + input, target = batch + output = self.model(input, target) + loss = F.nll_loss(output, target.view(-1)) + self.log("train_loss", loss, prog_bar=True) + return loss + + def configure_optimizers(self): + return torch.optim.Adam(self.parameters(), lr=1e-4) + + +def train(): + L.seed_everything(42) + + dataset = WikiText2() + train_dataloader = DataLoader(dataset, num_workers=8, batch_size=1) + + model = LanguageModel(vocab_size=dataset.vocab_size) + + mp_strategy = ModelParallelStrategy( + data_parallel_size=4, + tensor_parallel_size=1, + ) + + trainer = L.Trainer(strategy=mp_strategy, max_steps=100, precision="bf16-true", accumulate_grad_batches=8) + + trainer.fit(model, train_dataloader) + + trainer.print(torch.cuda.memory_summary()) + + +if __name__ == "__main__": + torch.set_float32_matmul_precision("high") + + train() diff --git a/tests/run_standalone_tests.sh b/tests/run_standalone_tests.sh index 9aa54f7350607..75a52e16c57dc 100755 --- a/tests/run_standalone_tests.sh +++ b/tests/run_standalone_tests.sh @@ -17,7 +17,7 @@ set -e # Batch size for testing: Determines how many standalone test invocations run in parallel # It can be set through the env variable PL_STANDALONE_TESTS_BATCH_SIZE and defaults to 6 if not set -test_batch_size="${PL_STANDALONE_TESTS_BATCH_SIZE:-6}" +test_batch_size="${PL_STANDALONE_TESTS_BATCH_SIZE:-3}" source="${PL_STANDALONE_TESTS_SOURCE:-"lightning"}" # this is the directory where the tests are located test_dir=$1 # parse the first argument diff --git a/tests/tests_fabric/conftest.py b/tests/tests_fabric/conftest.py index 446994167d0a1..5fdc61a08955b 100644 --- a/tests/tests_fabric/conftest.py +++ b/tests/tests_fabric/conftest.py @@ -69,6 +69,7 @@ def restore_env_variables(): "OMP_NUM_THREADS", # set by our launchers # set by torchdynamo "TRITON_CACHE_DIR", + "TORCHINDUCTOR_CACHE_DIR", } leaked_vars.difference_update(allowlist) assert not leaked_vars, f"test is leaking environment variable(s): {set(leaked_vars)}" diff --git a/tests/tests_fabric/strategies/test_model_parallel_integration.py b/tests/tests_fabric/strategies/test_model_parallel_integration.py index dfbdb16b10060..b04a29b691529 100644 --- a/tests/tests_fabric/strategies/test_model_parallel_integration.py +++ b/tests/tests_fabric/strategies/test_model_parallel_integration.py @@ -29,6 +29,13 @@ from tests_fabric.helpers.runif import RunIf +@pytest.fixture() +def distributed(): + yield + if torch.distributed.is_initialized(): + torch.distributed.destroy_process_group() + + class FeedForward(nn.Module): def __init__(self): super().__init__() @@ -81,7 +88,7 @@ def _parallelize_feed_forward_fsdp2_tp(model, device_mesh): @RunIf(min_torch="2.4", standalone=True, min_cuda_gpus=4) -def test_setup_device_mesh(): +def test_setup_device_mesh(distributed): from torch.distributed.device_mesh import DeviceMesh for dp_size, tp_size in ((1, 4), (4, 1), (2, 2)): @@ -116,11 +123,28 @@ def test_setup_device_mesh(): assert fabric.strategy.device_mesh.size(1) == 4 +def _parallelize_with_compile(parallelize): + def fn(model, device_mesh): + model = parallelize(model, device_mesh) + return torch.compile(model) + + return fn + + @RunIf(min_torch="2.4", standalone=True, min_cuda_gpus=2) -def test_tensor_parallel(): +@pytest.mark.parametrize( + "compile", + [True, False], +) +def test_tensor_parallel(distributed, compile): from torch.distributed._tensor import DTensor - strategy = ModelParallelStrategy(parallelize_fn=_parallelize_feed_forward_tp) + parallelize = _parallelize_feed_forward_tp + + if compile: + parallelize = _parallelize_with_compile(parallelize) + + strategy = ModelParallelStrategy(parallelize_fn=parallelize) fabric = Fabric(accelerator="auto", devices=2, strategy=strategy) fabric.launch() @@ -161,9 +185,18 @@ def test_tensor_parallel(): @RunIf(min_torch="2.4", standalone=True, min_cuda_gpus=4) -def test_fsdp2_tensor_parallel(): +@pytest.mark.parametrize( + "compile", + [True, False], +) +def test_fsdp2_tensor_parallel(distributed, compile): from torch.distributed._tensor import DTensor + parallelize = _parallelize_feed_forward_fsdp2_tp + + if compile: + parallelize = _parallelize_with_compile(parallelize) + strategy = ModelParallelStrategy( parallelize_fn=_parallelize_feed_forward_fsdp2_tp, data_parallel_size=2, @@ -238,6 +271,7 @@ def _train(fabric, model=None, optimizer=None): @RunIf(min_torch="2.4", min_cuda_gpus=4, standalone=True) +@pytest.mark.filterwarnings("ignore::UserWarning") @pytest.mark.parametrize( "precision", [ @@ -245,7 +279,7 @@ def _train(fabric, model=None, optimizer=None): pytest.param("bf16-mixed", marks=RunIf(bf16_cuda=True)), ], ) -def test_train_save_load(precision, tmp_path): +def test_train_save_load(distributed, precision, tmp_path): """Test 2D-parallel training, saving and loading precision settings.""" strategy = ModelParallelStrategy( _parallelize_feed_forward_fsdp2_tp, @@ -303,7 +337,7 @@ def test_train_save_load(precision, tmp_path): @pytest.mark.filterwarnings("ignore::FutureWarning") @RunIf(min_torch="2.4", min_cuda_gpus=2, standalone=True) -def test_save_full_state_dict(tmp_path): +def test_save_full_state_dict(distributed, tmp_path): """Test that ModelParallelStrategy saves the full state into a single file with `save_distributed_checkpoint=False`.""" from torch.distributed.checkpoint.state_dict import get_optimizer_state_dict @@ -404,7 +438,7 @@ def test_save_full_state_dict(tmp_path): @pytest.mark.filterwarnings("ignore::FutureWarning") @RunIf(min_torch="2.4", min_cuda_gpus=2, standalone=True) -def test_load_full_state_dict_into_sharded_model(tmp_path): +def test_load_full_state_dict_into_sharded_model(distributed, tmp_path): """Test that the strategy can load a full-state checkpoint into a distributed model.""" fabric = Fabric(accelerator="cuda", devices=1) fabric.seed_everything(0) @@ -450,7 +484,7 @@ def test_load_full_state_dict_into_sharded_model(tmp_path): @RunIf(min_torch="2.4", min_cuda_gpus=2, skip_windows=True, standalone=True) @pytest.mark.parametrize("move_to_device", [True, False]) @mock.patch("lightning.fabric.wrappers._FabricModule") -def test_setup_module_move_to_device(fabric_module_mock, move_to_device): +def test_setup_module_move_to_device(fabric_module_mock, move_to_device, distributed): """Test that `move_to_device` does nothing, ModelParallel decides which device parameters get moved to which device (sharding).""" from torch.distributed._tensor import DTensor @@ -482,7 +516,7 @@ def test_setup_module_move_to_device(fabric_module_mock, move_to_device): pytest.param("bf16-true", torch.bfloat16, marks=RunIf(bf16_cuda=True)), ], ) -def test_module_init_context(precision, expected_dtype): +def test_module_init_context(distributed, precision, expected_dtype): """Test that the module under the init-context gets moved to the right device and dtype.""" strategy = ModelParallelStrategy(parallelize_fn=_parallelize_feed_forward_fsdp2) fabric = Fabric(accelerator="cuda", devices=2, strategy=strategy, precision=precision) @@ -505,7 +539,7 @@ def _run_setup_assertions(empty_init, expected_device): @RunIf(min_torch="2.4", min_cuda_gpus=2, standalone=True) -def test_save_filter(tmp_path): +def test_save_filter(distributed, tmp_path): strategy = ModelParallelStrategy( parallelize_fn=_parallelize_feed_forward_fsdp2, save_distributed_checkpoint=False, @@ -558,7 +592,7 @@ def _parallelize_single_linear_tp_fsdp2(model, device_mesh): "val", ], ) -def test_clip_gradients(clip_type, precision): +def test_clip_gradients(distributed, clip_type, precision): strategy = ModelParallelStrategy(_parallelize_single_linear_tp_fsdp2) fabric = Fabric(accelerator="auto", devices=2, precision=precision, strategy=strategy) fabric.launch() @@ -600,7 +634,7 @@ def test_clip_gradients(clip_type, precision): @RunIf(min_torch="2.4", min_cuda_gpus=4, standalone=True) -def test_save_sharded_and_consolidate_and_load(tmp_path): +def test_save_sharded_and_consolidate_and_load(distributed, tmp_path): """Test the consolidation of a distributed (DTensor) checkpoint into a single file.""" strategy = ModelParallelStrategy( _parallelize_feed_forward_fsdp2_tp, @@ -657,7 +691,7 @@ def test_save_sharded_and_consolidate_and_load(tmp_path): @RunIf(min_torch="2.4", min_cuda_gpus=2, standalone=True) -def test_load_raw_module_state(): +def test_load_raw_module_state(distributed): from torch.distributed.device_mesh import init_device_mesh from torch.distributed.tensor.parallel import ColwiseParallel, parallelize_module diff --git a/tests/tests_pytorch/strategies/test_model_parallel_integration.py b/tests/tests_pytorch/strategies/test_model_parallel_integration.py index 57d273917573a..9dcbcc802834b 100644 --- a/tests/tests_pytorch/strategies/test_model_parallel_integration.py +++ b/tests/tests_pytorch/strategies/test_model_parallel_integration.py @@ -78,10 +78,26 @@ def _parallelize_feed_forward_fsdp2_tp(model, device_mesh): return model +def _parallelize_with_compile(parallelize): + def fn(model, device_mesh): + model = parallelize(model, device_mesh) + return torch.compile(model) + + return fn + + +@pytest.fixture() +def distributed(): + yield + if torch.distributed.is_initialized(): + torch.distributed.destroy_process_group() + + class TemplateModel(LightningModule): - def __init__(self): + def __init__(self, compile=False): super().__init__() self.model = FeedForward() + self._compile = compile def training_step(self, batch): output = self.model(batch) @@ -98,21 +114,30 @@ def configure_optimizers(self): class FSDP2Model(TemplateModel): def configure_model(self): - _parallelize_feed_forward_fsdp2(self.model, device_mesh=self.device_mesh) + parallelize = _parallelize_feed_forward_fsdp2_tp + if self._compile: + parallelize = _parallelize_with_compile(parallelize) + parallelize(self.model, device_mesh=self.device_mesh) class TensorParallelModel(TemplateModel): def configure_model(self): - _parallelize_feed_forward_tp(self.model, device_mesh=self.device_mesh) + parallelize = _parallelize_feed_forward_tp + if self._compile: + parallelize = _parallelize_with_compile(parallelize) + parallelize(self.model, device_mesh=self.device_mesh) class FSDP2TensorParallelModel(TemplateModel): def configure_model(self): - _parallelize_feed_forward_fsdp2_tp(self.model, device_mesh=self.device_mesh) + parallelize = _parallelize_feed_forward_fsdp2_tp + if self._compile: + parallelize = _parallelize_with_compile(parallelize) + parallelize(self.model, device_mesh=self.device_mesh) @RunIf(min_torch="2.4", standalone=True, min_cuda_gpus=4) -def test_setup_device_mesh(): +def test_setup_device_mesh(distributed): from torch.distributed.device_mesh import DeviceMesh for dp_size, tp_size in ((1, 4), (4, 1), (2, 2)): @@ -169,7 +194,11 @@ def configure_model(self): @RunIf(min_torch="2.4", standalone=True, min_cuda_gpus=2) -def test_tensor_parallel(): +@pytest.mark.parametrize( + "compile", + [True, False], +) +def test_tensor_parallel(distributed, compile): from torch.distributed._tensor import DTensor class Model(TensorParallelModel): @@ -204,13 +233,17 @@ def training_step(self, batch): seed_everything(0) with trainer.init_module(empty_init=True): - model = Model() + model = Model(compile=compile) trainer.fit(model) @RunIf(min_torch="2.4", standalone=True, min_cuda_gpus=4) -def test_fsdp2_tensor_parallel(): +@pytest.mark.parametrize( + "compile", + [True, False], +) +def test_fsdp2_tensor_parallel(distributed, compile): from torch.distributed._tensor import DTensor class Model(FSDP2TensorParallelModel): @@ -261,13 +294,13 @@ def training_step(self, batch): seed_everything(0) with trainer.init_module(empty_init=True): - model = Model() + model = Model(compile=compile) trainer.fit(model) @RunIf(min_torch="2.4", min_cuda_gpus=2, standalone=True) -def test_modules_without_parameters(tmp_path): +def test_modules_without_parameters(distributed, tmp_path): """Test that TorchMetrics get moved to the device despite not having any parameters.""" class MetricsModel(TensorParallelModel): @@ -306,7 +339,11 @@ def training_step(self, batch): pytest.param("bf16-true", torch.bfloat16, marks=RunIf(bf16_cuda=True)), ], ) -def test_module_init_context(precision, expected_dtype, tmp_path): +@pytest.mark.parametrize( + "compile", + [True, False], +) +def test_module_init_context(distributed, compile, precision, expected_dtype, tmp_path): """Test that the module under the init-context gets moved to the right device and dtype.""" class Model(FSDP2Model): @@ -329,7 +366,7 @@ def _run_setup_assertions(empty_init, expected_device): logger=False, ) with trainer.init_module(empty_init=empty_init): - model = Model() + model = Model(compile=compile) # The model is on the CPU/meta-device until after `ModelParallelStrategy.setup()` assert model.model.w1.weight.device == expected_device @@ -345,7 +382,7 @@ def _run_setup_assertions(empty_init, expected_device): @RunIf(min_torch="2.4", min_cuda_gpus=2, skip_windows=True, standalone=True) @pytest.mark.parametrize("save_distributed_checkpoint", [True, False]) -def test_strategy_state_dict(tmp_path, save_distributed_checkpoint): +def test_strategy_state_dict(distributed, tmp_path, save_distributed_checkpoint): """Test that the strategy returns the correct state dict of the LightningModule.""" model = FSDP2Model() correct_state_dict = model.state_dict() # State dict before wrapping @@ -378,7 +415,7 @@ def test_strategy_state_dict(tmp_path, save_distributed_checkpoint): @RunIf(min_torch="2.4", min_cuda_gpus=2, skip_windows=True, standalone=True) -def test_load_full_state_checkpoint_into_regular_model(tmp_path): +def test_load_full_state_checkpoint_into_regular_model(distributed, tmp_path): """Test that a full-state checkpoint saved from a distributed model can be loaded back into a regular model.""" # Save a regular full-state checkpoint from a distributed model @@ -420,7 +457,7 @@ def test_load_full_state_checkpoint_into_regular_model(tmp_path): @pytest.mark.filterwarnings("ignore::FutureWarning") @RunIf(min_torch="2.4", min_cuda_gpus=2, skip_windows=True, standalone=True) -def test_load_standard_checkpoint_into_distributed_model(tmp_path): +def test_load_standard_checkpoint_into_distributed_model(distributed, tmp_path): """Test that a regular checkpoint (weights and optimizer states) can be loaded into a distributed model.""" # Save a regular DDP checkpoint @@ -461,7 +498,7 @@ def test_load_standard_checkpoint_into_distributed_model(tmp_path): @pytest.mark.filterwarnings("ignore::FutureWarning") @RunIf(min_torch="2.4", min_cuda_gpus=2, standalone=True) -def test_save_load_sharded_state_dict(tmp_path): +def test_save_load_sharded_state_dict(distributed, tmp_path): """Test saving and loading with the distributed state dict format.""" class CheckpointModel(FSDP2Model):