Skip to content

Commit 1f7eaba

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent 781b1c8 commit 1f7eaba

File tree

7 files changed

+37
-42
lines changed

7 files changed

+37
-42
lines changed

src/lightning/pytorch/utilities/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@
2525
)
2626
from lightning.pytorch.utilities.combined_loader import CombinedLoader
2727
from lightning.pytorch.utilities.enums import GradClipAlgorithmType
28+
from lightning.pytorch.utilities.fp8_training_handler import Float8TrainingHandler, FP8Config
29+
from lightning.pytorch.utilities.fsdp2_handler import FSDP2Config, FSDP2Handler
2830
from lightning.pytorch.utilities.grads import grad_norm
2931
from lightning.pytorch.utilities.parameter_tying import find_shared_parameters, set_shared_parameters
3032
from lightning.pytorch.utilities.parsing import AttributeDict, is_picklable
@@ -34,8 +36,6 @@
3436
rank_zero_only,
3537
rank_zero_warn,
3638
)
37-
from lightning.pytorch.utilities.fp8_training_handler import FP8Config, Float8TrainingHandler
38-
from lightning.pytorch.utilities.fsdp2_handler import FSDP2Config, FSDP2Handler
3939
from lightning.pytorch.utilities.torch_compile_handler import TorchCompileHandler
4040

