Skip to content

Commit 042496e

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent d1051e7 commit 042496e

File tree

12 files changed

+61
-20
lines changed

12 files changed

+61
-20
lines changed

src/lightning/fabric/connector.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,7 @@ def __init__(
142142
elif self._accelerator_flag == "gpu":
143143
self._accelerator_flag = self._choose_gpu_accelerator_backend()
144144
elif isinstance(self._accelerator_flag, Accelerator):
145-
pass # for 3rd party accelerator, just do nothing
145+
pass # for 3rd party accelerator, just do nothing
146146

147147
self._set_parallel_devices_and_init_accelerator()
148148

@@ -463,7 +463,10 @@ def _check_and_init_precision(self) -> Precision:
463463
if isinstance(self.strategy, DeepSpeedStrategy):
464464
return DeepSpeedPrecision(self._precision_input) # type: ignore
465465
if isinstance(self.strategy, FSDPStrategy):
466-
return FSDPPrecision(precision=self._precision_input, device=self._accelerator_flag.get_device() if isinstance(self._accelerator_flag, Accelerator) else None) # type: ignore[arg-type]
466+
return FSDPPrecision(
467+
precision=self._precision_input,
468+
device=self._accelerator_flag.get_device() if isinstance(self._accelerator_flag, Accelerator) else None,
469+
) # type: ignore[arg-type]
467470
mp_precision_supported = ("32-true", "bf16-mixed", "bf16-true", "16-true")
468471
if isinstance(self.strategy, ModelParallelStrategy) and self._precision_input not in mp_precision_supported:
469472
raise ValueError(

src/lightning/fabric/plugins/precision/amp.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,11 @@ def __init__(
5050

5151
self.precision = precision
5252
if scaler is None and self.precision == "16-mixed":
53-
scaler = torch.amp.GradScaler(device=device) if _TORCH_GREATER_EQUAL_2_4 else getattr(torch, f"{device.split(':')[0]}").amp.GradScaler()
53+
scaler = (
54+
torch.amp.GradScaler(device=device)
55+
if _TORCH_GREATER_EQUAL_2_4
56+
else getattr(torch, f"{device.split(':')[0]}").amp.GradScaler()
57+
)
5458
if scaler is not None and self.precision == "bf16-mixed":
5559
raise ValueError(f"`precision='bf16-mixed'` does not use a scaler, found {scaler}.")
5660
self.device = device

src/lightning/fabric/plugins/precision/fsdp.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,9 @@ class FSDPPrecision(Precision):
4848
4949
"""
5050

51-
def __init__(self, precision: _PRECISION_INPUT, scaler: Optional["ShardedGradScaler"] = None, device: str = None) -> None:
51+
def __init__(
52+
self, precision: _PRECISION_INPUT, scaler: Optional["ShardedGradScaler"] = None, device: str = None
53+
) -> None:
5254
supported_precision = get_args(_PRECISION_INPUT)
5355
if precision not in supported_precision:
5456
raise ValueError(
@@ -111,7 +113,9 @@ def module_init_context(self) -> ContextManager:
111113
@override
112114
def forward_context(self) -> ContextManager:
113115
if "mixed" in self.precision:
114-
return torch.autocast(self.device, dtype=(torch.bfloat16 if self.precision == "bf16-mixed" else torch.float16))
116+
return torch.autocast(
117+
self.device, dtype=(torch.bfloat16 if self.precision == "bf16-mixed" else torch.float16)
118+
)
115119
return self.tensor_init_context()
116120

117121
@override

src/lightning/fabric/strategies/ddp.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,13 @@ def setup_module(self, module: Module) -> DistributedDataParallel:
124124
"""Wraps the model into a :class:`~torch.nn.parallel.distributed.DistributedDataParallel` module."""
125125
device_ids = self._determine_ddp_device_ids()
126126
# https://pytorch.org/docs/stable/notes/cuda.html#id5
127-
ctx = getattr(torch, f"{self.root_device.type.split(':')[0]}").stream(getattr(torch, f"{self.root_device.type.split(':')[0]}").Stream()) if device_ids is not None else nullcontext()
127+
ctx = (
128+
getattr(torch, f"{self.root_device.type.split(':')[0]}").stream(
129+
getattr(torch, f"{self.root_device.type.split(':')[0]}").Stream()
130+
)
131+
if device_ids is not None
132+
else nullcontext()
133+
)
128134
with ctx:
129135
return DistributedDataParallel(module=module, device_ids=device_ids, **self._ddp_kwargs)
130136

src/lightning/fabric/strategies/deepspeed.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -506,7 +506,9 @@ def load_checkpoint(
506506

507507
optimzer_state_requested = any(isinstance(item, (Optimizer, DeepSpeedOptimizer)) for item in state.values())
508508

509-
getattr(torch, f"{self.root_device.type.split(':')[0]}").empty_cache() if self.root_device.type != "cpu" else None
509+
getattr(
510+
torch, f"{self.root_device.type.split(':')[0]}"
511+
).empty_cache() if self.root_device.type != "cpu" else None
510512
_, client_state = engine.load_checkpoint(
511513
path,
512514
tag="checkpoint",

src/lightning/fabric/strategies/strategy.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -325,7 +325,9 @@ def load_checkpoint(
325325
given, the full checkpoint will be returned.
326326
327327
"""
328-
getattr(torch, f"{self.root_device.type.split(':')[0]}").empty_cache() if self.root_device.type != "cpu" else None
328+
getattr(
329+
torch, f"{self.root_device.type.split(':')[0]}"
330+
).empty_cache() if self.root_device.type != "cpu" else None
329331
checkpoint = self.checkpoint_io.load_checkpoint(path)
330332
if not state:
331333
return checkpoint

src/lightning/pytorch/accelerators/accelerator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ def get_device_stats(self, device: _DEVICE) -> Dict[str, Any]:
4545
4646
"""
4747
raise NotImplementedError
48-
48+
4949
@staticmethod
5050
def get_device() -> str:
5151
"""Get the device for the current process."""

src/lightning/pytorch/plugins/precision/amp.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,11 @@ def __init__(
5050

5151
self.precision = precision
5252
if scaler is None and self.precision == "16-mixed":
53-
scaler = torch.amp.GradScaler(device=device) if _TORCH_GREATER_EQUAL_2_4 else getattr(torch, f"{device.split(':')[0]}").amp.GradScaler()
53+
scaler = (
54+
torch.amp.GradScaler(device=device)
55+
if _TORCH_GREATER_EQUAL_2_4
56+
else getattr(torch, f"{device.split(':')[0]}").amp.GradScaler()
57+
)
5458
if scaler is not None and self.precision == "bf16-mixed":
5559
raise MisconfigurationException(f"`precision='bf16-mixed'` does not use a scaler, found {scaler}.")
5660
self.device = device

src/lightning/pytorch/plugins/precision/fsdp.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,9 @@ class FSDPPrecision(Precision):
4747
4848
"""
4949

50-
def __init__(self, precision: _PRECISION_INPUT, scaler: Optional["ShardedGradScaler"] = None, device: str = None) -> None:
50+
def __init__(
51+
self, precision: _PRECISION_INPUT, scaler: Optional["ShardedGradScaler"] = None, device: str = None
52+
) -> None:
5153
supported_precision = get_args(_PRECISION_INPUT)
5254
if precision not in supported_precision:
5355
raise ValueError(
@@ -120,7 +122,9 @@ def module_init_context(self) -> ContextManager:
120122
@override
121123
def forward_context(self) -> ContextManager:
122124
if "mixed" in self.precision:
123-
return torch.autocast(self.device, dtype=(torch.bfloat16 if self.precision == "bf16-mixed" else torch.float16))
125+
return torch.autocast(
126+
self.device, dtype=(torch.bfloat16 if self.precision == "bf16-mixed" else torch.float16)
127+
)
124128
return _DtypeContextManager(self._desired_input_dtype)
125129

126130
@override

src/lightning/pytorch/strategies/ddp.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,13 @@ def _setup_model(self, model: Module) -> DistributedDataParallel:
190190
device_ids = self.determine_ddp_device_ids()
191191
log.debug(f"setting up DDP model with device ids: {device_ids}, kwargs: {self._ddp_kwargs}")
192192
# https://pytorch.org/docs/stable/notes/cuda.html#id5
193-
ctx = getattr(torch, f"{self.root_device.type.split(':')[0]}").stream(getattr(torch, f"{self.root_device.type.split(':')[0]}").Stream()) if device_ids is not None else nullcontext()
193+
ctx = (
194+
getattr(torch, f"{self.root_device.type.split(':')[0]}").stream(
195+
getattr(torch, f"{self.root_device.type.split(':')[0]}").Stream()
196+
)
197+
if device_ids is not None
198+
else nullcontext()
199+
)
194200
with ctx:
195201
return DistributedDataParallel(module=model, device_ids=device_ids, **self._ddp_kwargs)
196202

0 commit comments

Comments
 (0)