Skip to content

Commit a389807

Browse files
authored
Helpful mesages (#210)
1 parent d3b797e commit a389807

File tree

1 file changed

+44
-10
lines changed

1 file changed

+44
-10
lines changed

finetrainers/args.py

Lines changed: 44 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import argparse
2+
import sys
23
from typing import Any, Dict, List, Optional, Tuple
34

45
import torch
@@ -11,6 +12,13 @@ class Args:
1112
r"""
1213
The arguments for the finetrainers training script.
1314
15+
For helpful information about arguments, run `python train.py --help`.
16+
17+
TODO(aryan): add `python train.py --recommend_configs --model_name <model_name>` to recommend
18+
good training configs for a model after extensive testing.
19+
TODO(aryan): add `python train.py --memory_requirements --model_name <model_name>` to show
20+
memory requirements per model, per training type with sensible training settings.
21+
1422
MODEL ARGUMENTS
1523
---------------
1624
model_name (`str`):
@@ -424,20 +432,31 @@ def to_dict(self) -> Dict[str, Any]:
424432
}
425433

426434

435+
# TODO(aryan): handle more informative messages
436+
_IS_ARGUMENTS_REQUIRED = "--list_models" not in sys.argv
437+
438+
427439
def parse_arguments() -> Args:
428440
parser = argparse.ArgumentParser()
429441

430-
_add_model_arguments(parser)
431-
_add_dataset_arguments(parser)
432-
_add_dataloader_arguments(parser)
433-
_add_diffusion_arguments(parser)
434-
_add_training_arguments(parser)
435-
_add_optimizer_arguments(parser)
436-
_add_validation_arguments(parser)
437-
_add_miscellaneous_arguments(parser)
442+
if _IS_ARGUMENTS_REQUIRED:
443+
_add_model_arguments(parser)
444+
_add_dataset_arguments(parser)
445+
_add_dataloader_arguments(parser)
446+
_add_diffusion_arguments(parser)
447+
_add_training_arguments(parser)
448+
_add_optimizer_arguments(parser)
449+
_add_validation_arguments(parser)
450+
_add_miscellaneous_arguments(parser)
451+
452+
args = parser.parse_args()
453+
return _map_to_args_type(args)
454+
else:
455+
_add_helper_arguments(parser)
438456

439-
args = parser.parse_args()
440-
return _map_to_args_type(args)
457+
args = parser.parse_args()
458+
_display_helper_messages(args)
459+
sys.exit(0)
441460

442461

443462
def validate_args(args: Args):
@@ -932,6 +951,14 @@ def _add_miscellaneous_arguments(parser: argparse.ArgumentParser) -> None:
932951
)
933952

934953

954+
def _add_helper_arguments(parser: argparse.ArgumentParser) -> None:
955+
parser.add_argument(
956+
"--list_models",
957+
action="store_true",
958+
help="List all the supported models.",
959+
)
960+
961+
935962
_DTYPE_MAP = {
936963
"bf16": torch.bfloat16,
937964
"fp16": torch.float16,
@@ -1089,3 +1116,10 @@ def _validate_validation_args(args: Args):
10891116
assert len(args.validation_prompts) == len(
10901117
args.validation_widths
10911118
), "Validation prompts and widths should be of same length"
1119+
1120+
1121+
def _display_helper_messages(args: argparse.Namespace):
1122+
if args.list_models:
1123+
print("Supported models:")
1124+
for index, model_name in enumerate(SUPPORTED_MODEL_CONFIGS.keys()):
1125+
print(f" {index + 1}. {model_name}")

0 commit comments

Comments
 (0)