4141
__all__ = [

src/lightning/pytorch/utilities/fp8_training_handler.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
# the script is modified based on https://github.com/pytorch/torchtitan/blob/main/torchtitan/float8.py
22
import logging
3-
from typing import Dict, List, Union
4-
from dataclasses import dataclass
53
import operator
4+
from dataclasses import dataclass
5+
from typing import Dict, List, Union
66

77
import torch
88
import torch.nn as nn
@@ -44,13 +44,10 @@ class FP8Config:
4444

4545

4646
class Float8TrainingHandler:
47-
"""
48-
Handler for configuring models for FP8 training using torchao.
49-
"""
47+
"""Handler for configuring models for FP8 training using torchao."""
5048

5149
def __init__(self, args: FP8Config, model_path: str, parallel_dims: Dict[str, bool]):
52-
"""
53-
Initializes the handler for FP8 training and configuration.
50+
"""Initializes the handler for FP8 training and configuration.
5451
5552
Args:
5653
args (FP8Config): Configuration object for FP8 training, including settings for scaling, amax initialization, and torch compile.
@@ -74,6 +71,7 @@ def __init__(self, args: FP8Config, model_path: str, parallel_dims: Dict[str, bo
7471
7572
parallel_dims = {"dp_shard_enabled": False}
7673
handler = Float8TrainingHandler(fp8_config, "path/to/model", parallel_dims)
74+
7775
"""
7876
self.model_path = model_path
7977
self.args = args
@@ -132,14 +130,14 @@ def __init__(self, args: FP8Config, model_path: str, parallel_dims: Dict[str, bo
132130
log.info("Float8 training active")
133131

134132
def convert_to_float8_training(self, model: nn.Module, module_filter_fn: callable = None):
135-
"""
136-
Converts the linear layers of `model` to `Float8Linear` based on a module filter function.
137-
Mutates the model in place.
133+
"""Converts the linear layers of `model` to `Float8Linear` based on a module filter function. Mutates the model
134+
in place.
138135
139136
Args:
140137
model (nn.Module): The model whose layers should be converted.
141138
module_filter_fn (callable, optional): A function to filter which modules should be replaced.
142139
Defaults to a model-specific filter based on `model_path`.
140+
143141
"""
144142
if not self.enable_fp8:
145143
log.warning("FP8 is disabled, so layers will not be replaced.")

src/lightning/pytorch/utilities/fsdp2_handler.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
import logging
2+
import operator
3+
from dataclasses import dataclass
24
from typing import TYPE_CHECKING
35

46
import torch
57
import torch.nn as nn
6-
import operator
7-
from dataclasses import dataclass
88
from lightning_utilities.core.imports import compare_version
99

1010
if TYPE_CHECKING:
@@ -20,8 +20,7 @@ class FSDP2Config:
2020

2121

2222
class FSDP2Handler:
23-
"""
24-
Handler for wrapping the model layers with FSDP2.
23+
"""Handler for wrapping the model layers with FSDP2.
2524
2625
Args:
2726
args (FSDP2Config): Configuration for FSDP2, including options for CPU offload and gradient checkpointing.
@@ -30,6 +29,7 @@ class FSDP2Handler:
3029
Attributes:
3130
args (FSDP2Config): Stores the FSDP2 configuration.
3231
device_mesh (DeviceMesh): Stores the device mesh configuration.
32+
3333
"""
3434

3535
def __init__(self, args: FSDP2Config, device_mesh: "DeviceMesh"):
@@ -63,14 +63,14 @@ def __init__(self, args: FSDP2Config, device_mesh: "DeviceMesh"):
6363
raise
6464

6565
def wrap_model(self, model: nn.Module):
66-
"""
67-
Wraps the model layers with FSDP configurations.
66+
"""Wraps the model layers with FSDP configurations.
6867
6968
Args:
7069
model (nn.Module): The model to wrap.
7170
7271
Returns:
7372
nn.Module: The wrapped model.
73+
7474
"""
7575
dp_mesh = self.device_mesh["data_parallel"]
7676
assert dp_mesh.size() > 1, "FSDP requires at least two devices."

src/lightning/pytorch/utilities/torch_compile_handler.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,22 @@
11
import logging
2+
import operator
3+
24
import torch
35
import torch.nn as nn
4-
import operator
56
from lightning_utilities.core.imports import compare_version
67

7-
88
log = logging.getLogger(__name__)
99

1010

1111
class TorchCompileHandler:
12-
"""
13-
Handler for compiling specific layers of the model using torch.compile.
12+
"""Handler for compiling specific layers of the model using torch.compile.
1413
1514
Args:
1615
enable_compile (bool): Whether to enable compilation.
1716
model_path (str): Path to the model, used to determine default compilable layers.
1817
compile_layers (List[str], optional): List of layer class names to compile. If None, defaults are used.
1918
compile_args (dict, optional): Additional arguments to pass to torch.compile.
19+
2020
"""
2121

2222
# Default mapping of model names to compilable layer class names
@@ -54,23 +54,23 @@ def __init__(
5454
)
5555

5656
def _get_default_compile_layers(self):
57-
"""
58-
Determines the default layers to compile based on the model name.
57+
"""Determines the default layers to compile based on the model name.
5958
6059
Returns:
6160
List[str]: List of layer class names to compile.
61+
6262
"""
6363
for model_name, layers in self.DEFAULT_COMPILABLE_LAYERS.items():
6464
if model_name in self.model_path:
6565
return layers
6666
return []
6767

6868
def compile_model(self, model: nn.Module):
69-
"""
70-
Compiles specified layers in the model.
69+
"""Compiles specified layers in the model.
7170
7271
Args:
7372
model (nn.Module): The model to compile.
73+
7474
"""
7575
if not self.enable_compile:
7676
return
@@ -84,11 +84,11 @@ def compile_model(self, model: nn.Module):
8484
self._compile_layers(model)
8585

8686
def _compile_layers(self, module: nn.Module):
87-
"""
88-
Recursively compiles specified layers in the module.
87+
"""Recursively compiles specified layers in the module.
8988
9089
Args:
9190
module (nn.Module): The module to process.
91+
9292
"""
9393
for name, child in module.named_children():
9494
child_class_name = type(child).__name__

tests/tests_pytorch/utilities/test_fp8_training_handler.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,12 @@
22
from unittest.mock import patch
33

44
import torch.nn as nn
5-
from torchao.float8 import Float8Linear
65
from lightning.pytorch.demos import Transformer
7-
8-
from lightning.pytorch.utilities.fp8_training_handler import FP8Config, Float8TrainingHandler
6+
from lightning.pytorch.utilities.fp8_training_handler import Float8TrainingHandler, FP8Config
7+
from torchao.float8 import Float8Linear
98

109

1110
class TestFloat8TrainingHandler(unittest.TestCase):
12-
1311
def setUp(self):
1412
self.args = FP8Config(
1513
enable_fp8=True,

tests/tests_pytorch/utilities/test_fsdp2_handler.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,29 +2,29 @@
22
from unittest.mock import MagicMock, patch
33

44
import torch.nn as nn
5-
from lightning.pytorch.demos import Transformer
65
from lightning.pytorch.utilities.fsdp2_handler import FSDP2Config, FSDP2Handler
76

87

98
# Define mock functions
109
def mock_fully_shard(module, **kwargs):
11-
"""
12-
Mock for torch.distributed._composable.fsdp.fully_shard.
10+
"""Mock for torch.distributed._composable.fsdp.fully_shard.
11+
1312
Returns the module unchanged to simulate sharding without actual processing.
13+
1414
"""
1515
return module
1616

1717

1818
def mock_checkpoint_wrapper(module):
19-
"""
20-
Mock for torch.distributed.algorithms._checkpoint.checkpoint_wrapper.
19+
"""Mock for torch.distributed.algorithms._checkpoint.checkpoint_wrapper.
20+
2121
Returns the module unchanged to simulate checkpoint wrapping without actual processing.
22+
2223
"""
2324
return module
2425

2526

2627
class TestFSDP2Handler(unittest.TestCase):
27-
2828
def setUp(self):
2929
self.args = FSDP2Config(
3030
enable_gradient_checkpointing=True,

tests/tests_pytorch/utilities/test_torch_compile_handler.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,20 +5,19 @@
55

66
import torch.nn as nn
77
from lightning.pytorch.demos import Transformer
8-
98
from lightning.pytorch.utilities.torch_compile_handler import TorchCompileHandler
109

1110

1211
def mock_torch_compile(module, **kwargs):
13-
"""
14-
Mock function for torch.compile that returns the module unchanged.
12+
"""Mock function for torch.compile that returns the module unchanged.
13+
1514
This avoids actual compilation during testing.
15+
1616
"""
1717
return module
1818

1919

2020
class TestTorchCompileHandler(unittest.TestCase):
21-
2221
def setUp(self):
2322
self.enable_compile = True
2423
self.model_path = "test_custom_transformer_model"

0 commit comments

Comments
 (0)