Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions src/lightning/fabric/utilities/consolidate_checkpoint.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
import sys
from argparse import ArgumentParser, Namespace
from pathlib import Path

Expand Down Expand Up @@ -40,23 +41,23 @@ def _parse_cli_args() -> Namespace:
def _process_cli_args(args: Namespace) -> Namespace:
if not _TORCH_GREATER_EQUAL_2_3:
_log.error("Processing distributed checkpoints requires PyTorch >= 2.3.")
exit(1)
sys.exit(1)

checkpoint_folder = Path(args.checkpoint_folder)
if not checkpoint_folder.exists():
_log.error(f"The provided checkpoint folder does not exist: {checkpoint_folder}")
exit(1)
sys.exit(1)
if not checkpoint_folder.is_dir():
_log.error(
f"The provided checkpoint path must be a folder, containing the checkpoint shards: {checkpoint_folder}"
)
exit(1)
sys.exit(1)
if not (checkpoint_folder / _METADATA_FILENAME).is_file():
_log.error(
"Only FSDP-sharded checkpoints saved with Lightning are supported for consolidation. The provided folder"
f" is not in that format: {checkpoint_folder}"
)
exit(1)
sys.exit(1)

if args.output_file is None:
output_file = checkpoint_folder.with_suffix(checkpoint_folder.suffix + ".consolidated")
Expand All @@ -67,7 +68,7 @@ def _process_cli_args(args: Namespace) -> Namespace:
"The path for the converted checkpoint already exists. Choose a different path by providing"
f" `--output_file` or move/delete the file first: {output_file}"
)
exit(1)
sys.exit(1)

return Namespace(checkpoint_folder=checkpoint_folder, output_file=output_file)

Expand Down
3 changes: 2 additions & 1 deletion src/lightning/pytorch/trainer/call.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.
import logging
import signal
import sys
from copy import deepcopy
from typing import Any, Callable, Optional, Union

Expand Down Expand Up @@ -62,7 +63,7 @@ def _call_and_handle_interrupt(trainer: "pl.Trainer", trainer_fn: Callable, *arg
launcher = trainer.strategy.launcher
if isinstance(launcher, _SubprocessScriptLauncher):
launcher.kill(_get_sigkill_signal())
exit(1)
sys.exit(1)

except BaseException as exception:
_interrupt(trainer, exception)
Expand Down
Loading