Skip to content

Commit dd7f19d

Browse files
committed
fabric: unify CLI with jsonargparse
1 parent 04e103b commit dd7f19d

File tree

2 files changed

+160
-128
lines changed

2 files changed

+160
-128
lines changed

requirements/fabric/test.txt

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,5 @@ pytest-cov ==6.2.1
55
pytest-timeout ==2.4.0
66
pytest-rerunfailures ==16.0
77
pytest-random-order ==1.2.0
8-
click ==8.1.8; python_version < "3.11"
9-
click ==8.2.1; python_version > "3.10"
8+
jsonargparse[signatures,jsonnet] >=4.39.0, <4.41.0
109
tensorboardX >=2.6, <2.7.0 # todo: relax it back to `>=2.2` after fixing tests

src/lightning/fabric/cli.py

Lines changed: 159 additions & 126 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131

3232
_log = logging.getLogger(__name__)
3333

34-
_CLICK_AVAILABLE = RequirementCache("click")
34+
_JSONARGPARSE_SIGNATURES_AVAILABLE = RequirementCache("jsonargparse[signatures]>=4.27.7")
3535
_LIGHTNING_SDK_AVAILABLE = RequirementCache("lightning_sdk")
3636

3737
_SUPPORTED_ACCELERATORS = ("cpu", "gpu", "cuda", "mps", "tpu", "auto")
@@ -45,127 +45,160 @@ def _get_supported_strategies() -> list[str]:
4545
return [strategy for strategy in available_strategies if not re.match(excluded, strategy)]
4646

4747

