3131
3232_log = logging .getLogger (__name__ )
3333
34- _CLICK_AVAILABLE = RequirementCache ("click " )
34+ _JSONARGPARSE_SIGNATURES_AVAILABLE = RequirementCache ("jsonargparse[signatures]>=4.27.7 " )
3535_LIGHTNING_SDK_AVAILABLE = RequirementCache ("lightning_sdk" )
3636
3737_SUPPORTED_ACCELERATORS = ("cpu" , "gpu" , "cuda" , "mps" , "tpu" , "auto" )
@@ -45,127 +45,160 @@ def _get_supported_strategies() -> list[str]:
4545 return [strategy for strategy in available_strategies if not re .match (excluded , strategy )]
4646
4747
48- if _CLICK_AVAILABLE :
49- import click
50-
51- @click .group ()
52- def _main () -> None :
53- pass
54-
55- @_main .command (
56- "run" ,
57- context_settings = {
58- "ignore_unknown_options" : True ,
59- },
60- )
61- @click .argument (
62- "script" ,
63- type = click .Path (exists = True ),
64- )
65- @click .option (
66- "--accelerator" ,
67- type = click .Choice (_SUPPORTED_ACCELERATORS ),
68- default = None ,
69- help = "The hardware accelerator to run on." ,
70- )
71- @click .option (
72- "--strategy" ,
73- type = click .Choice (_get_supported_strategies ()),
74- default = None ,
75- help = "Strategy for how to run across multiple devices." ,
76- )
77- @click .option (
78- "--devices" ,
79- type = str ,
80- default = "1" ,
81- help = (
82- "Number of devices to run on (``int``), which devices to run on (``list`` or ``str``), or ``'auto'``."
83- " The value applies per node."
84- ),
85- )
86- @click .option (
87- "--num-nodes" ,
88- "--num_nodes" ,
89- type = int ,
90- default = 1 ,
91- help = "Number of machines (nodes) for distributed execution." ,
92- )
93- @click .option (
94- "--node-rank" ,
95- "--node_rank" ,
96- type = int ,
97- default = 0 ,
98- help = (
99- "The index of the machine (node) this command gets started on. Must be a number in the range"
100- " 0, ..., num_nodes - 1."
101- ),
102- )
103- @click .option (
104- "--main-address" ,
105- "--main_address" ,
106- type = str ,
107- default = "127.0.0.1" ,
108- help = "The hostname or IP address of the main machine (usually the one with node_rank = 0)." ,
109- )
110- @click .option (
111- "--main-port" ,
112- "--main_port" ,
113- type = int ,
114- default = 29400 ,
115- help = "The main port to connect to the main machine." ,
116- )
117- @click .option (
118- "--precision" ,
119- type = click .Choice (get_args (_PRECISION_INPUT_STR ) + get_args (_PRECISION_INPUT_STR_ALIAS )),
120- default = None ,
121- help = (
122- "Double precision (``64-true`` or ``64``), full precision (``32-true`` or ``32``), "
123- "half precision (``16-mixed`` or ``16``) or bfloat16 precision (``bf16-mixed`` or ``bf16``)"
124- ),
125- )
126- @click .argument ("script_args" , nargs = - 1 , type = click .UNPROCESSED )
127- def _run (** kwargs : Any ) -> None :
128- """Run a Lightning Fabric script.
129-
130- SCRIPT is the path to the Python script with the code to run. The script must contain a Fabric object.
131-
132- SCRIPT_ARGS are the remaining arguments that you can pass to the script itself and are expected to be parsed
133- there.
134-
135- """
136- script_args = list (kwargs .pop ("script_args" , []))
137- main (args = Namespace (** kwargs ), script_args = script_args )
138-
139- @_main .command (
140- "consolidate" ,
141- context_settings = {
142- "ignore_unknown_options" : True ,
143- },
144- )
145- @click .argument (
146- "checkpoint_folder" ,
147- type = click .Path (exists = True ),
148- )
149- @click .option (
150- "--output_file" ,
151- type = click .Path (exists = True ),
152- default = None ,
153- help = (
154- "Path to the file where the converted checkpoint should be saved. The file should not already exist."
155- " If no path is provided, the file will be saved next to the input checkpoint folder with the same name"
156- " and a '.consolidated' suffix."
157- ),
158- )
159- def _consolidate (checkpoint_folder : str , output_file : Optional [str ]) -> None :
160- """Convert a distributed/sharded checkpoint into a single file that can be loaded with `torch.load()`.
161-
162- Only supports FSDP sharded checkpoints at the moment.
163-
164- """
165- args = Namespace (checkpoint_folder = checkpoint_folder , output_file = output_file )
166- config = _process_cli_args (args )
167- checkpoint = _load_distributed_checkpoint (config .checkpoint_folder )
168- torch .save (checkpoint , config .output_file )
48+ if _JSONARGPARSE_SIGNATURES_AVAILABLE :
49+ from jsonargparse import ArgumentParser , register_unresolvable_import_paths
50+
51+ # Align with pytorch CLI behavior
52+ register_unresolvable_import_paths (torch ) # Required until the upstream PyTorch issue is fixed
53+
54+ try :
55+ from jsonargparse import set_parsing_settings
56+
57+ set_parsing_settings (config_read_mode_fsspec_enabled = True )
58+ except ImportError :
59+ from jsonargparse import set_config_read_mode
60+
61+ set_config_read_mode (fsspec_enabled = True )
62+
63+ class FabricCLI :
64+ """Lightning Fabric command-line tool."""
65+
66+ def __init__ (self , args : Optional [list [str ]] = None , run : bool = True ) -> None :
67+ self .parser = self .init_parser ()
68+ self ._add_subcommands (self .parser )
69+ self .config , self .unknown_args = self .parser .parse_known_args (args )
70+
71+ if run :
72+ self .run ()
73+
74+ def init_parser (self ) -> ArgumentParser :
75+ """Method that instantiates the argument parser."""
76+ return ArgumentParser (prog = "lightning-fabric" , description = self .__class__ .__doc__ )
77+
78+ def _add_subcommands (self , parser : ArgumentParser ) -> None :
79+ """Adds subcommands to the parser."""
80+ subparsers = parser .add_subparsers (dest = "command" , required = True )
81+ self .add_run_subcommand (subparsers )
82+ self .add_consolidate_subcommand (subparsers )
83+
84+ def add_run_subcommand (self , subparsers : Any ) -> None :
85+ """Adds the `run` subcommand to the parser."""
86+ parser = subparsers .add_parser ("run" , help = "Run a Lightning Fabric script." )
87+ parser .add_argument (
88+ "script" ,
89+ type = str ,
90+ help = "Path to the Python script with the code to run. The script must contain a Fabric object." ,
91+ )
92+ parser .add_argument (
93+ "--accelerator" ,
94+ choices = _SUPPORTED_ACCELERATORS ,
95+ default = None ,
96+ help = "The hardware accelerator to run on." ,
97+ )
98+ parser .add_argument (
99+ "--strategy" ,
100+ choices = _get_supported_strategies (),
101+ default = None ,
102+ help = "Strategy for how to run across multiple devices." ,
103+ )
104+ parser .add_argument (
105+ "--devices" ,
106+ type = str ,
107+ default = "1" ,
108+ help = (
109+ "Number of devices to run on (int), which devices to run on (list or str), or 'auto'."
110+ " The value applies per node."
111+ ),
112+ )
113+ parser .add_argument (
114+ "--num-nodes" ,
115+ "--num_nodes" ,
116+ type = int ,
117+ default = 1 ,
118+ help = "Number of machines (nodes) for distributed execution." ,
119+ )
120+ parser .add_argument (
121+ "--node-rank" ,
122+ "--node_rank" ,
123+ type = int ,
124+ default = 0 ,
125+ help = "The index of the machine (node) this command gets started on. Must be 0, ..., num_nodes - 1." ,
126+ )
127+ parser .add_argument (
128+ "--main-address" ,
129+ "--main_address" ,
130+ type = str ,
131+ default = "127.0.0.1" ,
132+ help = "The hostname or IP address of the main machine (usually the one with node_rank = 0)." ,
133+ )
134+ parser .add_argument (
135+ "--main-port" ,
136+ "--main_port" ,
137+ type = int ,
138+ default = 29400 ,
139+ help = "The main port to connect to the main machine." ,
140+ )
141+ parser .add_argument (
142+ "--precision" ,
143+ choices = list (get_args (_PRECISION_INPUT_STR ) + get_args (_PRECISION_INPUT_STR_ALIAS )),
144+ default = None ,
145+ help = (
146+ "Double precision ('64-true' or '64'), full precision ('32-true' or '32'), "
147+ "half precision ('16-mixed' or '16') or bfloat16 precision ('bf16-mixed' or 'bf16')."
148+ ),
149+ )
150+
151+ def add_consolidate_subcommand (self , subparsers : Any ) -> None :
152+ """Adds the `consolidate` subcommand to the parser."""
153+ parser = subparsers .add_parser (
154+ "consolidate" , help = "Convert a distributed/sharded checkpoint into a single file."
155+ )
156+ parser .add_argument (
157+ "checkpoint_folder" ,
158+ type = str ,
159+ help = "Path to the input checkpoint folder." ,
160+ )
161+ parser .add_argument (
162+ "--output_file" ,
163+ type = str ,
164+ default = None ,
165+ help = (
166+ "Path to the file where the converted checkpoint should be saved. The file should not already exist."
167+ " If no path is provided, the file will be saved next to the input checkpoint folder with the same"
168+ " name and a '.consolidated' suffix."
169+ ),
170+ )
171+
172+ def run (self ) -> None :
173+ """Runs the subcommand."""
174+ if self .config .command == "run" :
175+ self ._run_script ()
176+ elif self .config .command == "consolidate" :
177+ self ._consolidate_checkpoint ()
178+
179+ def _run_script (self ) -> None :
180+ """Runs the script with the given arguments."""
181+ config = self .config .run
182+ if not os .path .exists (config .script ):
183+ raise SystemExit (f"Script not found: { config .script } " )
184+
185+ args = Namespace (** vars (config ))
186+ main (args = args , script_args = self .unknown_args )
187+
188+ def _consolidate_checkpoint (self ) -> None :
189+ """Consolidates the checkpoint."""
190+ config = self .config .consolidate
191+ if not os .path .isdir (config .checkpoint_folder ):
192+ raise SystemExit (f"Checkpoint folder not found: { config .checkpoint_folder } " )
193+
194+ args = Namespace (** vars (config ))
195+ processed_args = _process_cli_args (args )
196+ checkpoint = _load_distributed_checkpoint (processed_args .checkpoint_folder )
197+ torch .save (checkpoint , processed_args .output_file )
198+
199+ def _entrypoint () -> None :
200+ """The CLI entrypoint."""
201+ FabricCLI ()
169202
170203
171204def _set_env_variables (args : Namespace ) -> None :
@@ -235,11 +268,11 @@ def main(args: Namespace, script_args: Optional[list[str]] = None) -> None:
235268
236269
237270if __name__ == "__main__" :
238- if not _CLICK_AVAILABLE : # pragma: no cover
271+ if not _JSONARGPARSE_SIGNATURES_AVAILABLE : # pragma: no cover
239272 _log .error (
240- "To use the Lightning Fabric CLI, you must have `click` installed."
241- " Install it by running ` pip install -U click` ."
273+ "To use the Lightning Fabric CLI, you must have 'jsonargparse[signatures]>=4.27.7' installed."
274+ " Install it by running: pip install -U 'jsonargparse[signatures]>=4.27.7' ."
242275 )
243276 raise SystemExit (1 )
244277
245- _run ()
278+ _entrypoint ()
0 commit comments