Skip to content

Commit 3775340

Browse files
authored
Split Precision.init_context (#18734)
1 parent 0e04760 commit 3775340

File tree

28 files changed

+88
-87
lines changed

28 files changed

+88
-87
lines changed

src/lightning/fabric/CHANGELOG.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
4949
- Added support for meta-device initialization with `Fabric.init_module(empty_init=True)` in FSDP ([#18122](https://github.com/Lightning-AI/lightning/pull/18122))
5050

5151

52-
- Added `lightning.fabric.plugins.Precision.init_context()` and `lightning.fabric.strategies.Strategy.module_init_context()` context managers to control model and tensor instantiation ([#17462](https://github.com/Lightning-AI/lightning/pull/17462))
52+
- Added `lightning.fabric.plugins.Precision.module_init_context()` and `lightning.fabric.strategies.Strategy.module_init_context()` context managers to control model and tensor instantiation ([#17462](https://github.com/Lightning-AI/lightning/pull/17462))
5353

5454

5555
- `lightning.fabric.strategies.Strategy.tensor_init_context()` context manager to instantiate tensors efficiently directly on device and dtype ([#17607](https://github.com/Lightning-AI/lightning/pull/17607))

src/lightning/fabric/plugins/precision/bitsandbytes.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,10 @@ def convert_module(self, module: torch.nn.Module) -> torch.nn.Module:
116116
m.compute_type_is_set = False
117117
return module
118118

119-
def init_context(self) -> ContextManager:
119+
def tensor_init_context(self) -> ContextManager:
120+
return _DtypeContextManager(self.dtype)
121+
122+
def module_init_context(self) -> ContextManager:
120123
if self.ignore_modules:
121124
# cannot patch the Linear class if the user wants to skip some submodules
122125
raise RuntimeError(
@@ -125,7 +128,7 @@ def init_context(self) -> ContextManager:
125128
" may initialize the layers on-device, defeating the purpose of quantization. You can remove"
126129
" `ignore_modules` or remove the `init_module` context manager."
127130
)
128-
dtype_ctx = _DtypeContextManager(self.dtype)
131+
dtype_ctx = self.tensor_init_context()
129132
# TODO: this could also support replacing `Embedding` and `Conv1D`
130133
context_manager = _ClassReplacementContextManager({"torch.nn.Linear": self._linear_cls})
131134
stack = ExitStack()

src/lightning/fabric/plugins/precision/deepspeed.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,11 +66,14 @@ def convert_module(self, module: Module) -> Module:
6666
return module.to(dtype=self._desired_dtype)
6767
return module
6868

69-
def init_context(self) -> ContextManager:
69+
def tensor_init_context(self) -> ContextManager:
7070
if "true" not in self.precision:
7171
return nullcontext()
7272
return _DtypeContextManager(self._desired_dtype)
7373

74+
def module_init_context(self) -> ContextManager:
75+
return self.tensor_init_context()
76+
7477
def convert_input(self, data: Any) -> Any:
7578
return apply_to_collection(data, function=_convert_fp_tensor, dtype=Tensor, dst_type=self._desired_dtype)
7679

src/lightning/fabric/plugins/precision/double.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,11 +30,14 @@ class DoublePrecision(Precision):
3030
def convert_module(self, module: Module) -> Module:
3131
return module.double()
3232

33-
def init_context(self) -> ContextManager:
33+
def tensor_init_context(self) -> ContextManager:
3434
return _DtypeContextManager(torch.double)
3535

36+
def module_init_context(self) -> ContextManager:
37+
return self.tensor_init_context()
38+
3639
def forward_context(self) -> ContextManager:
37-
return _DtypeContextManager(torch.double)
40+
return self.tensor_init_context()
3841

3942
def convert_input(self, data: Any) -> Any:
4043
return apply_to_collection(data, function=_convert_fp_tensor, dtype=Tensor, dst_type=torch.double)

src/lightning/fabric/plugins/precision/fsdp.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -103,13 +103,16 @@ def mixed_precision_config(self) -> "TorchMixedPrecision":
103103
buffer_dtype=buffer_dtype,
104104
)
105105

106-
def init_context(self) -> ContextManager:
106+
def tensor_init_context(self) -> ContextManager:
107+
return _DtypeContextManager(self._desired_input_dtype)
108+
109+
def module_init_context(self) -> ContextManager:
107110
return _DtypeContextManager(self.mixed_precision_config.param_dtype or torch.float32)
108111

109112
def forward_context(self) -> ContextManager:
110113
if "mixed" in self.precision:
111114
return torch.autocast("cuda", dtype=(torch.bfloat16 if self.precision == "bf16-mixed" else torch.float16))
112-
return _DtypeContextManager(self._desired_input_dtype)
115+
return self.tensor_init_context()
113116

114117
def convert_input(self, data: Any) -> Any:
115118
return apply_to_collection(data, function=_convert_fp_tensor, dtype=Tensor, dst_type=self._desired_input_dtype)

src/lightning/fabric/plugins/precision/half.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,11 +39,14 @@ def __init__(self, precision: Literal["bf16-true", "16-true"] = "16-true") -> No
3939
def convert_module(self, module: Module) -> Module:
4040
return module.to(dtype=self._desired_input_dtype)
4141

42-
def init_context(self) -> ContextManager:
42+
def tensor_init_context(self) -> ContextManager:
4343
return _DtypeContextManager(self._desired_input_dtype)
4444

45+
def module_init_context(self) -> ContextManager:
46+
return self.tensor_init_context()
47+
4548
def forward_context(self) -> ContextManager:
46-
return _DtypeContextManager(self._desired_input_dtype)
49+
return self.tensor_init_context()
4750

4851
def convert_input(self, data: Any) -> Any:
4952
return apply_to_collection(data, function=_convert_fp_tensor, dtype=Tensor, dst_type=self._desired_input_dtype)

src/lightning/fabric/plugins/precision/precision.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,11 @@ def convert_module(self, module: Module) -> Module:
5353
"""
5454
return module
5555

56-
def init_context(self) -> ContextManager:
56+
def tensor_init_context(self) -> ContextManager:
57+
"""Controls how tensors get created (device, dtype)."""
58+
return nullcontext()
59+
60+
def module_init_context(self) -> ContextManager:
5761
"""Instantiate module parameters or tensors in the precision type this plugin handles.
5862
5963
This is optional and depends on the precision limitations during optimization.

src/lightning/fabric/plugins/precision/transformer_engine.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -93,8 +93,11 @@ def convert_module(self, module: torch.nn.Module) -> torch.nn.Module:
9393
module = module.to(dtype=self.dtype)
9494
return module
9595

96-
def init_context(self) -> ContextManager:
97-
dtype_ctx = _DtypeContextManager(self.dtype)
96+
def tensor_init_context(self) -> ContextManager:
97+
return _DtypeContextManager(self.dtype)
98+
99+
def module_init_context(self) -> ContextManager:
100+
dtype_ctx = self.tensor_init_context()
98101
stack = ExitStack()
99102
if self.replace_layers:
100103
import transformer_engine.pytorch as te

src/lightning/fabric/strategies/fsdp.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -329,16 +329,17 @@ def module_to_device(self, module: Module) -> None:
329329
pass
330330

331331
def module_init_context(self, empty_init: Optional[bool] = None) -> ContextManager:
332-
precision_init_ctx = self.precision.init_context()
332+
precision_init_ctx = self.precision.module_init_context()
333333
module_sharded_ctx = self.module_sharded_context()
334+
empty_ctx = _EmptyInit(enabled=bool(empty_init))
334335
stack = ExitStack()
335336
if _TORCH_GREATER_EQUAL_2_1 and empty_init:
336337
# Materialization happens in `setup`. When modules get wrapped by FSDP, the sequence of operations is:
337338
# 1) materialize module 2) call `reset_parameters()` 3) shard the module.
338339
# These operations are applied to each submodule 'bottom up' in the module hierarchy.
339340
stack.enter_context(torch.device("meta"))
340341
elif _TORCH_GREATER_EQUAL_1_13:
341-
stack.enter_context(_EmptyInit(enabled=bool(empty_init)))
342+
stack.enter_context(empty_ctx)
342343
stack.enter_context(precision_init_ctx)
343344
stack.enter_context(module_sharded_ctx)
344345
return stack

src/lightning/fabric/strategies/strategy.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ def process_dataloader(self, dataloader: DataLoader) -> DataLoader:
120120

121121
def tensor_init_context(self) -> ContextManager:
122122
"""Controls how tensors get created (device, dtype)."""
123-
precision_init_ctx = self.precision.init_context()
123+
precision_init_ctx = self.precision.tensor_init_context()
124124
stack = ExitStack()
125125
if _TORCH_GREATER_EQUAL_2_0:
126126
stack.enter_context(self.root_device)

0 commit comments

Comments
 (0)