Skip to content

Commit 9a35627

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent e6889ed commit 9a35627

File tree

6 files changed

+67
-75
lines changed

6 files changed

+67
-75
lines changed

examples/pytorch/custom_handler_fp8_fsdp1n2_compile/handlers/fp8_training_handler.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import logging
33
import operator
44
from dataclasses import dataclass
5-
from typing import Dict, List, Union
5+
from typing import Union
66

77
import torch
88
import torch.nn as nn
@@ -46,7 +46,7 @@ class FP8Config:
4646
class Float8TrainingHandler:
4747
"""Handler for configuring models for FP8 training using torchao."""
4848

49-
def __init__(self, args: FP8Config, model_path: str, parallel_dims: Dict[str, bool]):
49+
def __init__(self, args: FP8Config, model_path: str, parallel_dims: dict[str, bool]):
5050
"""Initializes the handler for FP8 training and configuration.
5151
5252
Args:
@@ -164,7 +164,7 @@ def convert_to_float8_training(self, model: nn.Module, module_filter_fn: callabl
164164
f"Swapped to Float8Linear layers with enable_fsdp_float8_all_gather={self.config.enable_fsdp_float8_all_gather}"
165165
)
166166

167-
def precompute_float8_dynamic_scale_for_fsdp(self, model: Union[nn.Module, List[nn.Module]]):
167+
def precompute_float8_dynamic_scale_for_fsdp(self, model: Union[nn.Module, list[nn.Module]]):
168168
if not self.enable_fp8 or not self.precompute_scale:
169169
return
170170

@@ -174,7 +174,7 @@ def precompute_float8_dynamic_scale_for_fsdp(self, model: Union[nn.Module, List[
174174
for m in models:
175175
precompute_float8_dynamic_scale_for_fsdp(m)
176176

177-
def sync_float8_amax_and_scale_history(self, model: Union[nn.Module, List[nn.Module]]):
177+
def sync_float8_amax_and_scale_history(self, model: Union[nn.Module, list[nn.Module]]):
178178
if not self.enable_fp8 or not self.delayed_scaling:
179179
return
180180

examples/pytorch/custom_handler_fp8_fsdp1n2_compile/handlers/fsdp2_handler.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -75,13 +75,13 @@ def wrap_model(self, model: nn.Module):
7575
dp_mesh = self.device_mesh["data_parallel"]
7676
assert dp_mesh.size() > 1, "FSDP requires at least two devices."
7777

78-
fsdp_policy = dict(
79-
mesh=dp_mesh,
80-
mp_policy=self.MixedPrecisionPolicy(
78+
fsdp_policy = {
79+
"mesh": dp_mesh,
80+
"mp_policy": self.MixedPrecisionPolicy(
8181
param_dtype=torch.bfloat16,
8282
reduce_dtype=torch.float32,
8383
),
84-
)
84+
}
8585
if self.args.enable_cpu_offload:
8686
fsdp_policy["offload_policy"] = self.CPUOffloadPolicy()
8787

examples/pytorch/custom_handler_fp8_fsdp1n2_compile/tests/test_fp8_training_handler.py

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import unittest
22
from unittest.mock import patch
33

4+
import pytest
45
import torch.nn as nn
56
from handlers.fp8_training_handler import Float8TrainingHandler, FP8Config
67
from lightning.pytorch.demos import Transformer
@@ -38,33 +39,30 @@ def setUp(self):
3839
@patch("handlers.fp8_training_handler.is_sm89_or_later", return_value=True)
3940
def test_handler_initialization(self, mock_sm89):
4041
handler = Float8TrainingHandler(self.args, self.model_path, self.parallel_dims)
41-
self.assertTrue(handler.enable_fp8)
42-
self.assertFalse(handler.compile)
43-
self.assertIsNotNone(handler.args)
44-
self.assertIsNotNone(handler.parallel_dims)
42+
assert handler.enable_fp8
43+
assert not handler.compile
44+
assert handler.args is not None
45+
assert handler.parallel_dims is not None
4546

4647
@patch("handlers.fp8_training_handler.is_sm89_or_later", return_value=True)
4748
def test_compile_flag(self, mock_sm89):
4849
self.args.enable_torch_compile = True
4950
handler = Float8TrainingHandler(self.args, self.model_path, self.parallel_dims)
50-
self.assertTrue(handler.compile)
51+
assert handler.compile
5152

5253
@patch("handlers.fp8_training_handler.is_sm89_or_later", return_value=False)
5354
def test_handler_disabled_on_unsupported_hardware(self, mock_sm89):
5455
# Assert that the RuntimeError is raised
55-
with self.assertRaises(RuntimeError) as context:
56+
with pytest.raises(RuntimeError) as context:
5657
Float8TrainingHandler(self.args, self.model_path, self.parallel_dims)
5758

5859
# Check that the error message matches the expected text
59-
self.assertIn(
60-
"Float8Linear operation is not supported on the current hardware.",
61-
str(context.exception),
62-
)
60+
assert "Float8Linear operation is not supported on the current hardware." in str(context.exception)
6361

6462
def test_handler_disabled_when_fp8_not_enabled(self):
6563
self.args.enable_fp8 = False
6664
handler = Float8TrainingHandler(self.args, self.model_path, self.parallel_dims)
67-
self.assertFalse(handler.enable_fp8)
65+
assert not handler.enable_fp8
6866

6967
@patch("handlers.fp8_training_handler.is_sm89_or_later", return_value=True)
7068
def test_convert_to_float8_training(self, mock_sm89):
@@ -75,9 +73,9 @@ def test_convert_to_float8_training(self, mock_sm89):
7573
print(self.model)
7674
for module_name, module in self.model.named_modules():
7775
if any(proj in module_name for proj in ["w1", "w2", "w3"]): # Float8Linear
78-
self.assertIsInstance(module, Float8Linear, f"{module_name} should be Float8Linear")
76+
assert isinstance(module, Float8Linear), f"{module_name} should be Float8Linear"
7977
elif isinstance(module, nn.Linear):
80-
self.assertNotIsInstance(module, Float8Linear, f"{module_name} should not be Float8Linear")
78+
assert not isinstance(module, Float8Linear), f"{module_name} should not be Float8Linear"
8179

8280
@patch("handlers.fp8_training_handler.is_sm89_or_later", return_value=True)
8381
def test_precompute_float8_dynamic_scale_for_fsdp(self, mock_sm89):

examples/pytorch/custom_handler_fp8_fsdp1n2_compile/tests/test_fsdp2_handler.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import unittest
22
from unittest.mock import MagicMock, patch
33

4+
import pytest
45
import torch.nn as nn
56
from handlers.fsdp2_handler import FSDP2Config, FSDP2Handler
67

@@ -37,15 +38,15 @@ def setUp(self):
3738

3839
class ModelWrapper(nn.Module):
3940
def __init__(self, model):
40-
super(ModelWrapper, self).__init__()
41+
super().__init__()
4142
self.model = model # The wrapped Transformer model
4243

4344
def forward(self, *args, **kwargs):
4445
return self.model(*args, **kwargs)
4546

4647
class InnerModel(nn.Module):
4748
def __init__(self, num_layers, input_size, hidden_size):
48-
super(InnerModel, self).__init__()
49+
super().__init__()
4950
# Initialize a ModuleList to store the layers
5051
self.layers = nn.ModuleList()
5152
for _ in range(num_layers):
@@ -77,23 +78,23 @@ def test_wrap_model(self, mock_checkpoint_wrapper_func, mock_fully_shard_func):
7778
wrapped_model = handler.wrap_model(self.model)
7879

7980
# Ensure fully_shard and checkpoint_wrapper are called
80-
self.assertTrue(mock_fully_shard_func.called, "fully_shard was not called")
81-
self.assertTrue(mock_checkpoint_wrapper_func.called, "checkpoint_wrapper was not called")
81+
assert mock_fully_shard_func.called, "fully_shard was not called"
82+
assert mock_checkpoint_wrapper_func.called, "checkpoint_wrapper was not called"
8283

8384
# Verify that the model's layers have been wrapped
84-
self.assertIsNotNone(wrapped_model, "wrapped_model is None")
85+
assert wrapped_model is not None, "wrapped_model is None"
8586
mock_fully_shard_func.assert_called()
8687

8788
# Ensure that checkpoint_wrapper is called for each layer
88-
self.assertEqual(mock_checkpoint_wrapper_func.call_count, len(self.model.model.layers))
89+
assert mock_checkpoint_wrapper_func.call_count == len(self.model.model.layers)
8990
# Ensure that fully_shard is called for each layer + full module
90-
self.assertEqual(mock_fully_shard_func.call_count, len(self.model.model.layers) + 1)
91+
assert mock_fully_shard_func.call_count == len(self.model.model.layers) + 1
9192

9293
def test_wrap_model_with_single_device(self):
9394
# Simulate single device
9495
self.device_mesh["data_parallel"].size.return_value = 1
9596
handler = FSDP2Handler(self.args, self.device_mesh)
96-
with self.assertRaises(AssertionError):
97+
with pytest.raises(AssertionError):
9798
handler.wrap_model(self.model)
9899

99100
@patch("torch.distributed._composable.fsdp.fully_shard", side_effect=mock_fully_shard)
@@ -103,8 +104,8 @@ def test_enable_cpu_offload(self, mock_fully_shard_func):
103104
handler.wrap_model(self.model)
104105
# Check if CPUOffloadPolicy is used
105106
args, kwargs = mock_fully_shard_func.call_args
106-
self.assertIn("offload_policy", kwargs)
107-
self.assertIsNotNone(kwargs["offload_policy"])
107+
assert "offload_policy" in kwargs
108+
assert kwargs["offload_policy"] is not None
108109

109110
@patch("torch.distributed._composable.fsdp.fully_shard", side_effect=mock_fully_shard)
110111
@patch(
@@ -116,4 +117,4 @@ def test_diable_gradient_checkpointing(self, mock_checkpoint_wrapper_func, mock_
116117
handler = FSDP2Handler(self.args, self.device_mesh)
117118
handler.wrap_model(self.model)
118119
# Check if gradient checkpointing is disabled
119-
self.assertFalse(mock_checkpoint_wrapper_func.called, "Error: checkpoint_wrapper was unexpectedly called.")
120+
assert not mock_checkpoint_wrapper_func.called, "Error: checkpoint_wrapper was unexpectedly called."

examples/pytorch/custom_handler_fp8_fsdp1n2_compile/tests/test_torch_compile_handler.py

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -43,11 +43,7 @@ def test_compile_transformer_encoder_layers(self, mock_compile):
4343
handler.compile_model(self.model)
4444

4545
# Ensure torch.compile was called with the correct layer
46-
self.assertEqual(
47-
mock_compile.call_count,
48-
self.num_layers,
49-
f"Expected mock_compile to be called {self.num_layers} times",
50-
)
46+
assert mock_compile.call_count == self.num_layers, f"Expected mock_compile to be called {self.num_layers} times"
5147

5248
def test_compile_disabled(self):
5349
handler = TorchCompileHandler(False, self.model_path)
@@ -74,9 +70,5 @@ def forward(self, x):
7470
handler.compile_model(model)
7571

7672
# LlamaMLP inside NestedModel should be compiled
77-
self.assertTrue(mock_compile.called)
78-
self.assertEqual(
79-
mock_compile.call_count,
80-
self.num_layers,
81-
f"Expected mock_compile to be called {self.num_layers} times",
82-
)
73+
assert mock_compile.called
74+
assert mock_compile.call_count == self.num_layers, f"Expected mock_compile to be called {self.num_layers} times"

examples/pytorch/custom_handler_fp8_fsdp1n2_compile/train.py

Lines changed: 33 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,21 @@
11
import argparse
2-
from dataclasses import dataclass
32
import logging
3+
from dataclasses import dataclass
44

5-
import torch.distributed as dist
65
import lightning as L
76
import torch
7+
import torch.distributed as dist
88
import torch.nn as nn
99
import torch.nn.functional as F
10-
from lightning.pytorch.demos import Transformer, WikiText2
10+
from lightning.pytorch.demos import WikiText2
1111
from lightning.pytorch.strategies import FSDPStrategy, ModelParallelStrategy
1212
from torch.distributed.fsdp import BackwardPrefetch, MixedPrecision
1313
from torch.utils.data import DataLoader
1414

1515
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
1616
log = logging.getLogger(__name__)
1717

18+
1819
@dataclass
1920
class Args:
2021
vocab_size: int = 32000
@@ -24,56 +25,56 @@ class Args:
2425
enable_gradient_checkpointing: bool = False
2526
enable_fsdp2: bool = False
2627

28+
2729
class SimpleLayer(nn.Module):
2830
def __init__(self, hidden_size):
29-
super(SimpleLayer, self).__init__()
31+
super().__init__()
3032
self.linear = nn.Linear(hidden_size, hidden_size)
3133
self.activation = nn.ReLU()
3234

3335
def forward(self, x):
3436
print(f"Input shape before Linear: {x.shape}")
3537
x = self.linear(x)
3638
print(f"Output shape after Linear: {x.shape}")
37-
x = self.activation(x)
38-
return x
39+
return self.activation(x)
40+
3941

4042
class InnerModel(nn.Module):
4143
def __init__(self, num_layers, hidden_size, vocab_size=32000):
42-
super(InnerModel, self).__init__()
44+
super().__init__()
4345
# Embedding layer
4446
self.embedding = nn.Embedding(num_embeddings=vocab_size, embedding_dim=hidden_size)
4547
# Initialize a ModuleList to store the intermediate layers
4648
self.layers = nn.ModuleList([SimpleLayer(hidden_size) for _ in range(num_layers)])
4749
self.lm_head = nn.Linear(hidden_size, vocab_size)
4850

49-
5051
def forward(self, x):
5152
x = self.embedding(x)
5253
# Pass the input through each layer sequentially
5354
for layer in self.layers:
5455
x = layer(x)
55-
x = self.lm_head(x)
56-
return x
56+
return self.lm_head(x)
5757

5858

5959
class ModelWrapper(nn.Module):
6060
def __init__(self, model):
61-
super(ModelWrapper, self).__init__()
61+
super().__init__()
6262
self.model = model # The wrapped Transformer model
6363

6464
def forward(self, *args, **kwargs):
6565
return self.model(*args, **kwargs)
6666

6767

6868
class LanguageModel(L.LightningModule):
69-
def __init__(self,
70-
vocab_size=32000,
71-
enable_fp8 = False,
72-
enable_fsdp2 = False,
73-
enable_torch_compile = False,
74-
enable_gradient_checkpointing = False,
75-
enable_cpu_offload = False
76-
):
69+
def __init__(
70+
self,
71+
vocab_size=32000,
72+
enable_fp8=False,
73+
enable_fsdp2=False,
74+
enable_torch_compile=False,
75+
enable_gradient_checkpointing=False,
76+
enable_cpu_offload=False,
77+
):
7778
super().__init__()
7879
self.model = None
7980
self.vocab_size = vocab_size
@@ -83,15 +84,14 @@ def __init__(self,
8384
self.enable_gradient_checkpointing = enable_gradient_checkpointing
8485
self.enable_cpu_offload = enable_cpu_offload
8586
self.model_path = "dummy" # placeholder
86-
self.parallel_dims = {
87-
"dp_shard_enabled": True if torch.cuda.device_count() > 1 else False
88-
} # only used for FP8 training
87+
self.parallel_dims = {"dp_shard_enabled": torch.cuda.device_count() > 1} # only used for FP8 training
8988

9089
def log_model_stage(self, stage: str):
91-
"""
92-
Logs the current state of the model with a description of the stage.
90+
"""Logs the current state of the model with a description of the stage.
91+
9392
Args:
9493
stage (str): Description of the current model stage.
94+
9595
"""
9696
log.warning(f"Model at stage: {stage}\n{self.model}")
9797

