11import argparse
2+ import sys
23from typing import Any , Dict , List , Optional , Tuple
34
45import torch
@@ -11,6 +12,13 @@ class Args:
1112 r"""
1213 The arguments for the finetrainers training script.
1314
15+ For helpful information about arguments, run `python train.py --help`.
16+
17+ TODO(aryan): add `python train.py --recommend_configs --model_name <model_name>` to recommend
18+ good training configs for a model after extensive testing.
19+ TODO(aryan): add `python train.py --memory_requirements --model_name <model_name>` to show
20+ memory requirements per model, per training type with sensible training settings.
21+
1422 MODEL ARGUMENTS
1523 ---------------
1624 model_name (`str`):
@@ -424,20 +432,31 @@ def to_dict(self) -> Dict[str, Any]:
424432 }
425433
426434
435+ # TODO(aryan): handle more informative messages
436+ _IS_ARGUMENTS_REQUIRED = "--list_models" not in sys .argv
437+
438+
427439def parse_arguments () -> Args :
428440 parser = argparse .ArgumentParser ()
429441
430- _add_model_arguments (parser )
431- _add_dataset_arguments (parser )
432- _add_dataloader_arguments (parser )
433- _add_diffusion_arguments (parser )
434- _add_training_arguments (parser )
435- _add_optimizer_arguments (parser )
436- _add_validation_arguments (parser )
437- _add_miscellaneous_arguments (parser )
442+ if _IS_ARGUMENTS_REQUIRED :
443+ _add_model_arguments (parser )
444+ _add_dataset_arguments (parser )
445+ _add_dataloader_arguments (parser )
446+ _add_diffusion_arguments (parser )
447+ _add_training_arguments (parser )
448+ _add_optimizer_arguments (parser )
449+ _add_validation_arguments (parser )
450+ _add_miscellaneous_arguments (parser )
451+
452+ args = parser .parse_args ()
453+ return _map_to_args_type (args )
454+ else :
455+ _add_helper_arguments (parser )
438456
439- args = parser .parse_args ()
440- return _map_to_args_type (args )
457+ args = parser .parse_args ()
458+ _display_helper_messages (args )
459+ sys .exit (0 )
441460
442461
443462def validate_args (args : Args ):
@@ -932,6 +951,14 @@ def _add_miscellaneous_arguments(parser: argparse.ArgumentParser) -> None:
932951 )
933952
934953
954+ def _add_helper_arguments (parser : argparse .ArgumentParser ) -> None :
955+ parser .add_argument (
956+ "--list_models" ,
957+ action = "store_true" ,
958+ help = "List all the supported models." ,
959+ )
960+
961+
935962_DTYPE_MAP = {
936963 "bf16" : torch .bfloat16 ,
937964 "fp16" : torch .float16 ,
@@ -1089,3 +1116,10 @@ def _validate_validation_args(args: Args):
10891116 assert len (args .validation_prompts ) == len (
10901117 args .validation_widths
10911118 ), "Validation prompts and widths should be of same length"
1119+
1120+
1121+ def _display_helper_messages (args : argparse .Namespace ):
1122+ if args .list_models :
1123+ print ("Supported models:" )
1124+ for index , model_name in enumerate (SUPPORTED_MODEL_CONFIGS .keys ()):
1125+ print (f" { index + 1 } . { model_name } " )
0 commit comments