Skip to content
3 changes: 0 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -101,8 +101,6 @@ ignore = [
"S603", # todo: `subprocess` call: check for execution of untrusted input
"S605", # todo: Starting a process with a shell: seems safe, but may be changed in the future; consider rewriting without `shell`
"S607", # todo: Starting a process with a partial executable path
"RET504", # todo:Unnecessary variable assignment before `return` statement
"RET503",
]
"tests/**" = [
"S101", # Use of `assert` detected
Expand All @@ -118,7 +116,6 @@ ignore = [
"S603", # todo: `subprocess` call: check for execution of untrusted input
"S605", # todo: Starting a process with a shell: seems safe, but may be changed in the future; consider rewriting without `shell`
"S607", # todo: Starting a process with a partial executable path
"RET504", # todo:Unnecessary variable assignment before `return` statement
"PT004", # todo: Fixture `tmpdir_unittest_fixture` does not return anything, add leading underscore
"PT012", # todo: `pytest.raises()` block should contain a single simple statement
"PT019", # todo: Fixture `_` without value is injected as parameter, use `@pytest.mark.usefixtures` instead
Expand Down
3 changes: 1 addition & 2 deletions src/lightning/fabric/fabric.py
Original file line number Diff line number Diff line change
Expand Up @@ -476,8 +476,7 @@ def _setup_dataloader(
dataloader = self._strategy.process_dataloader(dataloader)
device = self.device if move_to_device and not isinstance(self._strategy, XLAStrategy) else None
fabric_dataloader = _FabricDataLoader(dataloader=dataloader, device=device)
fabric_dataloader = cast(DataLoader, fabric_dataloader)
return fabric_dataloader
return cast(DataLoader, fabric_dataloader)

def backward(self, tensor: Tensor, *args: Any, model: Optional[_FabricModule] = None, **kwargs: Any) -> None:
r"""Replaces ``loss.backward()`` in your training loop. Handles precision automatically for you.
Expand Down
3 changes: 1 addition & 2 deletions src/lightning/fabric/loggers/tensorboard.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,8 +157,7 @@ def log_dir(self) -> str:
if isinstance(self.sub_dir, str):
log_dir = os.path.join(log_dir, self.sub_dir)
log_dir = os.path.expandvars(log_dir)
log_dir = os.path.expanduser(log_dir)
return log_dir
return os.path.expanduser(log_dir)

@property
def sub_dir(self) -> Optional[str]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,8 +103,7 @@ def convert_module(self, module: torch.nn.Module) -> torch.nn.Module:
)
elif self.replace_layers in (None, True):
_convert_layers(module)
module = module.to(dtype=self.weights_dtype)
return module
return module.to(dtype=self.weights_dtype)

@override
def tensor_init_context(self) -> AbstractContextManager:
Expand Down
13 changes: 5 additions & 8 deletions src/lightning/fabric/strategies/fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -795,19 +795,18 @@ def _optimizer_has_flat_params(optimizer: Optimizer) -> bool:
)


def _get_sharded_state_dict_context(module: Module) -> Generator[None, None, None]:
def _get_sharded_state_dict_context(module: Module) -> Generator:
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.api import ShardedOptimStateDictConfig, ShardedStateDictConfig, StateDictType

state_dict_config = ShardedStateDictConfig(offload_to_cpu=True)
optim_state_dict_config = ShardedOptimStateDictConfig(offload_to_cpu=True)
state_dict_type_context = FSDP.state_dict_type(
return FSDP.state_dict_type(
module=module,
state_dict_type=StateDictType.SHARDED_STATE_DICT,
state_dict_config=state_dict_config,
optim_state_dict_config=optim_state_dict_config,
)
return state_dict_type_context # type: ignore[return-value]
) # type: ignore[return-value]


def _get_full_state_dict_context(
Expand All @@ -819,14 +818,12 @@ def _get_full_state_dict_context(

state_dict_config = FullStateDictConfig(offload_to_cpu=True, rank0_only=rank0_only)
optim_state_dict_config = FullOptimStateDictConfig(offload_to_cpu=True, rank0_only=rank0_only)
state_dict_type_context = FSDP.state_dict_type(
return FSDP.state_dict_type(
module=module,
state_dict_type=StateDictType.FULL_STATE_DICT,
state_dict_config=state_dict_config,
optim_state_dict_config=optim_state_dict_config,
)

return state_dict_type_context # type: ignore[return-value]
) # type: ignore[return-value]


def _is_sharded_checkpoint(path: Path) -> bool:
Expand Down
3 changes: 1 addition & 2 deletions src/lightning/fabric/strategies/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,8 +104,7 @@ def reduce_boolean_decision(self, decision: bool, all: bool = True) -> bool:
decision,
reduce_op=ReduceOp.SUM, # type: ignore[arg-type]
)
decision = bool(decision == self.world_size) if all else bool(decision)
return decision
return bool(decision == self.world_size) if all else bool(decision)

@override
def teardown(self) -> None:
Expand Down
3 changes: 1 addition & 2 deletions src/lightning/fabric/strategies/xla.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,8 +200,7 @@ def all_gather(self, tensor: Tensor, group: Optional[Any] = None, sync_grads: bo
import torch_xla.core.xla_model as xm

tensor = xf.all_gather(tensor) if sync_grads else xm.all_gather(tensor)
tensor = tensor.to(original_device)
return tensor
return tensor.to(original_device)

@override
def all_reduce(
Expand Down
3 changes: 1 addition & 2 deletions src/lightning/fabric/strategies/xla_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,8 +334,7 @@ def all_gather(self, tensor: Tensor, group: Optional[Any] = None, sync_grads: bo
import torch_xla.core.xla_model as xm

tensor = xf.all_gather(tensor) if sync_grads else xm.all_gather(tensor)
tensor = tensor.to(original_device)
return tensor
return tensor.to(original_device)

@override
def all_reduce(
Expand Down
1 change: 1 addition & 0 deletions src/lightning/fabric/utilities/throughput.py
Original file line number Diff line number Diff line change
Expand Up @@ -632,6 +632,7 @@ def get_available_flops(device: torch.device, dtype: Union[torch.dtype, str]) ->
rank_zero_warn(f"FLOPs not found for TPU {device_name!r} with {dtype}")
return None
return int(_TPU_FLOPS[chip])
return None


def _plugin_to_compute_dtype(plugin: "Precision") -> torch.dtype:
Expand Down
4 changes: 1 addition & 3 deletions src/lightning/pytorch/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -623,9 +623,7 @@ def check_monitor_top_k(self, trainer: "pl.Trainer", current: Optional[Tensor] =
should_update_best_and_save = monitor_op(current, self.best_k_models[self.kth_best_model_path])

# If using multiple devices, make sure all processes are unanimous on the decision.
should_update_best_and_save = trainer.strategy.reduce_boolean_decision(bool(should_update_best_and_save))

return should_update_best_and_save
return trainer.strategy.reduce_boolean_decision(bool(should_update_best_and_save))

def _format_checkpoint_name(
self,
Expand Down
3 changes: 1 addition & 2 deletions src/lightning/pytorch/core/datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,5 +315,4 @@ def format_loader_info(info: dict[str, Union[dataset_info, Iterable[dataset_info
# Retrieve information for each dataloader method
dataloader_info = extract_loader_info(datamodule_loader_methods)
# Format the information
dataloader_str = format_loader_info(dataloader_info)
return dataloader_str
return format_loader_info(dataloader_info)
6 changes: 2 additions & 4 deletions src/lightning/pytorch/core/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,8 +360,7 @@ def _apply_batch_transfer_handler(
) -> Any:
device = device or self.device
batch = self._call_batch_hook("transfer_batch_to_device", batch, device, dataloader_idx)
batch = self._call_batch_hook("on_after_batch_transfer", batch, dataloader_idx)
return batch
return self._call_batch_hook("on_after_batch_transfer", batch, dataloader_idx)

def print(self, *args: Any, **kwargs: Any) -> None:
r"""Prints only from process 0. Use this in any distributed mode to log only once.
Expand Down Expand Up @@ -666,8 +665,7 @@ def __to_tensor(self, value: Union[Tensor, numbers.Number], name: str) -> Tensor
f"`self.log({name}, {value})` was called, but the tensor must have a single element."
f" You can try doing `self.log({name}, {value}.mean())`"
)
value = value.squeeze()
return value
return value.squeeze()

def all_gather(
self, data: Union[Tensor, dict, list, tuple], group: Optional[Any] = None, sync_grads: bool = False
Expand Down
1 change: 1 addition & 0 deletions src/lightning/pytorch/core/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,6 +409,7 @@ def step(self, closure: Callable[[], float]) -> float: ...
def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]:
if closure is not None:
return closure()
return None

@override
def zero_grad(self, set_to_none: Optional[bool] = True) -> None:
Expand Down
12 changes: 4 additions & 8 deletions src/lightning/pytorch/demos/boring_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,20 +253,16 @@ def setup(self, stage: str) -> None:
]

def train_dataloader(self) -> Iterable[DataLoader]:
combined_train = apply_to_collection(self.train_datasets, Dataset, lambda x: DataLoader(x))
return combined_train
return apply_to_collection(self.train_datasets, Dataset, lambda x: DataLoader(x))

def val_dataloader(self) -> DataLoader:
combined_val = apply_to_collection(self.val_datasets, Dataset, lambda x: DataLoader(x))
return combined_val
return apply_to_collection(self.val_datasets, Dataset, lambda x: DataLoader(x))

def test_dataloader(self) -> DataLoader:
combined_test = apply_to_collection(self.test_datasets, Dataset, lambda x: DataLoader(x))
return combined_test
return apply_to_collection(self.test_datasets, Dataset, lambda x: DataLoader(x))

def predict_dataloader(self) -> DataLoader:
combined_predict = apply_to_collection(self.predict_datasets, Dataset, lambda x: DataLoader(x))
return combined_predict
return apply_to_collection(self.predict_datasets, Dataset, lambda x: DataLoader(x))


class ManualOptimBoringModel(BoringModel):
Expand Down
12 changes: 4 additions & 8 deletions src/lightning/pytorch/demos/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,7 @@ def __init__(
def generate_square_subsequent_mask(self, size: int) -> Tensor:
"""Generate a square mask for the sequence to prevent future tokens from being seen."""
mask = torch.triu(torch.ones(size, size), diagonal=1)
mask = mask.float().masked_fill(mask == 1, float("-inf")).masked_fill(mask == 0, 0.0)
return mask
return mask.float().masked_fill(mask == 1, float("-inf")).masked_fill(mask == 0, 0.0)

def forward(self, inputs: Tensor, target: Tensor, mask: Optional[Tensor] = None) -> Tensor:
_, t = inputs.shape
Expand All @@ -78,8 +77,7 @@ def forward(self, inputs: Tensor, target: Tensor, mask: Optional[Tensor] = None)
output = self.transformer(src, target, tgt_mask=mask)
output = self.decoder(output)
output = F.log_softmax(output, dim=-1)
output = output.view(-1, self.vocab_size)
return output
return output.view(-1, self.vocab_size)


class PositionalEncoding(nn.Module):
Expand All @@ -106,8 +104,7 @@ def _init_pos_encoding(self, device: torch.device) -> Tensor:
div_term = torch.exp(torch.arange(0, self.dim, 2, device=device).float() * (-math.log(10000.0) / self.dim))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0)
return pe
return pe.unsqueeze(0)


class WikiText2(Dataset):
Expand Down Expand Up @@ -200,8 +197,7 @@ def forward(self, inputs: Tensor, target: Tensor) -> Tensor:
def training_step(self, batch: tuple[Tensor, Tensor], batch_idx: int) -> Tensor:
inputs, target = batch
output = self(inputs, target)
loss = torch.nn.functional.nll_loss(output, target.view(-1))
return loss
return torch.nn.functional.nll_loss(output, target.view(-1))

def configure_optimizers(self) -> torch.optim.Optimizer:
return torch.optim.SGD(self.model.parameters(), lr=0.1)
Expand Down
1 change: 1 addition & 0 deletions src/lightning/pytorch/loggers/comet.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,6 +397,7 @@ def version(self) -> Optional[str]:
# Don't create an experiment if we don't have one
if self._experiment is not None:
return self._experiment.get_key()
return None

def __getstate__(self) -> dict[str, Any]:
state = self.__dict__.copy()
Expand Down
3 changes: 1 addition & 2 deletions src/lightning/pytorch/loggers/tensorboard.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,8 +136,7 @@ def log_dir(self) -> str:
if isinstance(self.sub_dir, str):
log_dir = os.path.join(log_dir, self.sub_dir)
log_dir = os.path.expandvars(log_dir)
log_dir = os.path.expanduser(log_dir)
return log_dir
return os.path.expanduser(log_dir)

@property
@override
Expand Down
3 changes: 1 addition & 2 deletions src/lightning/pytorch/loggers/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,7 @@ def _scan_checkpoints(checkpoint_callback: Checkpoint, logged_model_time: dict)
checkpoints = sorted(
(Path(p).stat().st_mtime, p, s, tag) for p, (s, tag) in checkpoints.items() if Path(p).is_file()
)
checkpoints = [c for c in checkpoints if c[1] not in logged_model_time or logged_model_time[c[1]] < c[0]]
return checkpoints
return [c for c in checkpoints if c[1] not in logged_model_time or logged_model_time[c[1]] < c[0]]


def _log_hyperparams(trainer: "pl.Trainer") -> None:
Expand Down
3 changes: 1 addition & 2 deletions src/lightning/pytorch/profilers/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,12 +375,11 @@ def _total_steps(self) -> Union[int, float]:
)
return num_val_batches + num_sanity_val_batches
if self._schedule.is_testing:
num_test_batches = (
return (
sum(trainer.num_test_batches)
if isinstance(trainer.num_test_batches, list)
else trainer.num_test_batches
)
return num_test_batches
if self._schedule.is_predicting:
return sum(trainer.num_predict_batches)
raise NotImplementedError("Unsupported schedule")
Expand Down
4 changes: 2 additions & 2 deletions src/lightning/pytorch/strategies/fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -581,6 +581,7 @@ def save_checkpoint(
return super().save_checkpoint(checkpoint=checkpoint, filepath=path)
else:
raise ValueError(f"Unknown state_dict_type: {self._state_dict_type}")
return None

@override
def load_checkpoint(self, checkpoint_path: _PATH) -> dict[str, Any]:
Expand Down Expand Up @@ -624,8 +625,7 @@ def load_checkpoint(self, checkpoint_path: _PATH) -> dict[str, Any]:
optim.load_state_dict(flattened_osd)

# Load metadata (anything not a module or optimizer)
metadata = torch.load(path / _METADATA_FILENAME)
return metadata
return torch.load(path / _METADATA_FILENAME)

if _is_full_checkpoint(path):
checkpoint = _lazy_load(path)
Expand Down
1 change: 1 addition & 0 deletions src/lightning/pytorch/strategies/model_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,7 @@ def save_checkpoint(
if _is_sharded_checkpoint(path):
shutil.rmtree(path)
return super().save_checkpoint(checkpoint=checkpoint, filepath=path)
return None

@override
def load_checkpoint(self, checkpoint_path: _PATH) -> dict[str, Any]:
Expand Down
3 changes: 1 addition & 2 deletions src/lightning/pytorch/strategies/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,8 +110,7 @@ def reduce_boolean_decision(self, decision: bool, all: bool = True) -> bool:
decision,
reduce_op=ReduceOp.SUM, # type: ignore[arg-type]
)
decision = bool(decision == self.world_size) if all else bool(decision)
return decision
return bool(decision == self.world_size) if all else bool(decision)

@contextmanager
def block_backward_sync(self) -> Generator:
Expand Down
3 changes: 1 addition & 2 deletions src/lightning/pytorch/strategies/xla.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,8 +345,7 @@ def all_gather(self, tensor: Tensor, group: Optional[Any] = None, sync_grads: bo
import torch_xla.core.xla_model as xm

tensor = xf.all_gather(tensor) if sync_grads else xm.all_gather(tensor)
tensor = tensor.to(original_device)
return tensor
return tensor.to(original_device)

@override
def teardown(self) -> None:
Expand Down
6 changes: 2 additions & 4 deletions src/lightning/pytorch/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1277,8 +1277,7 @@ def training_step(self, batch, batch_idx):
else:
dirpath = self.default_root_dir

dirpath = self.strategy.broadcast(dirpath)
return dirpath
return self.strategy.broadcast(dirpath)

@property
def is_global_zero(self) -> bool:
Expand Down Expand Up @@ -1731,5 +1730,4 @@ def configure_optimizers(self):
assert self.max_epochs is not None
max_estimated_steps = math.ceil(total_batches / self.accumulate_grad_batches) * max(self.max_epochs, 1)

max_estimated_steps = min(max_estimated_steps, self.max_steps) if self.max_steps != -1 else max_estimated_steps
return max_estimated_steps
return min(max_estimated_steps, self.max_steps) if self.max_steps != -1 else max_estimated_steps
3 changes: 1 addition & 2 deletions src/lightning/pytorch/utilities/model_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,8 +137,7 @@ def _determine_model_folder(model_name: str, default_root_dir: str) -> str:
# download the latest checkpoint from the model registry
model_name = model_name.replace("/", "_")
model_name = model_name.replace(":", "_")
local_model_dir = os.path.join(default_root_dir, model_name)
return local_model_dir
return os.path.join(default_root_dir, model_name)


def find_model_local_ckpt_path(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -313,14 +313,13 @@ def total_flops(self) -> int:
@property
def flop_counts(self) -> dict[str, dict[Any, int]]:
flop_counts = self._flop_counter.get_flop_counts()
ret = {
return {
name: flop_counts.get(
f"{type(self._model).__name__}.{name}",
{},
)
for name in self.layer_names
}
return ret

def summarize(self) -> dict[str, LayerSummary]:
summary = OrderedDict((name, LayerSummary(module)) for name, module in self.named_modules)
Expand Down
3 changes: 1 addition & 2 deletions tests/parity_fabric/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,7 @@ def forward(self, x):
x = torch.flatten(x, 1) # flatten all dimensions except batch
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
return self.fc3(x)

def get_optimizer(self):
return torch.optim.SGD(self.parameters(), lr=0.0001)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,7 @@ def _parallelize_feed_forward_fsdp2(model, device_mesh):

def _parallelize_feed_forward_fsdp2_tp(model, device_mesh):
model = _parallelize_feed_forward_tp(model, device_mesh)
model = _parallelize_feed_forward_fsdp2(model, device_mesh)
return model
return _parallelize_feed_forward_fsdp2(model, device_mesh)


@RunIf(min_torch="2.4", standalone=True, min_cuda_gpus=4)
Expand Down
Loading
Loading