@@ -129,7 +129,7 @@ def configure_fsdp2(self):
129129

130130
def configure_fp8(self):
131131
# Setup fp8 training, if enable_fp8 is false, it will create a fake handler
132-
from handlers.fp8_training_handler import FP8Config, Float8TrainingHandler
132+
from handlers.fp8_training_handler import Float8TrainingHandler, FP8Config
133133

134134
fp8_config = FP8Config(
135135
enable_fp8=self.enable_fp8,
@@ -207,13 +207,14 @@ def train(args):
207207
dataset = WikiText2()
208208
train_dataloader = DataLoader(dataset, num_workers=8, batch_size=1)
209209

210-
model = LanguageModel(vocab_size=args.vocab_size,
211-
enable_fp8 = args.enable_fp8,
212-
enable_fsdp2 = args.enable_fsdp2,
213-
enable_torch_compile = args.enable_torch_compile,
214-
enable_gradient_checkpointing = args.enable_gradient_checkpointing,
215-
enable_cpu_offload = args.enable_cpu_offload,
216-
)
210+
model = LanguageModel(
211+
vocab_size=args.vocab_size,
212+
enable_fp8=args.enable_fp8,
213+
enable_fsdp2=args.enable_fsdp2,
214+
enable_torch_compile=args.enable_torch_compile,
215+
enable_gradient_checkpointing=args.enable_gradient_checkpointing,
216+
enable_cpu_offload=args.enable_cpu_offload,
217+
)
217218

218219
if args.enable_fsdp2:
219220
strategy = ModelParallelStrategy(

0 commit comments

Comments
 (0)