@@ -11,10 +11,215 @@ class Args:
1111 r"""
1212 The arguments for the finetrainers training script.
1313
14- Args:
15- flow_resolution_shifting (`bool`, defaults to `False`):
16- Resolution-dependant shifting of timestep schedules.
17- [Scaling Rectified Flow Transformers for High-Resolution Image Synthesis](https://arxiv.org/abs/2403.03206)
14+ MODEL ARGUMENTS
15+ ---------------
16+ model_name (`str`):
17+ Name of model to train. To get a list of models, run `python train.py --list_models`.
18+ pretrained_model_name_or_path (`str`):
19+ Path to pretrained model or model identifier from https://huggingface.co/models. The model should be
20+ loadable based on specified `model_name`.
21+ revision (`str`, defaults to `None`):
22+ If provided, the model will be loaded from a specific branch of the model repository.
23+ variant (`str`, defaults to `None`):
24+ Variant of model weights to use. Some models provide weight variants, such as `fp16`, to reduce disk
25+ storage requirements.
26+ cache_dir (`str`, defaults to `None`):
27+ The directory where the downloaded models and datasets will be stored, or loaded from.
28+ text_encoder_dtype (`torch.dtype`, defaults to `torch.bfloat16`):
29+ Data type for the text encoder when generating text embeddings.
30+ text_encoder_2_dtype (`torch.dtype`, defaults to `torch.bfloat16`):
31+ Data type for the text encoder 2 when generating text embeddings.
32+ text_encoder_3_dtype (`torch.dtype`, defaults to `torch.bfloat16`):
33+ Data type for the text encoder 3 when generating text embeddings.
34+ transformer_dtype (`torch.dtype`, defaults to `torch.bfloat16`):
35+ Data type for the transformer model.
36+ vae_dtype (`torch.dtype`, defaults to `torch.bfloat16`):
37+ Data type for the VAE model.
38+
39+ DATASET ARGUMENTS
40+ -----------------
41+ data_root (`str`):
42+ A folder containing the training data.
43+ dataset_file (`str`, defaults to `None`):
44+ Path to a CSV/JSON/JSONL file containing metadata for training. This should be provided if you're not using
45+ a directory dataset format containing a simple `prompts.txt` and `videos.txt`/`images.txt` for example.
46+ video_column (`str`):
47+ The column of the dataset containing videos. Or, the name of the file in `data_root` folder containing the
48+ line-separated path to video data.
49+ caption_column (`str`):
50+ The column of the dataset containing the instance prompt for each video. Or, the name of the file in
51+ `data_root` folder containing the line-separated instance prompts.
52+ id_token (`str`, defaults to `None`):
53+ Identifier token appended to the start of each prompt if provided. This is useful for LoRA-type training.
54+ image_resolution_buckets (`List[Tuple[int, int]]`, defaults to `None`):
55+ Resolution buckets for images. This should be a list of integer tuples, where each tuple represents the
56+ resolution (height, width) of the image. All images will be resized to the nearest bucket resolution.
57+ video_resolution_buckets (`List[Tuple[int, int, int]]`, defaults to `None`):
58+ Resolution buckets for videos. This should be a list of integer tuples, where each tuple represents the
59+ resolution (num_frames, height, width) of the video. All videos will be resized to the nearest bucket
60+ resolution.
61+ video_reshape_mode (`str`, defaults to `None`):
62+ All input videos are reshaped to this mode. Choose between ['center', 'random', 'none'].
63+ TODO(aryan): We don't support this.
64+ caption_dropout_p (`float`, defaults to `0.00`):
65+ Probability of dropout for the caption tokens. This is useful to improve the unconditional generation
66+ quality of the model.
67+ caption_dropout_technique (`str`, defaults to `empty`):
68+ Technique to use for caption dropout. Choose between ['empty', 'zero']. Some models apply caption dropout
69+ by setting the prompt condition to an empty string, while others zero-out the text embedding tensors.
70+ precompute_conditions (`bool`, defaults to `False`):
71+ Whether or not to precompute the conditionings for the model. This is useful for faster training, and
72+ reduces the memory requirements.
73+ remove_common_llm_caption_prefixes (`bool`, defaults to `False`):
74+ Whether or not to remove common LLM caption prefixes. This is useful for improving the quality of the
75+ generated text.
76+
77+ DATALOADER_ARGUMENTS
78+ --------------------
79+ See https://pytorch.org/docs/stable/data.html for more information.
80+
81+ dataloader_num_workers (`int`, defaults to `0`):
82+ Number of subprocesses to use for data loading. `0` means that the data will be loaded in a blocking manner
83+ on the main process.
84+ pin_memory (`bool`, defaults to `False`):
85+ Whether or not to use the pinned memory setting in PyTorch dataloader. This is useful for faster data loading.
86+
87+ DIFFUSION ARGUMENTS
88+ -------------------
89+ flow_resolution_shifting (`bool`, defaults to `False`):
90+ Resolution-dependent shifting of timestep schedules.
91+ [Scaling Rectified Flow Transformers for High-Resolution Image Synthesis](https://arxiv.org/abs/2403.03206).
92+ TODO(aryan): We don't support this yet.
93+ flow_base_seq_len (`int`, defaults to `256`):
94+ Base number of tokens for images/video when applying resolution-dependent shifting.
95+ flow_max_seq_len (`int`, defaults to `4096`):
96+ Maximum number of tokens for images/video when applying resolution-dependent shifting.
97+ flow_base_shift (`float`, defaults to `0.5`):
98+ Base shift for timestep schedules when applying resolution-dependent shifting.
99+ flow_max_shift (`float`, defaults to `1.15`):
100+ Maximum shift for timestep schedules when applying resolution-dependent shifting.
101+ flow_shift (`float`, defaults to `1.0`):
102+ Instead of training with uniform/logit-normal sigmas, shift them as (shift * sigma) / (1 + (shift - 1) * sigma).
103+ Setting it higher is helpful when trying to train models for high-resolution generation or to produce better
104+ samples in lower number of inference steps.
105+ flow_weighting_scheme (`str`, defaults to `none`):
106+ We default to the "none" weighting scheme for uniform sampling and uniform loss.
107+ Choose between ['sigma_sqrt', 'logit_normal', 'mode', 'cosmap', 'none'].
108+ flow_logit_mean (`float`, defaults to `0.0`):
109+ Mean to use when using the `'logit_normal'` weighting scheme.
110+ flow_logit_std (`float`, defaults to `1.0`):
111+ Standard deviation to use when using the `'logit_normal'` weighting scheme.
112+ flow_mode_scale (`float`, defaults to `1.29`):
113+ Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`.
114+
115+ TRAINING ARGUMENTS
116+ ------------------
117+ training_type (`str`, defaults to `None`):
118+ Type of training to perform. Choose between ['lora'].
119+ seed (`int`, defaults to `42`):
120+ A seed for reproducible training.
121+ mixed_precision (`str`, defaults to `None`):
122+ Whether to use mixed precision. Choose between ['no', 'fp8', 'fp16', 'bf16'].
123+ batch_size (`int`, defaults to `1`):
124+ Per-device batch size.
125+ train_epochs (`int`, defaults to `1`):
126+ Number of training epochs.
127+ train_steps (`int`, defaults to `None`):
128+ Total number of training steps to perform. If provided, overrides `train_epochs`.
129+ rank (`int`, defaults to `128`):
130+ The rank for LoRA matrices.
131+ lora_alpha (`float`, defaults to `64`):
132+ The lora_alpha to compute scaling factor (lora_alpha / rank) for LoRA matrices.
133+ target_modules (`List[str]`, defaults to `["to_k", "to_q", "to_v", "to_out.0"]`):
134+ The target modules for LoRA. Make sure to modify this based on the model.
135+ gradient_accumulation_steps (`int`, defaults to `1`):
136+ Number of gradients steps to accumulate before performing an optimizer step.
137+ gradient_checkpointing (`bool`, defaults to `False`):
138+ Whether or not to use gradient/activation checkpointing to save memory at the expense of slower
139+ backward pass.
140+ checkpointing_steps (`int`, defaults to `500`):
141+ Save a checkpoint of the training state every X training steps. These checkpoints can be used both
142+ as final checkpoints in case they are better than the last checkpoint, and are also suitable for
143+ resuming training using `resume_from_checkpoint`.
144+ checkpointing_limit (`int`, defaults to `None`):
145+ Max number of checkpoints to store.
146+ resume_from_checkpoint (`str`, defaults to `None`):
147+ Whether training should be resumed from a previous checkpoint. Use a path saved by `checkpointing_steps`,
148+ or `"latest"` to automatically select the last available checkpoint.
149+
150+ OPTIMIZER ARGUMENTS
151+ -------------------
152+ optimizer (`str`, defaults to `adamw`):
153+ The optimizer type to use. Choose between ['adam', 'adamw'].
154+ use_8bit_bnb (`bool`, defaults to `False`):
155+ Whether to use 8bit variant of the `optimizer` using `bitsandbytes`.
156+ lr (`float`, defaults to `1e-4`):
157+ Initial learning rate (after the potential warmup period) to use.
158+ scale_lr (`bool`, defaults to `False`):
159+ Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.
160+ lr_scheduler (`str`, defaults to `cosine_with_restarts`):
161+ The scheduler type to use. Choose between ['linear', 'cosine', 'cosine_with_restarts', 'polynomial',
162+ 'constant', 'constant_with_warmup'].
163+ lr_warmup_steps (`int`, defaults to `500`):
164+ Number of steps for the warmup in the lr scheduler.
165+ lr_num_cycles (`int`, defaults to `1`):
166+ Number of hard resets of the lr in cosine_with_restarts scheduler.
167+ lr_power (`float`, defaults to `1.0`):
168+ Power factor of the polynomial scheduler.
169+ beta1 (`float`, defaults to `0.9`):
170+ beta2 (`float`, defaults to `0.95`):
171+ beta3 (`float`, defaults to `0.999`):
172+ weight_decay (`float`, defaults to `0.0001`):
173+ Penalty for large weights in the model.
174+ epsilon (`float`, defaults to `1e-8`):
175+ Small value to avoid division by zero in the optimizer.
176+ max_grad_norm (`float`, defaults to `1.0`):
177+ Maximum gradient norm to clip the gradients.
178+
179+ VALIDATION ARGUMENTS
180+ --------------------
181+ validation_prompts (`List[str]`, defaults to `None`):
182+ List of prompts to use for validation. If not provided, a random prompt will be selected from the training
183+ dataset.
184+ validation_images (`List[str]`, defaults to `None`):
185+ List of image paths to use for validation.
186+ validation_videos (`List[str]`, defaults to `None`):
187+ List of video paths to use for validation.
188+ validation_heights (`List[int]`, defaults to `None`):
189+ List of heights for the validation videos.
190+ validation_widths (`List[int]`, defaults to `None`):
191+ List of widths for the validation videos.
192+ validation_num_frames (`List[int]`, defaults to `None`):
193+ List of number of frames for the validation videos.
194+ num_validation_videos_per_prompt (`int`, defaults to `1`):
195+ Number of videos to use for validation per prompt.
196+ validation_every_n_epochs (`int`, defaults to `None`):
197+ Perform validation every `n` training epochs.
198+ validation_every_n_steps (`int`, defaults to `None`):
199+ Perform validation every `n` training steps.
200+ enable_model_cpu_offload (`bool`, defaults to `False`):
201+ Whether or not to offload different modeling components to CPU during validation.
202+
203+ MISCELLANEOUS ARGUMENTS
204+ -----------------------
205+ tracker_name (`str`, defaults to `finetrainers`):
206+ Name of the tracker/project to use for logging training metrics.
207+ push_to_hub (`bool`, defaults to `False`):
208+ Whether or not to push the model to the Hugging Face Hub.
209+ hub_token (`str`, defaults to `None`):
210+ The API token to use for pushing the model to the Hugging Face Hub.
211+ hub_model_id (`str`, defaults to `None`):
212+ The model identifier to use for pushing the model to the Hugging Face Hub.
213+ output_dir (`str`, defaults to `None`):
214+ The directory where the model checkpoints and logs will be stored.
215+ logging_dir (`str`, defaults to `logs`):
216+ The directory where the logs will be stored.
217+ allow_tf32 (`bool`, defaults to `False`):
218+ Whether or not to allow the use of TF32 matmul on compatible hardware.
219+ nccl_timeout (`int`, defaults to `1800`):
220+ Timeout for the NCCL communication.
221+ report_to (`str`, defaults to `wandb`):
222+ The name of the logger to use for logging training metrics. Choose between ['wandb'].
18223 """
19224
20225 # Model arguments
@@ -390,7 +595,7 @@ def _add_diffusion_arguments(parser: argparse.ArgumentParser) -> None:
390595 parser .add_argument (
391596 "--flow_resolution_shifting" ,
392597 action = "store_true" ,
393- help = "Resolution-dependant shifting of timestep schedules." ,
598+ help = "Resolution-dependent shifting of timestep schedules." ,
394599 )
395600 parser .add_argument (
396601 "--flow_base_seq_len" ,
0 commit comments