48-
if _CLICK_AVAILABLE:
49-
import click
50-
51-
@click.group()
52-
def _main() -> None:
53-
pass
54-
55-
@_main.command(
56-
"run",
57-
context_settings={
58-
"ignore_unknown_options": True,
59-
},
60-
)
61-
@click.argument(
62-
"script",
63-
type=click.Path(exists=True),
64-
)
65-
@click.option(
66-
"--accelerator",
67-
type=click.Choice(_SUPPORTED_ACCELERATORS),
68-
default=None,
69-
help="The hardware accelerator to run on.",
70-
)
71-
@click.option(
72-
"--strategy",
73-
type=click.Choice(_get_supported_strategies()),
74-
default=None,
75-
help="Strategy for how to run across multiple devices.",
76-
)
77-
@click.option(
78-
"--devices",
79-
type=str,
80-
default="1",
81-
help=(
82-
"Number of devices to run on (``int``), which devices to run on (``list`` or ``str``), or ``'auto'``."
83-
" The value applies per node."
84-
),
85-
)
86-
@click.option(
87-
"--num-nodes",
88-
"--num_nodes",
89-
type=int,
90-
default=1,
91-
help="Number of machines (nodes) for distributed execution.",
92-
)
93-
@click.option(
94-
"--node-rank",
95-
"--node_rank",
96-
type=int,
97-
default=0,
98-
help=(
99-
"The index of the machine (node) this command gets started on. Must be a number in the range"
100-
" 0, ..., num_nodes - 1."
101-
),
102-
)
103-
@click.option(
104-
"--main-address",
105-
"--main_address",
106-
type=str,
107-
default="127.0.0.1",
108-
help="The hostname or IP address of the main machine (usually the one with node_rank = 0).",
109-
)
110-
@click.option(
111-
"--main-port",
112-
"--main_port",
113-
type=int,
114-
default=29400,
115-
help="The main port to connect to the main machine.",
116-
)
117-
@click.option(
118-
"--precision",
119-
type=click.Choice(get_args(_PRECISION_INPUT_STR) + get_args(_PRECISION_INPUT_STR_ALIAS)),
120-
default=None,
121-
help=(
122-
"Double precision (``64-true`` or ``64``), full precision (``32-true`` or ``32``), "
123-
"half precision (``16-mixed`` or ``16``) or bfloat16 precision (``bf16-mixed`` or ``bf16``)"
124-
),
125-
)
126-
@click.argument("script_args", nargs=-1, type=click.UNPROCESSED)
127-
def _run(**kwargs: Any) -> None:
128-
"""Run a Lightning Fabric script.
129-
130-
SCRIPT is the path to the Python script with the code to run. The script must contain a Fabric object.
131-
132-
SCRIPT_ARGS are the remaining arguments that you can pass to the script itself and are expected to be parsed
133-
there.
134-
135-
"""
136-
script_args = list(kwargs.pop("script_args", []))
137-
main(args=Namespace(**kwargs), script_args=script_args)
138-
139-
@_main.command(
140-
"consolidate",
141-
context_settings={
142-
"ignore_unknown_options": True,
143-
},
144-
)
145-
@click.argument(
146-
"checkpoint_folder",
147-
type=click.Path(exists=True),
148-
)
149-
@click.option(
150-
"--output_file",
151-
type=click.Path(exists=True),
152-
default=None,
153-
help=(
154-
"Path to the file where the converted checkpoint should be saved. The file should not already exist."
155-
" If no path is provided, the file will be saved next to the input checkpoint folder with the same name"
156-
" and a '.consolidated' suffix."
157-
),
158-
)
159-
def _consolidate(checkpoint_folder: str, output_file: Optional[str]) -> None:
160-
"""Convert a distributed/sharded checkpoint into a single file that can be loaded with `torch.load()`.
161-
162-
Only supports FSDP sharded checkpoints at the moment.
163-
164-
"""
165-
args = Namespace(checkpoint_folder=checkpoint_folder, output_file=output_file)
166-
config = _process_cli_args(args)
167-
checkpoint = _load_distributed_checkpoint(config.checkpoint_folder)
168-
torch.save(checkpoint, config.output_file)
48+
if _JSONARGPARSE_SIGNATURES_AVAILABLE:
49+
from jsonargparse import ArgumentParser, register_unresolvable_import_paths
50+
51+
# Align with pytorch CLI behavior
52+
register_unresolvable_import_paths(torch) # Required until the upstream PyTorch issue is fixed
53+
54+
try:
55+
from jsonargparse import set_parsing_settings
56+
57+
set_parsing_settings(config_read_mode_fsspec_enabled=True)
58+
except ImportError:
59+
from jsonargparse import set_config_read_mode
60+
61+
set_config_read_mode(fsspec_enabled=True)
62+
63+
class FabricCLI:
64+
"""Lightning Fabric command-line tool."""
65+
66+
def __init__(self, args: Optional[list[str]] = None, run: bool = True) -> None:
67+
self.parser = self.init_parser()
68+
self._add_subcommands(self.parser)
69+
self.config, self.unknown_args = self.parser.parse_known_args(args)
70+
71+
if run:
72+
self.run()
73+
74+
def init_parser(self) -> ArgumentParser:
75+
"""Method that instantiates the argument parser."""
76+
return ArgumentParser(prog="lightning-fabric", description=self.__class__.__doc__)
77+
78+
def _add_subcommands(self, parser: ArgumentParser) -> None:
79+
"""Adds subcommands to the parser."""
80+
subparsers = parser.add_subparsers(dest="command", required=True)
81+
self.add_run_subcommand(subparsers)
82+
self.add_consolidate_subcommand(subparsers)
83+
84+
def add_run_subcommand(self, subparsers: Any) -> None:
85+
"""Adds the `run` subcommand to the parser."""
86+
parser = subparsers.add_parser("run", help="Run a Lightning Fabric script.")
87+
parser.add_argument(
88+
"script",
89+
type=str,
90+
help="Path to the Python script with the code to run. The script must contain a Fabric object.",
91+
)
92+
parser.add_argument(
93+
"--accelerator",
94+
choices=_SUPPORTED_ACCELERATORS,
95+
default=None,
96+
help="The hardware accelerator to run on.",
97+
)
98+
parser.add_argument(
99+
"--strategy",
100+
choices=_get_supported_strategies(),
101+
default=None,
102+
help="Strategy for how to run across multiple devices.",
103+
)
104+
parser.add_argument(
105+
"--devices",
106+
type=str,
107+
default="1",
108+
help=(
109+
"Number of devices to run on (int), which devices to run on (list or str), or 'auto'."
110+
" The value applies per node."
111+
),
112+
)
113+
parser.add_argument(
114+
"--num-nodes",
115+
"--num_nodes",
116+
type=int,
117+
default=1,
118+
help="Number of machines (nodes) for distributed execution.",
119+
)
120+
parser.add_argument(
121+
"--node-rank",
122+
"--node_rank",
123+
type=int,
124+
default=0,
125+
help="The index of the machine (node) this command gets started on. Must be 0, ..., num_nodes - 1.",
126+
)
127+
parser.add_argument(
128+
"--main-address",
129+
"--main_address",
130+
type=str,
131+
default="127.0.0.1",
132+
help="The hostname or IP address of the main machine (usually the one with node_rank = 0).",
133+
)
134+
parser.add_argument(
135+
"--main-port",
136+
"--main_port",
137+
type=int,
138+
default=29400,
139+
help="The main port to connect to the main machine.",
140+
)
141+
parser.add_argument(
142+
"--precision",
143+
choices=list(get_args(_PRECISION_INPUT_STR) + get_args(_PRECISION_INPUT_STR_ALIAS)),
144+
default=None,
145+
help=(
146+
"Double precision ('64-true' or '64'), full precision ('32-true' or '32'), "
147+
"half precision ('16-mixed' or '16') or bfloat16 precision ('bf16-mixed' or 'bf16')."
148+
),
149+
)
150+
151+
def add_consolidate_subcommand(self, subparsers: Any) -> None:
152+
"""Adds the `consolidate` subcommand to the parser."""
153+
parser = subparsers.add_parser(
154+
"consolidate", help="Convert a distributed/sharded checkpoint into a single file."
155+
)
156+
parser.add_argument(
157+
"checkpoint_folder",
158+
type=str,
159+
help="Path to the input checkpoint folder.",
160+
)
161+
parser.add_argument(
162+
"--output_file",
163+
type=str,
164+
default=None,
165+
help=(
166+
"Path to the file where the converted checkpoint should be saved. The file should not already exist."
167+
" If no path is provided, the file will be saved next to the input checkpoint folder with the same"
168+
" name and a '.consolidated' suffix."
169+
),
170+
)
171+
172+
def run(self) -> None:
173+
"""Runs the subcommand."""
174+
if self.config.command == "run":
175+
self._run_script()
176+
elif self.config.command == "consolidate":
177+
self._consolidate_checkpoint()
178+
179+
def _run_script(self) -> None:
180+
"""Runs the script with the given arguments."""
181+
config = self.config.run
182+
if not os.path.exists(config.script):
183+
raise SystemExit(f"Script not found: {config.script}")
184+
185+
args = Namespace(**vars(config))
186+
main(args=args, script_args=self.unknown_args)
187+
188+
def _consolidate_checkpoint(self) -> None:
189+
"""Consolidates the checkpoint."""
190+
config = self.config.consolidate
191+
if not os.path.isdir(config.checkpoint_folder):
192+
raise SystemExit(f"Checkpoint folder not found: {config.checkpoint_folder}")
193+
194+
args = Namespace(**vars(config))
195+
processed_args = _process_cli_args(args)
196+
checkpoint = _load_distributed_checkpoint(processed_args.checkpoint_folder)
197+
torch.save(checkpoint, processed_args.output_file)
198+
199+
def _entrypoint() -> None:
200+
"""The CLI entrypoint."""
201+
FabricCLI()
169202

170203

171204
def _set_env_variables(args: Namespace) -> None:
@@ -235,11 +268,11 @@ def main(args: Namespace, script_args: Optional[list[str]] = None) -> None:
235268

236269

237270
if __name__ == "__main__":
238-
if not _CLICK_AVAILABLE: # pragma: no cover
271+
if not _JSONARGPARSE_SIGNATURES_AVAILABLE: # pragma: no cover
239272
_log.error(
240-
"To use the Lightning Fabric CLI, you must have `click` installed."
241-
" Install it by running `pip install -U click`."
273+
"To use the Lightning Fabric CLI, you must have 'jsonargparse[signatures]>=4.27.7' installed."
274+
" Install it by running: pip install -U 'jsonargparse[signatures]>=4.27.7'."
242275
)
243276
raise SystemExit(1)
244277

245-
_run()
278+
_entrypoint()

0 commit comments

Comments
 (0)