From c6695f8a2ec5f8500425b08cc0881a8cb1b58498 Mon Sep 17 00:00:00 2001 From: Luca Antiga Date: Fri, 22 Nov 2024 16:09:37 +0000 Subject: [PATCH 01/13] Minimal transformer examples --- examples/fabric/fp8_fsdp2_compile/README.md | 0 examples/fabric/fp8_fsdp2_compile/train.py | 106 +++++++++++++++++++ examples/pytorch/fp8_fsdp2_compile/README.md | 0 examples/pytorch/fp8_fsdp2_compile/train.py | 94 ++++++++++++++++ 4 files changed, 200 insertions(+) create mode 100644 examples/fabric/fp8_fsdp2_compile/README.md create mode 100644 examples/fabric/fp8_fsdp2_compile/train.py create mode 100644 examples/pytorch/fp8_fsdp2_compile/README.md create mode 100644 examples/pytorch/fp8_fsdp2_compile/train.py diff --git a/examples/fabric/fp8_fsdp2_compile/README.md b/examples/fabric/fp8_fsdp2_compile/README.md new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/examples/fabric/fp8_fsdp2_compile/train.py b/examples/fabric/fp8_fsdp2_compile/train.py new file mode 100644 index 0000000000000..329c31f3da3ef --- /dev/null +++ b/examples/fabric/fp8_fsdp2_compile/train.py @@ -0,0 +1,106 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.utils.data import DataLoader +from torch.distributed._composable.fsdp.fully_shard import fully_shard +from torch.distributed.device_mesh import DeviceMesh + +from torchao.float8 import convert_to_float8_training, Float8LinearConfig + +import lightning as L +from lightning.fabric.strategies import ModelParallelStrategy +from lightning.pytorch.demos import Transformer, WikiText2 + +from tqdm import tqdm + + +def configure_model(model: nn.Module, device_mesh: DeviceMesh) -> nn.Module: + float8_config = Float8LinearConfig( + # pip install -U --index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/Triton-Nightly/pypi/simple/ triton-nightly + pad_inner_dim=True, + ) + + def module_filter_fn(mod: torch.nn.Module, fqn: str): + # we skip the decoder because it typically vocabulary size + # is not divisible by 16 as required by float8 + if fqn == "decoder": + return False + return True + + convert_to_float8_training(model, config=float8_config, module_filter_fn=module_filter_fn) + + for module in model.modules(): + if isinstance(module, (torch.nn.TransformerEncoderLayer, torch.nn.TransformerDecoderLayer)): + fully_shard(module, mesh=device_mesh) + + fully_shard(model, mesh=device_mesh) + + model = torch.compile(model) + + return model + + +def train(): + L.seed_everything(42) + + batch_size = 8 + micro_batch_size = 1 + + dataset = WikiText2() + dataloader = DataLoader(dataset, num_workers=8, batch_size=micro_batch_size) + + with torch.device("meta"): + model = Transformer( + vocab_size=dataset.vocab_size, + nlayers=16, + nhid=4096, + ninp=1024, + nhead=32, + ) + + strategy = ModelParallelStrategy( + data_parallel_size=4, + tensor_parallel_size=1, + parallelize_fn=configure_model + ) + + fabric = L.Fabric(precision="bf16-true", strategy=strategy) + fabric.launch() + + model = fabric.setup(model) + + optimizer = torch.optim.Adam(model.parameters(), lr=1e-4) + optimizer = fabric.setup_optimizers(optimizer) + + dataloader = fabric.setup_dataloaders(dataloader) + + iterable = tqdm(enumerate(dataloader), total=len(dataloader)) if fabric.is_global_zero else enumerate(dataloader) + + for i, batch in iterable: + input, target = batch + + is_accumulating = i % (batch_size // micro_batch_size) != 0 + + with fabric.no_backward_sync(model, enabled=is_accumulating): + output = model(input, target) + loss = F.nll_loss(output, target.view(-1)) + fabric.backward(loss) + + if not is_accumulating: + fabric.clip_gradients(model, optimizer, max_norm=1.0) + optimizer.step() + optimizer.zero_grad() + + if fabric.is_global_zero: + iterable.set_postfix_str(f"train_loss={loss.item():.2f}") + + if i // (batch_size // micro_batch_size) > 100: + break + + fabric.print(torch.cuda.memory_summary()) + + +if __name__ == "__main__": + torch.set_float32_matmul_precision('high') + + train() diff --git a/examples/pytorch/fp8_fsdp2_compile/README.md b/examples/pytorch/fp8_fsdp2_compile/README.md new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/examples/pytorch/fp8_fsdp2_compile/train.py b/examples/pytorch/fp8_fsdp2_compile/train.py new file mode 100644 index 0000000000000..ed1184e2d59b8 --- /dev/null +++ b/examples/pytorch/fp8_fsdp2_compile/train.py @@ -0,0 +1,94 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.utils.data import DataLoader +from torch.distributed._composable.fsdp.fully_shard import fully_shard + +from torchao.float8 import convert_to_float8_training, Float8LinearConfig + +import lightning as L +from lightning.pytorch.strategies import ModelParallelStrategy +from lightning.pytorch.demos import Transformer, WikiText2 + + +class LanguageModel(L.LightningModule): + def __init__(self, vocab_size): + super().__init__() + self.vocab_size = vocab_size + self.model = None + + def configure_model(self): + if self.model is not None: + return + + with torch.device("meta"): + model = Transformer( + vocab_size=self.vocab_size, + nlayers=16, + nhid=4096, + ninp=1024, + nhead=32, + ) + + float8_config = Float8LinearConfig( + # pip install -U --index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/Triton-Nightly/pypi/simple/ triton-nightly + pad_inner_dim=True, + ) + + def module_filter_fn(mod: torch.nn.Module, fqn: str): + # we skip the decoder because it typically vocabulary size + # is not divisible by 16 as required by float8 + if fqn == "decoder": + return False + return True + + convert_to_float8_training(model, config=float8_config, module_filter_fn=module_filter_fn) + + for module in model.modules(): + if isinstance(module, (nn.TransformerEncoderLayer, nn.TransformerDecoderLayer)): + fully_shard(module, mesh=self.device_mesh) + + fully_shard(model, mesh=self.device_mesh) + + self.model = torch.compile(model) + + def training_step(self, batch): + input, target = batch + output = self.model(input, target) + loss = F.nll_loss(output, target.view(-1)) + self.log("train_loss", loss, prog_bar=True) + return loss + + def configure_optimizers(self): + return torch.optim.Adam(self.parameters(), lr=1e-4) + + +def train(): + L.seed_everything(42) + + dataset = WikiText2() + train_dataloader = DataLoader(dataset, num_workers=8, batch_size=1) + + model = LanguageModel(vocab_size=dataset.vocab_size) + + mp_strategy = ModelParallelStrategy( + data_parallel_size=4, + tensor_parallel_size=1, + ) + + trainer = L.Trainer( + strategy=mp_strategy, + max_steps=100, + precision="bf16-true", + accumulate_grad_batches=8 + ) + + trainer.fit(model, train_dataloader) + + trainer.print(torch.cuda.memory_summary()) + + +if __name__ == "__main__": + torch.set_float32_matmul_precision('high') + + train() From fcf115a449fcd917c484e90ba88463968f04193a Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 22 Nov 2024 16:13:35 +0000 Subject: [PATCH 02/13] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- examples/fabric/fp8_fsdp2_compile/train.py | 21 +++++++-------------- examples/pytorch/fp8_fsdp2_compile/train.py | 21 +++++++-------------- 2 files changed, 14 insertions(+), 28 deletions(-) diff --git a/examples/fabric/fp8_fsdp2_compile/train.py b/examples/fabric/fp8_fsdp2_compile/train.py index 329c31f3da3ef..96ccca324d86f 100644 --- a/examples/fabric/fp8_fsdp2_compile/train.py +++ b/examples/fabric/fp8_fsdp2_compile/train.py @@ -1,16 +1,13 @@ +import lightning as L import torch import torch.nn as nn import torch.nn.functional as F -from torch.utils.data import DataLoader -from torch.distributed._composable.fsdp.fully_shard import fully_shard -from torch.distributed.device_mesh import DeviceMesh - -from torchao.float8 import convert_to_float8_training, Float8LinearConfig - -import lightning as L from lightning.fabric.strategies import ModelParallelStrategy from lightning.pytorch.demos import Transformer, WikiText2 - +from torch.distributed._composable.fsdp.fully_shard import fully_shard +from torch.distributed.device_mesh import DeviceMesh +from torch.utils.data import DataLoader +from torchao.float8 import Float8LinearConfig, convert_to_float8_training from tqdm import tqdm @@ -58,11 +55,7 @@ def train(): nhead=32, ) - strategy = ModelParallelStrategy( - data_parallel_size=4, - tensor_parallel_size=1, - parallelize_fn=configure_model - ) + strategy = ModelParallelStrategy(data_parallel_size=4, tensor_parallel_size=1, parallelize_fn=configure_model) fabric = L.Fabric(precision="bf16-true", strategy=strategy) fabric.launch() @@ -101,6 +94,6 @@ def train(): if __name__ == "__main__": - torch.set_float32_matmul_precision('high') + torch.set_float32_matmul_precision("high") train() diff --git a/examples/pytorch/fp8_fsdp2_compile/train.py b/examples/pytorch/fp8_fsdp2_compile/train.py index ed1184e2d59b8..cd243dfb52976 100644 --- a/examples/pytorch/fp8_fsdp2_compile/train.py +++ b/examples/pytorch/fp8_fsdp2_compile/train.py @@ -1,14 +1,12 @@ +import lightning as L import torch import torch.nn as nn import torch.nn.functional as F -from torch.utils.data import DataLoader -from torch.distributed._composable.fsdp.fully_shard import fully_shard - -from torchao.float8 import convert_to_float8_training, Float8LinearConfig - -import lightning as L -from lightning.pytorch.strategies import ModelParallelStrategy from lightning.pytorch.demos import Transformer, WikiText2 +from lightning.pytorch.strategies import ModelParallelStrategy +from torch.distributed._composable.fsdp.fully_shard import fully_shard +from torch.utils.data import DataLoader +from torchao.float8 import Float8LinearConfig, convert_to_float8_training class LanguageModel(L.LightningModule): @@ -76,12 +74,7 @@ def train(): tensor_parallel_size=1, ) - trainer = L.Trainer( - strategy=mp_strategy, - max_steps=100, - precision="bf16-true", - accumulate_grad_batches=8 - ) + trainer = L.Trainer(strategy=mp_strategy, max_steps=100, precision="bf16-true", accumulate_grad_batches=8) trainer.fit(model, train_dataloader) @@ -89,6 +82,6 @@ def train(): if __name__ == "__main__": - torch.set_float32_matmul_precision('high') + torch.set_float32_matmul_precision("high") train() From 8be123a34eaeeaa1f878707e941314f4883cdb64 Mon Sep 17 00:00:00 2001 From: Luca Antiga Date: Mon, 25 Nov 2024 22:33:15 +0100 Subject: [PATCH 03/13] Add tests for compile after fsdp2/tp --- .../fabric/fp8_fsdp2_compile/requirements.txt | 1 + examples/fabric/fp8_fsdp2_compile/train.py | 17 ++++--- .../fp8_fsdp2_compile/requirements.txt | 1 + examples/pytorch/fp8_fsdp2_compile/train.py | 6 +-- .../test_model_parallel_integration.py | 32 ++++++++++-- .../test_model_parallel_integration.py | 50 +++++++++++++++---- 6 files changed, 82 insertions(+), 25 deletions(-) create mode 100644 examples/fabric/fp8_fsdp2_compile/requirements.txt create mode 100644 examples/pytorch/fp8_fsdp2_compile/requirements.txt diff --git a/examples/fabric/fp8_fsdp2_compile/requirements.txt b/examples/fabric/fp8_fsdp2_compile/requirements.txt new file mode 100644 index 0000000000000..ce00e191aa9c1 --- /dev/null +++ b/examples/fabric/fp8_fsdp2_compile/requirements.txt @@ -0,0 +1 @@ +torchao>=0.7.0 diff --git a/examples/fabric/fp8_fsdp2_compile/train.py b/examples/fabric/fp8_fsdp2_compile/train.py index 96ccca324d86f..ba88603268945 100644 --- a/examples/fabric/fp8_fsdp2_compile/train.py +++ b/examples/fabric/fp8_fsdp2_compile/train.py @@ -13,16 +13,14 @@ def configure_model(model: nn.Module, device_mesh: DeviceMesh) -> nn.Module: float8_config = Float8LinearConfig( - # pip install -U --index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/Triton-Nightly/pypi/simple/ triton-nightly + # pip install -U --index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/Triton-Nightly/pypi/simple/ triton-nightly # noqa pad_inner_dim=True, ) def module_filter_fn(mod: torch.nn.Module, fqn: str): # we skip the decoder because it typically vocabulary size # is not divisible by 16 as required by float8 - if fqn == "decoder": - return False - return True + return fqn != "decoder" convert_to_float8_training(model, config=float8_config, module_filter_fn=module_filter_fn) @@ -32,9 +30,7 @@ def module_filter_fn(mod: torch.nn.Module, fqn: str): fully_shard(model, mesh=device_mesh) - model = torch.compile(model) - - return model + return torch.compile(model) def train(): @@ -43,6 +39,8 @@ def train(): batch_size = 8 micro_batch_size = 1 + max_steps = 100 + dataset = WikiText2() dataloader = DataLoader(dataset, num_workers=8, batch_size=micro_batch_size) @@ -69,6 +67,8 @@ def train(): iterable = tqdm(enumerate(dataloader), total=len(dataloader)) if fabric.is_global_zero else enumerate(dataloader) + steps = 0 + for i, batch in iterable: input, target = batch @@ -83,11 +83,12 @@ def train(): fabric.clip_gradients(model, optimizer, max_norm=1.0) optimizer.step() optimizer.zero_grad() + steps += 1 if fabric.is_global_zero: iterable.set_postfix_str(f"train_loss={loss.item():.2f}") - if i // (batch_size // micro_batch_size) > 100: + if steps == max_steps: break fabric.print(torch.cuda.memory_summary()) diff --git a/examples/pytorch/fp8_fsdp2_compile/requirements.txt b/examples/pytorch/fp8_fsdp2_compile/requirements.txt new file mode 100644 index 0000000000000..ce00e191aa9c1 --- /dev/null +++ b/examples/pytorch/fp8_fsdp2_compile/requirements.txt @@ -0,0 +1 @@ +torchao>=0.7.0 diff --git a/examples/pytorch/fp8_fsdp2_compile/train.py b/examples/pytorch/fp8_fsdp2_compile/train.py index cd243dfb52976..6c7be98ee7dbd 100644 --- a/examples/pytorch/fp8_fsdp2_compile/train.py +++ b/examples/pytorch/fp8_fsdp2_compile/train.py @@ -29,16 +29,14 @@ def configure_model(self): ) float8_config = Float8LinearConfig( - # pip install -U --index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/Triton-Nightly/pypi/simple/ triton-nightly + # pip install -U --index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/Triton-Nightly/pypi/simple/ triton-nightly # noqa pad_inner_dim=True, ) def module_filter_fn(mod: torch.nn.Module, fqn: str): # we skip the decoder because it typically vocabulary size # is not divisible by 16 as required by float8 - if fqn == "decoder": - return False - return True + return fqn != "decoder" convert_to_float8_training(model, config=float8_config, module_filter_fn=module_filter_fn) diff --git a/tests/tests_fabric/strategies/test_model_parallel_integration.py b/tests/tests_fabric/strategies/test_model_parallel_integration.py index dfbdb16b10060..0fc4caaf449a5 100644 --- a/tests/tests_fabric/strategies/test_model_parallel_integration.py +++ b/tests/tests_fabric/strategies/test_model_parallel_integration.py @@ -116,11 +116,28 @@ def test_setup_device_mesh(): assert fabric.strategy.device_mesh.size(1) == 4 +def _parallelize_with_compile(parallelize): + def fn(model, device_mesh): + model = parallelize(model, device_mesh) + return torch.compile(model) + + return fn + + @RunIf(min_torch="2.4", standalone=True, min_cuda_gpus=2) -def test_tensor_parallel(): +@pytest.mark.parametrize( + "compile", + [True, False], +) +def test_tensor_parallel(compile): from torch.distributed._tensor import DTensor - strategy = ModelParallelStrategy(parallelize_fn=_parallelize_feed_forward_tp) + parallelize = _parallelize_feed_forward_tp + + if compile: + parallelize = _parallelize_with_compile(parallelize) + + strategy = ModelParallelStrategy(parallelize_fn=parallelize) fabric = Fabric(accelerator="auto", devices=2, strategy=strategy) fabric.launch() @@ -161,9 +178,18 @@ def test_tensor_parallel(): @RunIf(min_torch="2.4", standalone=True, min_cuda_gpus=4) -def test_fsdp2_tensor_parallel(): +@pytest.mark.parametrize( + "compile", + [True, False], +) +def test_fsdp2_tensor_parallel(compile): from torch.distributed._tensor import DTensor + parallelize = _parallelize_feed_forward_fsdp2_tp + + if compile: + parallelize = _parallelize_with_compile(parallelize) + strategy = ModelParallelStrategy( parallelize_fn=_parallelize_feed_forward_fsdp2_tp, data_parallel_size=2, diff --git a/tests/tests_pytorch/strategies/test_model_parallel_integration.py b/tests/tests_pytorch/strategies/test_model_parallel_integration.py index 57d273917573a..0228d952f9cbd 100644 --- a/tests/tests_pytorch/strategies/test_model_parallel_integration.py +++ b/tests/tests_pytorch/strategies/test_model_parallel_integration.py @@ -78,10 +78,19 @@ def _parallelize_feed_forward_fsdp2_tp(model, device_mesh): return model +def _parallelize_with_compile(parallelize): + def fn(model, device_mesh): + model = parallelize(model, device_mesh) + return torch.compile(model) + + return fn + + class TemplateModel(LightningModule): - def __init__(self): + def __init__(self, compile=False): super().__init__() self.model = FeedForward() + self._compile = compile def training_step(self, batch): output = self.model(batch) @@ -98,17 +107,26 @@ def configure_optimizers(self): class FSDP2Model(TemplateModel): def configure_model(self): - _parallelize_feed_forward_fsdp2(self.model, device_mesh=self.device_mesh) + parallelize = _parallelize_feed_forward_fsdp2_tp + if self._compile: + parallelize = _parallelize_with_compile(parallelize) + parallelize(self.model, device_mesh=self.device_mesh) class TensorParallelModel(TemplateModel): def configure_model(self): - _parallelize_feed_forward_tp(self.model, device_mesh=self.device_mesh) + parallelize = _parallelize_feed_forward_tp + if self._compile: + parallelize = _parallelize_with_compile(parallelize) + parallelize(self.model, device_mesh=self.device_mesh) class FSDP2TensorParallelModel(TemplateModel): def configure_model(self): - _parallelize_feed_forward_fsdp2_tp(self.model, device_mesh=self.device_mesh) + parallelize = _parallelize_feed_forward_fsdp2_tp + if self._compile: + parallelize = _parallelize_with_compile(parallelize) + parallelize(self.model, device_mesh=self.device_mesh) @RunIf(min_torch="2.4", standalone=True, min_cuda_gpus=4) @@ -169,7 +187,11 @@ def configure_model(self): @RunIf(min_torch="2.4", standalone=True, min_cuda_gpus=2) -def test_tensor_parallel(): +@pytest.mark.parametrize( + "compile", + [True, False], +) +def test_tensor_parallel(compile): from torch.distributed._tensor import DTensor class Model(TensorParallelModel): @@ -204,13 +226,17 @@ def training_step(self, batch): seed_everything(0) with trainer.init_module(empty_init=True): - model = Model() + model = Model(compile=compile) trainer.fit(model) @RunIf(min_torch="2.4", standalone=True, min_cuda_gpus=4) -def test_fsdp2_tensor_parallel(): +@pytest.mark.parametrize( + "compile", + [True, False], +) +def test_fsdp2_tensor_parallel(compile): from torch.distributed._tensor import DTensor class Model(FSDP2TensorParallelModel): @@ -261,7 +287,7 @@ def training_step(self, batch): seed_everything(0) with trainer.init_module(empty_init=True): - model = Model() + model = Model(compile=compile) trainer.fit(model) @@ -306,7 +332,11 @@ def training_step(self, batch): pytest.param("bf16-true", torch.bfloat16, marks=RunIf(bf16_cuda=True)), ], ) -def test_module_init_context(precision, expected_dtype, tmp_path): +@pytest.mark.parametrize( + "compile", + [True, False], +) +def test_module_init_context(compile, precision, expected_dtype, tmp_path): """Test that the module under the init-context gets moved to the right device and dtype.""" class Model(FSDP2Model): @@ -329,7 +359,7 @@ def _run_setup_assertions(empty_init, expected_device): logger=False, ) with trainer.init_module(empty_init=empty_init): - model = Model() + model = Model(compile=compile) # The model is on the CPU/meta-device until after `ModelParallelStrategy.setup()` assert model.model.w1.weight.device == expected_device From 27c36d9b7f5c1ace154e165f7868aeb123c37fdc Mon Sep 17 00:00:00 2001 From: Luca Antiga Date: Mon, 25 Nov 2024 23:26:03 +0100 Subject: [PATCH 04/13] Add README's --- examples/fabric/fp8_fsdp2_compile/README.md | 39 ++++++++++++++++++++ examples/pytorch/fp8_fsdp2_compile/README.md | 39 ++++++++++++++++++++ 2 files changed, 78 insertions(+) diff --git a/examples/fabric/fp8_fsdp2_compile/README.md b/examples/fabric/fp8_fsdp2_compile/README.md index e69de29bb2d1d..e980d759bb3ff 100644 --- a/examples/fabric/fp8_fsdp2_compile/README.md +++ b/examples/fabric/fp8_fsdp2_compile/README.md @@ -0,0 +1,39 @@ +## Distributed, Low-Precision Transformer Example + +This example shows how to use `ModelParallelStrategy` in `Fabric` to train a Transformer model minimizing memory usage, maximizing throughput, and distributing load across multiple GPUs. + +### Training Large Models and Memory Requirements + +One of the main challenges when training large models, like large language models (LLMs), is dealing with their memory footprint. LLMs can be so large that weights, activations, gradients and optimizer state don't fit a single GPU, so that they need to be distributed across multiple GPUs, and across multiple machines. There are multiple ways of distributing computations, among which fully-sharded data parallelism (FSDP) and tensor parallelism (TP). + +An additional way of reducing memory requirements is representing floating point numbers in weights and activations in low numerical precision, such as 16-bit (`bfloat16`), or 8-bit (`fp8`). This leads to savings in memory usage, as well as memory bandwidth usage (fewer bytes transferred from device memory to GPU cores in unit time). + +Roughly, reducing precision to `fp8` for linear layers can lead to 2x reduction in memory requirements and 1.6x improvement in throughput. Support for `fp8` weights and activations requires recent GPUs - Hopper, Ada Lovelace and above (e.g. H100, L4, L40). + +The introduction of tensor subclasses in PyTorch brought two new APIs that can be used to achieve memory savings and distributed training (as well as inference) in combination: + +- [torch ao](https://github.com/pytorch/ao) to execute linear layers in low numerical precision (`fp8` and other quantized formats) +- [dtensors](https://pytorch.org/docs/stable/distributed.tensor.html) to distribute models across GPUs, by combining TP and FSDP (referred to FSDP2 in PyTorch) + +Notably, `torch ao` introduces quantization and dequantization operations in the model that may result in slow-downs if not optimized. Using `torch.compile` after `torch ao` recovers performance by generating optimized kernels for those operations. + +### Vanilla Transformer Example + +This example shows how to train a vanilla Transformer model using `fp8` precision and the FSDP2 distributed strategy, and then optimize the resulting model through `torch.compile`. + +Specifically, we employ the `ModelParallelStrategy`, and use the `configure_model` hook to distribute the model using the PyTorch DTensor API. +In the same hook we also pass the model through the `torch ao` API (prior to FSDP2), as well as `torch.compile` (after FSDP2). + +The resulting code follows the PyTorch API closely, while also taking advantage of the rest of PyTorch Lightning. + +To execute the code directly just run: + +```bash +python train.py +``` + +### A Note on torch.compile + +Note that PyTorch Lightning also supports calling `torch.compile` on a `LightningModule` and passing it to the `Trainer`. + +While this works for simple cases, in order to get the most out of the combination of the latest distributed, quantization, and compile PyTorch API's, we recommend invoking `torch.compile` at the end of the `configure_model` hook, as shown in this example. diff --git a/examples/pytorch/fp8_fsdp2_compile/README.md b/examples/pytorch/fp8_fsdp2_compile/README.md index e69de29bb2d1d..6c5e12d14da4a 100644 --- a/examples/pytorch/fp8_fsdp2_compile/README.md +++ b/examples/pytorch/fp8_fsdp2_compile/README.md @@ -0,0 +1,39 @@ +## Distributed, Low-Precision Transformer Example + +This example shows how to use `ModelParallelStrategy` in `Fabric` to train a Transformer model minimizing memory usage, maximizing throughput, and distributing load across multiple GPUs. + +### Training Large Models and Memory Requirements + +One of the main challenges when training large models, like large language models (LLMs), is dealing with their memory footprint. LLMs can be so large that weights, activations, gradients and optimizer state don't fit a single GPU, so that they need to be distributed across multiple GPUs, and across multiple machines. There are multiple ways of distributing computations, among which fully-sharded data parallelism (FSDP) and tensor parallelism (TP). + +An additional way of reducing memory requirements is representing floating point numbers in weights and activations in low numerical precision, such as 16-bit (`bfloat16`), or 8-bit (`fp8`). This leads to savings in memory usage, as well as memory bandwidth usage (fewer bytes transferred from device memory to GPU cores in unit time). + +Roughly, reducing precision to `fp8` for linear layers can lead to 2x reduction in memory requirements and 1.6x improvement in throughput. Support for `fp8` weights and activations requires recent GPUs - Hopper, Ada Lovelace and above (e.g. H100, L4, L40). + +The introduction of tensor subclasses in PyTorch brought two new APIs that can be used to achieve memory savings and distributed training (as well as inference) in combination: + +- [torch ao](https://github.com/pytorch/ao) to execute linear layers in low numerical precision (`fp8` and other quantized formats) +- [dtensors](https://pytorch.org/docs/stable/distributed.tensor.html) to distribute models across GPUs, by combining TP and FSDP (referred to FSDP2 in PyTorch) + +Notably, `torch ao` introduces quantization and dequantization operations in the model that may result in slow-downs if not optimized. Using `torch.compile` after `torch ao` recovers performance by generating optimized kernels for those operations. + +### Vanilla Transformer Example + +This example shows how to train a vanilla Transformer model using `fp8` precision and the FSDP2 distributed strategy, and then optimize the resulting model through `torch.compile`. + +Specifically, we employ the `ModelParallelStrategy`, which accepts a `parallelize_fn` to distribute the model using the PyTorch DTensor API. +We use the same function to also pass the model through the `torch ao` API (prior to FSDP2), as well as `torch.compile` (after FSDP2). + +The resulting code follows the PyTorch API closely, while also taking advantage of the rest of Lightning Fabric. + +To execute the code directly just run: + +```bash +python train.py +``` + +### A Note on torch.compile + +Note that Fabric also supports calling `torch.compile` on a model and passing it to `fabric.setup_model` or `fabric.setup_model_and_optimizers`. + +While this works well, in order to get the most out of the combination of the latest distributed, quantization, and compile PyTorch API's, we recommend invoking `torch.compile` as part of the `parallelize_fn` argument of `ModelParallelStrategy`, as shown in this example. From ca3c4cf1e4b7185e475dd64602dfe8634c27c6b5 Mon Sep 17 00:00:00 2001 From: Luca Antiga Date: Tue, 26 Nov 2024 00:25:37 +0100 Subject: [PATCH 05/13] Add docs --- docs/source-fabric/advanced/compile.rst | 103 ++++++++++++++++++++ docs/source-pytorch/advanced/compile.rst | 116 ++++++++++++++++++++++- 2 files changed, 217 insertions(+), 2 deletions(-) diff --git a/docs/source-fabric/advanced/compile.rst b/docs/source-fabric/advanced/compile.rst index 17ba6e4ca9dc8..801bf52ac8b46 100644 --- a/docs/source-fabric/advanced/compile.rst +++ b/docs/source-fabric/advanced/compile.rst @@ -118,6 +118,109 @@ always exclude the first call to ``forward()`` from your measurements, since it ---- +********************************* +Apply torch.compile to your model +********************************* + +:func:`torch.compile` can also be invoked as part of the :func:`parallelize_fn` argument of :class:`~lightning.fabric.strategies.model_parallel.ModelParallelStrategy`. + +This is particularly handy when :func:`torch.compile` is used in combination with the `torch.distributed.tensor` API. + +Here is an example: + +.. code-block:: python + import lightning as L + import torch + import torch.nn as nn + import torch.nn.functional as F + from lightning.pytorch.demos import Transformer + from lightning.fabric.strategies.model_parallel import ModelParallelStrategy + from torch.distributed._composable.fsdp.fully_shard import fully_shard + from torch.distributed.device_mesh import DeviceMesh + + def parallelize(model: nn.Module, device_mesh: DeviceMesh) -> nn.Module: + for module in model.modules(): + if isinstance(module, (torch.nn.TransformerEncoderLayer, torch.nn.TransformerDecoderLayer)): + fully_shard(module, mesh=device_mesh) + + fully_shard(model, mesh=device_mesh) + + return torch.compile(model) + + def train(): + L.seed_everything(42) + + with torch.device("meta"): + model = Transformer( + vocab_size=50257, + nlayers=16, + nhid=4096, + ninp=1024, + nhead=32, + ) + + strategy = ModelParallelStrategy(data_parallel_size=4, tensor_parallel_size=1, parallelize_fn=parallelize) + + fabric = L.Fabric(precision="bf16-true", strategy=strategy) + fabric.launch() + + model = fabric.setup(model) + +The advantage here is that `parallelize` is called when sharding the model, +so :func:`torch.compile` is guaranteed to run on model shards and capture distributed operations. + +Also, when using other libraries like `torch ao ` +that need to be applied in a similar fashion, it's easy to reason about the sequence of calls +needed to achieve the equivalent of `compile(distributed(quantized(model)))`: + +.. code-block:: python + import lightning as L + import torch + import torch.nn as nn + import torch.nn.functional as F + from lightning.pytorch.demos import Transformer + from torch.distributed._composable.fsdp.fully_shard import fully_shard + from torch.distributed.device_mesh import DeviceMesh + from torchao.float8 import Float8LinearConfig, convert_to_float8_training + + def parallelize(model: nn.Module, device_mesh: DeviceMesh) -> nn.Module: + float8_config = Float8LinearConfig( + pad_inner_dim=True, + ) + + def module_filter_fn(mod: torch.nn.Module, fqn: str): + return fqn != "decoder" + + convert_to_float8_training(model, config=float8_config, module_filter_fn=module_filter_fn) + + for module in model.modules(): + if isinstance(module, (torch.nn.TransformerEncoderLayer, torch.nn.TransformerDecoderLayer)): + fully_shard(module, mesh=device_mesh) + + fully_shard(model, mesh=device_mesh) + + return torch.compile(model) + + def train(): + L.seed_everything(42) + + with torch.device("meta"): + model = Transformer( + vocab_size=50257, + nlayers=16, + nhid=4096, + ninp=1024, + nhead=32, + ) + + strategy = ModelParallelStrategy(data_parallel_size=4, tensor_parallel_size=1, parallelize_fn=parallelize) + + fabric = L.Fabric(precision="bf16-true", strategy=strategy) + fabric.launch() + + model = fabric.setup(model) + +---- ****************** Avoid graph breaks diff --git a/docs/source-pytorch/advanced/compile.rst b/docs/source-pytorch/advanced/compile.rst index d5bd333c041b3..c5a7362184470 100644 --- a/docs/source-pytorch/advanced/compile.rst +++ b/docs/source-pytorch/advanced/compile.rst @@ -138,6 +138,118 @@ always exclude the first call to ``forward()``/``*_step()`` from your measuremen ---- +************************************** +Apply torch.compile in configure_model +************************************** + +:func:`torch.compile` can also be invoked as part of the :meth:`~lightning.pytorch.core.hooks.ModelHooks.configure_model` hook. + +This is particularly handy when :func:`torch.compile` is used in combination with :class:`~lightning.pytorch.strategies.model_parallel.ModelParallelStrategy`. + +Here is an example: + +.. code-block:: python + import lightning as L + import torch + import torch.nn as nn + import torch.nn.functional as F + from lightning.pytorch.demos import Transformer + from lightning.pytorch.strategies.model_parallel import ModelParallelStrategy + from torch.distributed.device_mesh import DeviceMesh + from torch.distributed._composable.fsdp.fully_shard import fully_shard + + class LanguageModel(L.LightningModule): + def __init__(self, vocab_size): + super().__init__() + self.vocab_size = vocab_size + self.model = None + + def configure_model(self): + if self.model is not None: + return + + with torch.device("meta"): + model = Transformer( + vocab_size=self.vocab_size, + nlayers=16, + nhid=4096, + ninp=1024, + nhead=32, + ) + + for module in model.modules(): + if isinstance(module, (nn.TransformerEncoderLayer, nn.TransformerDecoderLayer)): + fully_shard(module, mesh=self.device_mesh) + + fully_shard(model, mesh=self.device_mesh) + + self.model = torch.compile(model) + + def training_step(self, batch): + input, target = batch + output = self.model(input, target) + loss = F.nll_loss(output, target.view(-1)) + self.log("train_loss", loss) + return loss + + def configure_optimizers(self): + return torch.optim.Adam(self.parameters(), lr=1e-4) + +The advantage here is that `configure_model` is called when sharding the model, +so :func:`torch.compile` is guaranteed to run on model shards and capture distributed operations. + +Also, when using other libraries like `torch ao ` +that need to be applied in a similar fashion, it's easy to reason about the sequence of calls +needed to achieve the equivalent of `compile(distributed(quantized(model)))`: + +.. code-block:: python + import lightning as L + import torch + import torch.nn as nn + import torch.nn.functional as F + from lightning.pytorch.demos import Transformer + from lightning.pytorch.strategies.model_parallel import ModelParallelStrategy + from torch.distributed._composable.fsdp.fully_shard import fully_shard + from torch.distributed.device_mesh import DeviceMesh + from torchao.float8 import Float8LinearConfig, convert_to_float8_training + + class LanguageModel(L.LightningModule): + def __init__(self, vocab_size): + super().__init__() + self.vocab_size = vocab_size + self.model = None + + def configure_model(self): + if self.model is not None: + return + + with torch.device("meta"): + model = Transformer( + vocab_size=self.vocab_size, + nlayers=16, + nhid=4096, + ninp=1024, + nhead=32, + ) + + float8_config = Float8LinearConfig( + pad_inner_dim=True, + ) + + def module_filter_fn(mod: torch.nn.Module, fqn: str): + return fqn != "decoder" + + convert_to_float8_training(model, config=float8_config, module_filter_fn=module_filter_fn) + + for module in model.modules(): + if isinstance(module, (nn.TransformerEncoderLayer, nn.TransformerDecoderLayer)): + fully_shard(module, mesh=self.device_mesh) + + fully_shard(model, mesh=self.device_mesh) + + self.model = torch.compile(model) + +---- ****************** Avoid graph breaks @@ -253,8 +365,8 @@ Limitations There are a few limitations you should be aware of when using ``torch.compile`` **in conjunction with the Trainer**: -* The Trainer currently does not reapply ``torch.compile`` over DDP/FSDP, meaning distributed operations can't benefit from speed ups at the moment. - This limitation will be lifted in the future. +* The Trainer currently does not reapply ``torch.compile`` over :class:`~lightning.pytorch.strategies.DDPStrategy` and :class:`~lightning.pytorch.strategies.FSDPStrategy`, meaning distributed operations can't benefit from speed ups at the moment. + This limitation can be avoided by using :class:`~lightning.pytorch.strategies.model_parallel.ModelParallelStrategy`, as described in `Apply torch.compile in configure_model`_ above. * In some cases, using ``self.log()`` in your LightningModule will cause compilation errors. Until addressed, you can work around these issues by applying ``torch.compile`` to the submodule(s) of your LightningModule rather than to the entire LightningModule at once. From 55a6fde8b596c7d2543df6c67c15a1bee9116265 Mon Sep 17 00:00:00 2001 From: Luca Antiga Date: Tue, 26 Nov 2024 00:36:00 +0100 Subject: [PATCH 06/13] Rename folder, add cross-reference --- docs/source-fabric/advanced/compile.rst | 3 ++- docs/source-pytorch/advanced/compile.rst | 2 ++ .../README.md | 0 .../requirements.txt | 0 .../train.py | 0 .../README.md | 0 .../requirements.txt | 0 .../train.py | 0 8 files changed, 4 insertions(+), 1 deletion(-) rename examples/fabric/{fp8_fsdp2_compile => fp8_distributed_transformer}/README.md (100%) rename examples/fabric/{fp8_fsdp2_compile => fp8_distributed_transformer}/requirements.txt (100%) rename examples/fabric/{fp8_fsdp2_compile => fp8_distributed_transformer}/train.py (100%) rename examples/pytorch/{fp8_fsdp2_compile => fp8_distributed_transformer}/README.md (100%) rename examples/pytorch/{fp8_fsdp2_compile => fp8_distributed_transformer}/requirements.txt (100%) rename examples/pytorch/{fp8_fsdp2_compile => fp8_distributed_transformer}/train.py (100%) diff --git a/docs/source-fabric/advanced/compile.rst b/docs/source-fabric/advanced/compile.rst index 801bf52ac8b46..077a0419e698b 100644 --- a/docs/source-fabric/advanced/compile.rst +++ b/docs/source-fabric/advanced/compile.rst @@ -115,7 +115,6 @@ always exclude the first call to ``forward()`` from your measurements, since it Compile median time: 0.0185 seconds Speedup: 1.4x - ---- ********************************* @@ -220,6 +219,8 @@ needed to achieve the equivalent of `compile(distributed(quantized(model)))`: model = fabric.setup(model) +For a full example, see our `FP8 Distributed Transformer example `_. + ---- ****************** diff --git a/docs/source-pytorch/advanced/compile.rst b/docs/source-pytorch/advanced/compile.rst index c5a7362184470..3328ad9e8e114 100644 --- a/docs/source-pytorch/advanced/compile.rst +++ b/docs/source-pytorch/advanced/compile.rst @@ -249,6 +249,8 @@ needed to achieve the equivalent of `compile(distributed(quantized(model)))`: self.model = torch.compile(model) +For a full example, see our `FP8 Distributed Transformer example `_. + ---- ****************** diff --git a/examples/fabric/fp8_fsdp2_compile/README.md b/examples/fabric/fp8_distributed_transformer/README.md similarity index 100% rename from examples/fabric/fp8_fsdp2_compile/README.md rename to examples/fabric/fp8_distributed_transformer/README.md diff --git a/examples/fabric/fp8_fsdp2_compile/requirements.txt b/examples/fabric/fp8_distributed_transformer/requirements.txt similarity index 100% rename from examples/fabric/fp8_fsdp2_compile/requirements.txt rename to examples/fabric/fp8_distributed_transformer/requirements.txt diff --git a/examples/fabric/fp8_fsdp2_compile/train.py b/examples/fabric/fp8_distributed_transformer/train.py similarity index 100% rename from examples/fabric/fp8_fsdp2_compile/train.py rename to examples/fabric/fp8_distributed_transformer/train.py diff --git a/examples/pytorch/fp8_fsdp2_compile/README.md b/examples/pytorch/fp8_distributed_transformer/README.md similarity index 100% rename from examples/pytorch/fp8_fsdp2_compile/README.md rename to examples/pytorch/fp8_distributed_transformer/README.md diff --git a/examples/pytorch/fp8_fsdp2_compile/requirements.txt b/examples/pytorch/fp8_distributed_transformer/requirements.txt similarity index 100% rename from examples/pytorch/fp8_fsdp2_compile/requirements.txt rename to examples/pytorch/fp8_distributed_transformer/requirements.txt diff --git a/examples/pytorch/fp8_fsdp2_compile/train.py b/examples/pytorch/fp8_distributed_transformer/train.py similarity index 100% rename from examples/pytorch/fp8_fsdp2_compile/train.py rename to examples/pytorch/fp8_distributed_transformer/train.py From 06c6af62cb1e0b356ab98d1f0262bf11fe3ec2b5 Mon Sep 17 00:00:00 2001 From: Luca Antiga Date: Tue, 26 Nov 2024 00:37:22 +0100 Subject: [PATCH 07/13] Fix link --- docs/source-fabric/advanced/compile.rst | 2 +- docs/source-pytorch/advanced/compile.rst | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/source-fabric/advanced/compile.rst b/docs/source-fabric/advanced/compile.rst index 077a0419e698b..84ad15fec082c 100644 --- a/docs/source-fabric/advanced/compile.rst +++ b/docs/source-fabric/advanced/compile.rst @@ -168,7 +168,7 @@ Here is an example: The advantage here is that `parallelize` is called when sharding the model, so :func:`torch.compile` is guaranteed to run on model shards and capture distributed operations. -Also, when using other libraries like `torch ao ` +Also, when using other libraries like `torch ao `_ that need to be applied in a similar fashion, it's easy to reason about the sequence of calls needed to achieve the equivalent of `compile(distributed(quantized(model)))`: diff --git a/docs/source-pytorch/advanced/compile.rst b/docs/source-pytorch/advanced/compile.rst index 3328ad9e8e114..51ea84cd4e057 100644 --- a/docs/source-pytorch/advanced/compile.rst +++ b/docs/source-pytorch/advanced/compile.rst @@ -198,7 +198,7 @@ Here is an example: The advantage here is that `configure_model` is called when sharding the model, so :func:`torch.compile` is guaranteed to run on model shards and capture distributed operations. -Also, when using other libraries like `torch ao ` +Also, when using other libraries like `torch ao `_ that need to be applied in a similar fashion, it's easy to reason about the sequence of calls needed to achieve the equivalent of `compile(distributed(quantized(model)))`: From acb31252c6dffb269648784afbc697b837fc5c14 Mon Sep 17 00:00:00 2001 From: Luca Antiga Date: Tue, 26 Nov 2024 01:13:04 +0100 Subject: [PATCH 08/13] Newline after code-block directive --- docs/source-fabric/advanced/compile.rst | 2 ++ docs/source-pytorch/advanced/compile.rst | 2 ++ 2 files changed, 4 insertions(+) diff --git a/docs/source-fabric/advanced/compile.rst b/docs/source-fabric/advanced/compile.rst index 84ad15fec082c..45929fda7a6db 100644 --- a/docs/source-fabric/advanced/compile.rst +++ b/docs/source-fabric/advanced/compile.rst @@ -128,6 +128,7 @@ This is particularly handy when :func:`torch.compile` is used in combination wit Here is an example: .. code-block:: python + import lightning as L import torch import torch.nn as nn @@ -173,6 +174,7 @@ that need to be applied in a similar fashion, it's easy to reason about the sequ needed to achieve the equivalent of `compile(distributed(quantized(model)))`: .. code-block:: python + import lightning as L import torch import torch.nn as nn diff --git a/docs/source-pytorch/advanced/compile.rst b/docs/source-pytorch/advanced/compile.rst index 51ea84cd4e057..16fe91ca282df 100644 --- a/docs/source-pytorch/advanced/compile.rst +++ b/docs/source-pytorch/advanced/compile.rst @@ -149,6 +149,7 @@ This is particularly handy when :func:`torch.compile` is used in combination wit Here is an example: .. code-block:: python + import lightning as L import torch import torch.nn as nn @@ -203,6 +204,7 @@ that need to be applied in a similar fashion, it's easy to reason about the sequ needed to achieve the equivalent of `compile(distributed(quantized(model)))`: .. code-block:: python + import lightning as L import torch import torch.nn as nn From 1646eca0c126322f64900ef39f23d4d7cb3679f4 Mon Sep 17 00:00:00 2001 From: Luca Antiga Date: Tue, 26 Nov 2024 01:18:43 +0100 Subject: [PATCH 09/13] Update section name --- docs/source-fabric/advanced/compile.rst | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/source-fabric/advanced/compile.rst b/docs/source-fabric/advanced/compile.rst index 45929fda7a6db..20c8d111977c9 100644 --- a/docs/source-fabric/advanced/compile.rst +++ b/docs/source-fabric/advanced/compile.rst @@ -117,9 +117,9 @@ always exclude the first call to ``forward()`` from your measurements, since it ---- -********************************* -Apply torch.compile to your model -********************************* +********************************************** +Apply torch.compile with ModelParallelStrategy +********************************************** :func:`torch.compile` can also be invoked as part of the :func:`parallelize_fn` argument of :class:`~lightning.fabric.strategies.model_parallel.ModelParallelStrategy`. From 3d6cea6a603a631080ec7c670a281c5c967e586e Mon Sep 17 00:00:00 2001 From: Luca Antiga Date: Tue, 26 Nov 2024 01:28:45 +0100 Subject: [PATCH 10/13] Fix reference --- docs/source-fabric/advanced/compile.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source-fabric/advanced/compile.rst b/docs/source-fabric/advanced/compile.rst index 20c8d111977c9..df79454f67a6f 100644 --- a/docs/source-fabric/advanced/compile.rst +++ b/docs/source-fabric/advanced/compile.rst @@ -121,7 +121,7 @@ always exclude the first call to ``forward()`` from your measurements, since it Apply torch.compile with ModelParallelStrategy ********************************************** -:func:`torch.compile` can also be invoked as part of the :func:`parallelize_fn` argument of :class:`~lightning.fabric.strategies.model_parallel.ModelParallelStrategy`. +:func:`torch.compile` can also be invoked as part of the `parallelize_fn` argument of :class:`~lightning.fabric.strategies.model_parallel.ModelParallelStrategy`. This is particularly handy when :func:`torch.compile` is used in combination with the `torch.distributed.tensor` API. From 4eaac170b6f0a0bdbaac67fdaf8ecfbcc962fb9a Mon Sep 17 00:00:00 2001 From: Luca Antiga Date: Tue, 26 Nov 2024 09:10:29 +0000 Subject: [PATCH 11/13] Half standalone tests batch size --- tests/run_standalone_tests.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/run_standalone_tests.sh b/tests/run_standalone_tests.sh index 9aa54f7350607..75a52e16c57dc 100755 --- a/tests/run_standalone_tests.sh +++ b/tests/run_standalone_tests.sh @@ -17,7 +17,7 @@ set -e # Batch size for testing: Determines how many standalone test invocations run in parallel # It can be set through the env variable PL_STANDALONE_TESTS_BATCH_SIZE and defaults to 6 if not set -test_batch_size="${PL_STANDALONE_TESTS_BATCH_SIZE:-6}" +test_batch_size="${PL_STANDALONE_TESTS_BATCH_SIZE:-3}" source="${PL_STANDALONE_TESTS_SOURCE:-"lightning"}" # this is the directory where the tests are located test_dir=$1 # parse the first argument From a8310cf69602e0f533fb7c821a408f0c848ecb42 Mon Sep 17 00:00:00 2001 From: Luca Antiga Date: Tue, 26 Nov 2024 12:37:24 +0000 Subject: [PATCH 12/13] Fix integration tests --- tests/tests_fabric/conftest.py | 1 + .../test_model_parallel_integration.py | 32 ++++++++++++------- .../test_model_parallel_integration.py | 25 +++++++++------ 3 files changed, 37 insertions(+), 21 deletions(-) diff --git a/tests/tests_fabric/conftest.py b/tests/tests_fabric/conftest.py index 446994167d0a1..5fdc61a08955b 100644 --- a/tests/tests_fabric/conftest.py +++ b/tests/tests_fabric/conftest.py @@ -69,6 +69,7 @@ def restore_env_variables(): "OMP_NUM_THREADS", # set by our launchers # set by torchdynamo "TRITON_CACHE_DIR", + "TORCHINDUCTOR_CACHE_DIR", } leaked_vars.difference_update(allowlist) assert not leaked_vars, f"test is leaking environment variable(s): {set(leaked_vars)}" diff --git a/tests/tests_fabric/strategies/test_model_parallel_integration.py b/tests/tests_fabric/strategies/test_model_parallel_integration.py index 0fc4caaf449a5..9f7e3e67c5a52 100644 --- a/tests/tests_fabric/strategies/test_model_parallel_integration.py +++ b/tests/tests_fabric/strategies/test_model_parallel_integration.py @@ -29,6 +29,13 @@ from tests_fabric.helpers.runif import RunIf +@pytest.fixture +def distributed(): + yield + if torch.distributed.is_initialized(): + torch.distributed.destroy_process_group() + + class FeedForward(nn.Module): def __init__(self): super().__init__() @@ -81,7 +88,7 @@ def _parallelize_feed_forward_fsdp2_tp(model, device_mesh): @RunIf(min_torch="2.4", standalone=True, min_cuda_gpus=4) -def test_setup_device_mesh(): +def test_setup_device_mesh(distributed): from torch.distributed.device_mesh import DeviceMesh for dp_size, tp_size in ((1, 4), (4, 1), (2, 2)): @@ -129,7 +136,7 @@ def fn(model, device_mesh): "compile", [True, False], ) -def test_tensor_parallel(compile): +def test_tensor_parallel(distributed, compile): from torch.distributed._tensor import DTensor parallelize = _parallelize_feed_forward_tp @@ -182,7 +189,7 @@ def test_tensor_parallel(compile): "compile", [True, False], ) -def test_fsdp2_tensor_parallel(compile): +def test_fsdp2_tensor_parallel(distributed, compile): from torch.distributed._tensor import DTensor parallelize = _parallelize_feed_forward_fsdp2_tp @@ -264,6 +271,7 @@ def _train(fabric, model=None, optimizer=None): @RunIf(min_torch="2.4", min_cuda_gpus=4, standalone=True) +@pytest.mark.filterwarnings("ignore::UserWarning") @pytest.mark.parametrize( "precision", [ @@ -271,7 +279,7 @@ def _train(fabric, model=None, optimizer=None): pytest.param("bf16-mixed", marks=RunIf(bf16_cuda=True)), ], ) -def test_train_save_load(precision, tmp_path): +def test_train_save_load(distributed, precision, tmp_path): """Test 2D-parallel training, saving and loading precision settings.""" strategy = ModelParallelStrategy( _parallelize_feed_forward_fsdp2_tp, @@ -329,7 +337,7 @@ def test_train_save_load(precision, tmp_path): @pytest.mark.filterwarnings("ignore::FutureWarning") @RunIf(min_torch="2.4", min_cuda_gpus=2, standalone=True) -def test_save_full_state_dict(tmp_path): +def test_save_full_state_dict(distributed, tmp_path): """Test that ModelParallelStrategy saves the full state into a single file with `save_distributed_checkpoint=False`.""" from torch.distributed.checkpoint.state_dict import get_optimizer_state_dict @@ -430,7 +438,7 @@ def test_save_full_state_dict(tmp_path): @pytest.mark.filterwarnings("ignore::FutureWarning") @RunIf(min_torch="2.4", min_cuda_gpus=2, standalone=True) -def test_load_full_state_dict_into_sharded_model(tmp_path): +def test_load_full_state_dict_into_sharded_model(distributed, tmp_path): """Test that the strategy can load a full-state checkpoint into a distributed model.""" fabric = Fabric(accelerator="cuda", devices=1) fabric.seed_everything(0) @@ -476,7 +484,7 @@ def test_load_full_state_dict_into_sharded_model(tmp_path): @RunIf(min_torch="2.4", min_cuda_gpus=2, skip_windows=True, standalone=True) @pytest.mark.parametrize("move_to_device", [True, False]) @mock.patch("lightning.fabric.wrappers._FabricModule") -def test_setup_module_move_to_device(fabric_module_mock, move_to_device): +def test_setup_module_move_to_device(fabric_module_mock, move_to_device, distributed): """Test that `move_to_device` does nothing, ModelParallel decides which device parameters get moved to which device (sharding).""" from torch.distributed._tensor import DTensor @@ -508,7 +516,7 @@ def test_setup_module_move_to_device(fabric_module_mock, move_to_device): pytest.param("bf16-true", torch.bfloat16, marks=RunIf(bf16_cuda=True)), ], ) -def test_module_init_context(precision, expected_dtype): +def test_module_init_context(distributed, precision, expected_dtype): """Test that the module under the init-context gets moved to the right device and dtype.""" strategy = ModelParallelStrategy(parallelize_fn=_parallelize_feed_forward_fsdp2) fabric = Fabric(accelerator="cuda", devices=2, strategy=strategy, precision=precision) @@ -531,7 +539,7 @@ def _run_setup_assertions(empty_init, expected_device): @RunIf(min_torch="2.4", min_cuda_gpus=2, standalone=True) -def test_save_filter(tmp_path): +def test_save_filter(distributed, tmp_path): strategy = ModelParallelStrategy( parallelize_fn=_parallelize_feed_forward_fsdp2, save_distributed_checkpoint=False, @@ -584,7 +592,7 @@ def _parallelize_single_linear_tp_fsdp2(model, device_mesh): "val", ], ) -def test_clip_gradients(clip_type, precision): +def test_clip_gradients(distributed, clip_type, precision): strategy = ModelParallelStrategy(_parallelize_single_linear_tp_fsdp2) fabric = Fabric(accelerator="auto", devices=2, precision=precision, strategy=strategy) fabric.launch() @@ -626,7 +634,7 @@ def test_clip_gradients(clip_type, precision): @RunIf(min_torch="2.4", min_cuda_gpus=4, standalone=True) -def test_save_sharded_and_consolidate_and_load(tmp_path): +def test_save_sharded_and_consolidate_and_load(distributed, tmp_path): """Test the consolidation of a distributed (DTensor) checkpoint into a single file.""" strategy = ModelParallelStrategy( _parallelize_feed_forward_fsdp2_tp, @@ -683,7 +691,7 @@ def test_save_sharded_and_consolidate_and_load(tmp_path): @RunIf(min_torch="2.4", min_cuda_gpus=2, standalone=True) -def test_load_raw_module_state(): +def test_load_raw_module_state(distributed): from torch.distributed.device_mesh import init_device_mesh from torch.distributed.tensor.parallel import ColwiseParallel, parallelize_module diff --git a/tests/tests_pytorch/strategies/test_model_parallel_integration.py b/tests/tests_pytorch/strategies/test_model_parallel_integration.py index 0228d952f9cbd..015b2d417a240 100644 --- a/tests/tests_pytorch/strategies/test_model_parallel_integration.py +++ b/tests/tests_pytorch/strategies/test_model_parallel_integration.py @@ -86,6 +86,13 @@ def fn(model, device_mesh): return fn +@pytest.fixture +def distributed(): + yield + if torch.distributed.is_initialized(): + torch.distributed.destroy_process_group() + + class TemplateModel(LightningModule): def __init__(self, compile=False): super().__init__() @@ -130,7 +137,7 @@ def configure_model(self): @RunIf(min_torch="2.4", standalone=True, min_cuda_gpus=4) -def test_setup_device_mesh(): +def test_setup_device_mesh(distributed): from torch.distributed.device_mesh import DeviceMesh for dp_size, tp_size in ((1, 4), (4, 1), (2, 2)): @@ -191,7 +198,7 @@ def configure_model(self): "compile", [True, False], ) -def test_tensor_parallel(compile): +def test_tensor_parallel(distributed, compile): from torch.distributed._tensor import DTensor class Model(TensorParallelModel): @@ -236,7 +243,7 @@ def training_step(self, batch): "compile", [True, False], ) -def test_fsdp2_tensor_parallel(compile): +def test_fsdp2_tensor_parallel(distributed, compile): from torch.distributed._tensor import DTensor class Model(FSDP2TensorParallelModel): @@ -293,7 +300,7 @@ def training_step(self, batch): @RunIf(min_torch="2.4", min_cuda_gpus=2, standalone=True) -def test_modules_without_parameters(tmp_path): +def test_modules_without_parameters(distributed, tmp_path): """Test that TorchMetrics get moved to the device despite not having any parameters.""" class MetricsModel(TensorParallelModel): @@ -336,7 +343,7 @@ def training_step(self, batch): "compile", [True, False], ) -def test_module_init_context(compile, precision, expected_dtype, tmp_path): +def test_module_init_context(distributed, compile, precision, expected_dtype, tmp_path): """Test that the module under the init-context gets moved to the right device and dtype.""" class Model(FSDP2Model): @@ -375,7 +382,7 @@ def _run_setup_assertions(empty_init, expected_device): @RunIf(min_torch="2.4", min_cuda_gpus=2, skip_windows=True, standalone=True) @pytest.mark.parametrize("save_distributed_checkpoint", [True, False]) -def test_strategy_state_dict(tmp_path, save_distributed_checkpoint): +def test_strategy_state_dict(distributed, tmp_path, save_distributed_checkpoint): """Test that the strategy returns the correct state dict of the LightningModule.""" model = FSDP2Model() correct_state_dict = model.state_dict() # State dict before wrapping @@ -408,7 +415,7 @@ def test_strategy_state_dict(tmp_path, save_distributed_checkpoint): @RunIf(min_torch="2.4", min_cuda_gpus=2, skip_windows=True, standalone=True) -def test_load_full_state_checkpoint_into_regular_model(tmp_path): +def test_load_full_state_checkpoint_into_regular_model(distributed, tmp_path): """Test that a full-state checkpoint saved from a distributed model can be loaded back into a regular model.""" # Save a regular full-state checkpoint from a distributed model @@ -450,7 +457,7 @@ def test_load_full_state_checkpoint_into_regular_model(tmp_path): @pytest.mark.filterwarnings("ignore::FutureWarning") @RunIf(min_torch="2.4", min_cuda_gpus=2, skip_windows=True, standalone=True) -def test_load_standard_checkpoint_into_distributed_model(tmp_path): +def test_load_standard_checkpoint_into_distributed_model(distributed, tmp_path): """Test that a regular checkpoint (weights and optimizer states) can be loaded into a distributed model.""" # Save a regular DDP checkpoint @@ -491,7 +498,7 @@ def test_load_standard_checkpoint_into_distributed_model(tmp_path): @pytest.mark.filterwarnings("ignore::FutureWarning") @RunIf(min_torch="2.4", min_cuda_gpus=2, standalone=True) -def test_save_load_sharded_state_dict(tmp_path): +def test_save_load_sharded_state_dict(distributed, tmp_path): """Test saving and loading with the distributed state dict format.""" class CheckpointModel(FSDP2Model): From 71a4b97b2a37c4fed48ed410638a265e9a0103f4 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 26 Nov 2024 12:38:25 +0000 Subject: [PATCH 13/13] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../tests_fabric/strategies/test_model_parallel_integration.py | 2 +- .../tests_pytorch/strategies/test_model_parallel_integration.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/tests_fabric/strategies/test_model_parallel_integration.py b/tests/tests_fabric/strategies/test_model_parallel_integration.py index 9f7e3e67c5a52..b04a29b691529 100644 --- a/tests/tests_fabric/strategies/test_model_parallel_integration.py +++ b/tests/tests_fabric/strategies/test_model_parallel_integration.py @@ -29,7 +29,7 @@ from tests_fabric.helpers.runif import RunIf -@pytest.fixture +@pytest.fixture() def distributed(): yield if torch.distributed.is_initialized(): diff --git a/tests/tests_pytorch/strategies/test_model_parallel_integration.py b/tests/tests_pytorch/strategies/test_model_parallel_integration.py index 015b2d417a240..9dcbcc802834b 100644 --- a/tests/tests_pytorch/strategies/test_model_parallel_integration.py +++ b/tests/tests_pytorch/strategies/test_model_parallel_integration.py @@ -86,7 +86,7 @@ def fn(model, device_mesh): return fn -@pytest.fixture +@pytest.fixture() def distributed(): yield if torch.distributed.is_initialized():