Skip to content

Commit d3b797e

Browse files
authored
Improve argument handling (#209)
1 parent 81bf590 commit d3b797e

File tree

1 file changed

+11
-17
lines changed

1 file changed

+11
-17
lines changed

finetrainers/args.py

Lines changed: 11 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -455,7 +455,7 @@ def _add_model_arguments(parser: argparse.ArgumentParser) -> None:
455455
parser.add_argument(
456456
"--pretrained_model_name_or_path",
457457
type=str,
458-
default=None,
458+
required=True,
459459
help="Path to pretrained model or model identifier from huggingface.co/models.",
460460
)
461461
parser.add_argument(
@@ -505,7 +505,7 @@ def parse_video_resolution_bucket(resolution_bucket: str) -> Tuple[int, int, int
505505
parser.add_argument(
506506
"--data_root",
507507
type=str,
508-
default=None,
508+
required=True,
509509
help=("A folder containing the training data."),
510510
)
511511
parser.add_argument(
@@ -632,19 +632,19 @@ def _add_diffusion_arguments(parser: argparse.ArgumentParser) -> None:
632632
type=str,
633633
default="none",
634634
choices=["sigma_sqrt", "logit_normal", "mode", "cosmap", "none"],
635-
help=('We default to the "none" weighting scheme for uniform sampling and uniform loss'),
635+
help='We default to the "none" weighting scheme for uniform sampling and uniform loss',
636636
)
637637
parser.add_argument(
638638
"--flow_logit_mean",
639639
type=float,
640640
default=0.0,
641-
help="mean to use when using the `'logit_normal'` weighting scheme.",
641+
help="Mean to use when using the `'logit_normal'` weighting scheme.",
642642
)
643643
parser.add_argument(
644644
"--flow_logit_std",
645645
type=float,
646646
default=1.0,
647-
help="std to use when using the `'logit_normal'` weighting scheme.",
647+
help="Standard deviation to use when using the `'logit_normal'` weighting scheme.",
648648
)
649649
parser.add_argument(
650650
"--flow_mode_scale",
@@ -659,7 +659,7 @@ def _add_training_arguments(parser: argparse.ArgumentParser) -> None:
659659
parser.add_argument(
660660
"--training_type",
661661
type=str,
662-
default=None,
662+
required=True,
663663
help="Type of training to perform. Choose between ['lora']",
664664
)
665665
parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
@@ -676,10 +676,10 @@ def _add_training_arguments(parser: argparse.ArgumentParser) -> None:
676676
parser.add_argument(
677677
"--batch_size",
678678
type=int,
679-
default=4,
679+
default=1,
680680
help="Batch size (per device) for the training dataloader.",
681681
)
682-
parser.add_argument("--train_epochs", type=int, default=1)
682+
parser.add_argument("--train_epochs", type=int, default=1, help="Number of training epochs.")
683683
parser.add_argument(
684684
"--train_steps",
685685
type=int,
@@ -735,13 +735,11 @@ def _add_training_arguments(parser: argparse.ArgumentParser) -> None:
735735
parser.add_argument(
736736
"--enable_slicing",
737737
action="store_true",
738-
default=False,
739738
help="Whether or not to use VAE slicing for saving memory.",
740739
)
741740
parser.add_argument(
742741
"--enable_tiling",
743742
action="store_true",
744-
default=False,
745743
help="Whether or not to use VAE tiling for saving memory.",
746744
)
747745

@@ -756,7 +754,6 @@ def _add_optimizer_arguments(parser: argparse.ArgumentParser) -> None:
756754
parser.add_argument(
757755
"--scale_lr",
758756
action="store_true",
759-
default=False,
760757
help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
761758
)
762759
parser.add_argument(
@@ -877,7 +874,6 @@ def _add_validation_arguments(parser: argparse.ArgumentParser) -> None:
877874
parser.add_argument(
878875
"--enable_model_cpu_offload",
879876
action="store_true",
880-
default=False,
881877
help="Whether or not to enable model-wise CPU offloading when performing validation/testing to save memory.",
882878
)
883879

@@ -904,7 +900,7 @@ def _add_miscellaneous_arguments(parser: argparse.ArgumentParser) -> None:
904900
parser.add_argument(
905901
"--output_dir",
906902
type=str,
907-
default="finetrainer-training",
903+
default="finetrainers-training",
908904
help="The output directory where the model predictions and checkpoints will be written.",
909905
)
910906
parser.add_argument(
@@ -931,10 +927,8 @@ def _add_miscellaneous_arguments(parser: argparse.ArgumentParser) -> None:
931927
"--report_to",
932928
type=str,
933929
default="none",
934-
help=(
935-
'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
936-
' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
937-
),
930+
choices=["none", "wandb"],
931+
help="The integration to report the results and logs to.",
938932
)
939933

940934

0 commit comments

Comments
 (0)