Skip to content

Commit a230cb5

Browse files
feginmori360
authored andcommitted
[Module][2/2] Convert remaining nn.Module classes to Module protocol -- (pytorch#2608)
Stack from [ghstack](https://github.com/ezyang/ghstack/tree/0.14.0) (oldest at bottom): * __->__ pytorch#2608 **Motivation** The previous PR, pytorch#2565, doesn't fully convert all the nn.Module to Module. This PR converts the leftover. **Design Summary** 1. Introduce create_class_module so that users can easily create a Module class for a nn.Module class For exasmple, users can do ` Conv2d = create_class_module(nn.Conv2d)` to get a new Cov2d Module class that inherits from nn.Conv2d and Module. 2. Introduce verify_all_module_protocol method to BaseModel to check if all sub-nn.Module are Module. 3. Create container Module classes, nn.ModuleDict, nn.ModuleList, nn.Sequence.
1 parent 43f1283 commit a230cb5

File tree

19 files changed

+434
-127
lines changed

19 files changed

+434
-127
lines changed

tests/unit_tests/test_activation_checkpoint.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,31 +7,33 @@
77
import unittest
88

99
import torch
10-
import torch.nn as nn
1110

1211
from torch.utils.flop_counter import FlopCounterMode
1312
from torchtitan.config import ActivationCheckpointConfig as ACConfig
1413
from torchtitan.distributed.activation_checkpoint import apply_ac
14+
from torchtitan.models.common.linear import Linear
15+
from torchtitan.protocols.module import Module, ModuleDict
1516

1617

17-
class ToyModule(nn.Module):
18+
class ToyModule(Module):
1819
def __init__(self):
1920
super().__init__()
20-
self.layers = nn.ModuleDict({"0": TransformerBlock()})
21+
self.layers = ModuleDict({"0": TransformerBlock()})
2122

2223
def forward(self, x):
2324
return self.layers["0"](x)
2425

2526

26-
class TransformerBlock(nn.Module):
27+
class TransformerBlock(Module):
2728
def __init__(self):
2829
super().__init__()
29-
self.moe = nn.Module()
30-
self.moe.router = nn.Module()
31-
self.moe.router.gate = nn.Linear(512, 512, bias=False)
32-
self.attention = nn.Module()
33-
self.attention.wq = nn.Linear(512, 512, bias=False)
34-
self.output = nn.Linear(512, 1024, bias=False)
30+
linear_config = Linear.Config(bias=False)
31+
self.moe = Module()
32+
self.moe.router = Module()
33+
self.moe.router.gate = linear_config.build(in_features=512, out_features=512)
34+
self.attention = Module()
35+
self.attention.wq = linear_config.build(in_features=512, out_features=512)
36+
self.output = linear_config.build(in_features=512, out_features=1024)
3537

3638
def forward(self, x):
3739
gate_out = self.moe.router.gate(x)

tests/unit_tests/test_compile_moe.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,17 +7,19 @@
77
import unittest
88

99
import torch
10-
import torch.nn as nn
1110

1211
from torchtitan.config import CompileConfig
12+
from torchtitan.models.common.linear import Linear
1313
from torchtitan.models.llama4.parallelize import apply_compile
14+
from torchtitan.protocols.module import Module, ModuleDict
1415

1516

16-
class TransformerBlock(nn.Module):
17+
class TransformerBlock(Module):
1718
def __init__(self, dim=512):
1819
super().__init__()
19-
self.attention = nn.Linear(dim, dim, bias=False)
20-
self.mlp = nn.Linear(dim, dim, bias=False)
20+
linear_config = Linear.Config(bias=False)
21+
self.attention = linear_config.build(in_features=dim, out_features=dim)
22+
self.mlp = linear_config.build(in_features=dim, out_features=dim)
2123
self.moe_enabled = False
2224

2325
def forward(self, x):
@@ -26,10 +28,10 @@ def forward(self, x):
2628
return x
2729

2830

29-
class TinyModel(nn.Module):
31+
class TinyModel(Module):
3032
def __init__(self, num_layers=2, dim=512):
3133
super().__init__()
32-
self.layers = nn.ModuleDict(
34+
self.layers = ModuleDict(
3335
{str(i): TransformerBlock(dim) for i in range(num_layers)}
3436
)
3537

tests/unit_tests/test_module.py

Lines changed: 177 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,13 @@
55
# LICENSE file in the root directory of this source tree.
66

77
import unittest
8+
from dataclasses import dataclass
89

910
import torch
1011
import torch.nn as nn
1112

12-
from torchtitan.protocols.module import Module
13+
from torchtitan.models.common.linear import Linear
14+
from torchtitan.protocols.module import Module, ModuleDict, ModuleList, Sequential
1315

1416

1517
class TestModuleInitWeights(unittest.TestCase):
@@ -36,7 +38,9 @@ def test_init_weights_implemented(self):
3638
class GoodModule(Module):
3739
def __init__(self):
3840
super().__init__()
39-
self.linear = nn.Linear(4, 4)
41+
self.linear = Linear.Config(bias=True).build(
42+
in_features=4, out_features=4
43+
)
4044

4145
def init_weights(self, **kwargs):
4246
nn.init.zeros_(self.linear.weight)
@@ -110,11 +114,13 @@ def __init__(self, num_embeddings, embedding_dim):
110114
def test_module_hierarchy_is_flat(self):
111115
"""Diamond embedding adds no extra layer to the module tree."""
112116

113-
class Model(nn.Module):
117+
class Model(Module):
114118
def __init__(self):
115119
super().__init__()
116120
self.embed = TestDiamondInheritance.TestEmbedding(100, 32)
117-
self.linear = nn.Linear(32, 16)
121+
self.linear = Linear.Config(bias=True).build(
122+
in_features=32, out_features=16
123+
)
118124

119125
model = Model()
120126
param_names = {name for name, _ in model.named_parameters()}
@@ -138,5 +144,172 @@ def counting_init(self, *args, **kwargs):
138144
nn.Module.__init__ = orig_init
139145

140146

147+
class TestFromNnModule(unittest.TestCase):
148+
"""Tests for Module.from_nn_module utility."""
149+
150+
def test_is_subclass(self):
151+
"""Created class is subclass of both original and Module."""
152+
Conv2d = Module.from_nn_module(nn.Conv2d)
153+
self.assertTrue(issubclass(Conv2d, nn.Conv2d))
154+
self.assertTrue(issubclass(Conv2d, Module))
155+
156+
def test_isinstance(self):
157+
"""Instance satisfies isinstance checks for both original and Module."""
158+
Conv2d = Module.from_nn_module(nn.Conv2d)
159+
m = Conv2d(3, 16, 3)
160+
self.assertIsInstance(m, nn.Conv2d)
161+
self.assertIsInstance(m, Module)
162+
163+
def test_init_weights_calls_reset_parameters(self):
164+
"""For classes with reset_parameters, init_weights delegates to it."""
165+
LayerNorm = Module.from_nn_module(nn.LayerNorm)
166+
m = LayerNorm(32)
167+
# Manually set weight to zeros, then init_weights should reset
168+
nn.init.zeros_(m.weight)
169+
m.init_weights()
170+
# After reset_parameters, weight should be ones for LayerNorm
171+
self.assertTrue(torch.allclose(m.weight, torch.ones(32)))
172+
173+
def test_init_weights_noop_for_parameterless(self):
174+
"""For classes without reset_parameters, init_weights is a no-op."""
175+
GELU = Module.from_nn_module(nn.GELU)
176+
m = GELU()
177+
m.init_weights() # should not raise
178+
179+
def test_cache(self):
180+
"""Repeated calls return the same class object."""
181+
cls1 = Module.from_nn_module(nn.Conv2d)
182+
cls2 = Module.from_nn_module(nn.Conv2d)
183+
self.assertIs(cls1, cls2)
184+
185+
def test_forward_unchanged(self):
186+
"""Forward output is identical to original class."""
187+
LayerNorm = Module.from_nn_module(nn.LayerNorm)
188+
torch.manual_seed(42)
189+
orig = nn.LayerNorm(16)
190+
wrapped = LayerNorm(16)
191+
# Copy weights
192+
wrapped.load_state_dict(orig.state_dict())
193+
x = torch.randn(2, 16)
194+
torch.testing.assert_close(orig(x), wrapped(x))
195+
196+
def test_state_dict_unchanged(self):
197+
"""state_dict keys and values match the original class."""
198+
Conv2d = Module.from_nn_module(nn.Conv2d)
199+
orig = nn.Conv2d(3, 16, 3)
200+
wrapped = Conv2d(3, 16, 3)
201+
wrapped.load_state_dict(orig.state_dict())
202+
for key in orig.state_dict():
203+
self.assertIn(key, wrapped.state_dict())
204+
torch.testing.assert_close(
205+
orig.state_dict()[key], wrapped.state_dict()[key]
206+
)
207+
208+
209+
class TestContainerInitWeights(unittest.TestCase):
210+
"""Tests for ModuleList, ModuleDict, Sequential init_weights."""
211+
212+
def test_module_list_init_weights(self):
213+
"""ModuleList.init_weights calls init_weights on each child."""
214+
LayerNorm = Module.from_nn_module(nn.LayerNorm)
215+
norms = ModuleList([LayerNorm(8) for _ in range(3)])
216+
for n in norms:
217+
nn.init.zeros_(n.weight)
218+
norms.init_weights()
219+
for n in norms:
220+
self.assertTrue(torch.allclose(n.weight, torch.ones(8)))
221+
222+
def test_module_dict_init_weights(self):
223+
"""ModuleDict.init_weights calls init_weights on each child."""
224+
LayerNorm = Module.from_nn_module(nn.LayerNorm)
225+
norms = ModuleDict({"a": LayerNorm(8), "b": LayerNorm(8)})
226+
for n in norms.values():
227+
nn.init.zeros_(n.weight)
228+
norms.init_weights()
229+
for n in norms.values():
230+
self.assertTrue(torch.allclose(n.weight, torch.ones(8)))
231+
232+
def test_sequential_init_weights(self):
233+
"""Sequential.init_weights calls init_weights on each child."""
234+
linear = Linear.Config(bias=False).build(in_features=4, out_features=4)
235+
GELU = Module.from_nn_module(nn.GELU)
236+
seq = Sequential(linear, GELU())
237+
seq.init_weights() # should not raise
238+
239+
def test_containers_are_module(self):
240+
"""Container instances satisfy Module protocol."""
241+
self.assertIsInstance(ModuleList(), Module)
242+
self.assertIsInstance(ModuleDict(), Module)
243+
self.assertIsInstance(Sequential(), Module)
244+
245+
246+
class TestVerifyModuleProtocol(unittest.TestCase):
247+
"""Tests for BaseModel.verify_module_protocol."""
248+
249+
def test_passes_for_all_module(self):
250+
"""No error when all submodules are Module instances."""
251+
from torchtitan.protocols.model import BaseModel
252+
253+
class GoodModel(BaseModel):
254+
@dataclass(kw_only=True, slots=True)
255+
class Config(BaseModel.Config):
256+
def update_from_config(self, *, trainer_config, **kwargs):
257+
pass
258+
259+
def get_nparams_and_flops(self, model, seq_len):
260+
return (0, 0)
261+
262+
def __init__(self):
263+
super().__init__()
264+
self.linear = Linear.Config().build(in_features=4, out_features=4)
265+
266+
model = GoodModel()
267+
model.verify_module_protocol() # should not raise
268+
269+
def test_default_raises_for_plain_nn_module(self):
270+
"""Default verify_module_protocol raises when plain nn.Module child exists."""
271+
from torchtitan.protocols.model import BaseModel
272+
273+
class BadModel(BaseModel):
274+
@dataclass(kw_only=True, slots=True)
275+
class Config(BaseModel.Config):
276+
def update_from_config(self, *, trainer_config, **kwargs):
277+
pass
278+
279+
def get_nparams_and_flops(self, model, seq_len):
280+
return (0, 0)
281+
282+
def __init__(self):
283+
super().__init__()
284+
self.plain = nn.Linear(4, 4)
285+
286+
model = BadModel()
287+
with self.assertRaises(RuntimeError):
288+
model.verify_module_protocol()
289+
290+
def test_override_skips_verification(self):
291+
"""Subclass can override verify_module_protocol to skip verification."""
292+
from torchtitan.protocols.model import BaseModel
293+
294+
class ThirdPartyModel(BaseModel):
295+
@dataclass(kw_only=True, slots=True)
296+
class Config(BaseModel.Config):
297+
def update_from_config(self, *, trainer_config, **kwargs):
298+
pass
299+
300+
def get_nparams_and_flops(self, model, seq_len):
301+
return (0, 0)
302+
303+
def __init__(self):
304+
super().__init__()
305+
self.plain = nn.Linear(4, 4) # third-party module
306+
307+
def verify_module_protocol(self) -> None:
308+
pass # skip for third-party internals
309+
310+
model = ThirdPartyModel()
311+
model.verify_module_protocol() # should not raise
312+
313+
141314
if __name__ == "__main__":
142315
unittest.main()

tests/unit_tests/test_train_spec.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from torchtitan.components.loss import build_cross_entropy_loss
1313
from torchtitan.components.optimizer import OptimizersContainer
1414
from torchtitan.distributed.parallel_dims import ParallelDims
15+
from torchtitan.models.common.linear import Linear
1516
from torchtitan.models.llama3 import model_registry, parallelize_llama
1617
from torchtitan.protocols import BaseModel
1718
from torchtitan.protocols.model_spec import ModelSpec
@@ -30,13 +31,15 @@ def get_nparams_and_flops(self, model, seq_len):
3031

3132
def __init__(self, config: Config):
3233
super().__init__()
33-
self.linear = nn.Linear(config.hidden, config.hidden)
34+
self.linear = Linear.Config().build(
35+
in_features=config.hidden, out_features=config.hidden
36+
)
3437

3538
def forward(self, x: torch.Tensor) -> torch.Tensor:
3639
return self.linear(x)
3740

38-
def init_weights(self, buffer_device: torch.device | None = None) -> None:
39-
nn.init.normal_(self.linear.weight, mean=0.0, std=0.02)
41+
def init_weights(self, buffer_device: torch.device | None = None, **kwargs) -> None:
42+
self.linear.init_weights()
4043

4144

4245
def fake_post_optimizer_build_fn(

torchtitan/distributed/pipeline_parallel.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
from torchtitan.protocols.model import BaseModel
3636
from torchtitan.protocols.model_converter import ModelConvertersContainer
3737
from torchtitan.protocols.model_spec import ParallelizeFunction
38+
from torchtitan.protocols.module import ModuleDict, ModuleList
3839
from torchtitan.tools.logging import logger
3940

4041
__all__ = [
@@ -437,7 +438,7 @@ def _build_stage_from_modules(
437438
indices_to_keep = {
438439
int(idx) for idx in layers_to_keep if idx.isdigit()
439440
}
440-
new_layers = nn.ModuleList(
441+
new_layers = ModuleList(
441442
[
442443
layer
443444
for i, layer in enumerate(module_value)
@@ -448,9 +449,9 @@ def _build_stage_from_modules(
448449
else:
449450
# No layers from this structure needed, set to empty structure
450451
if isinstance(module_value, nn.ModuleDict):
451-
setattr(model, module_name, nn.ModuleDict())
452+
setattr(model, module_name, ModuleDict())
452453
elif isinstance(module_value, nn.ModuleList):
453-
setattr(model, module_name, nn.ModuleList())
454+
setattr(model, module_name, ModuleList())
454455
# Handle simple module attributes (e.g., "linear", "norm")
455456
elif module_name not in modules_to_keep:
456457
# Replace with None

torchtitan/experiments/ft/diloco/utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
from torchtitan.distributed.pipeline_parallel import generate_llm_fqn_per_model_part
1010
from torchtitan.experiments.ft.config import FaultTolerance as FTConfig
11+
from torchtitan.protocols.module import ModuleList
1112
from torchtitan.tools.logging import logger
1213

1314

@@ -72,7 +73,7 @@ def _build_fragment_from_modules(
7273
indices_to_keep = {
7374
int(idx) for idx in layers_to_keep if idx.isdigit()
7475
}
75-
new_layers = nn.ModuleList(
76+
new_layers = ModuleList(
7677
[
7778
layer
7879
for i, layer in enumerate(module_value)

torchtitan/experiments/ft/trainer.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,9 @@ def __init__(self, config: Config):
129129
)
130130
model_converters.convert(model)
131131

132+
# Verify all submodules satisfy the Module protocol
133+
model.verify_module_protocol()
134+
132135
# metrics logging (FT addition: ft_enable, ft_replica_id)
133136
self.metrics_processor = config.metrics.build(
134137
parallel_dims=parallel_dims,

0 commit comments

Comments
 (0)