|
| 1 | +""" |
| 2 | +Default values taken from |
| 3 | +https://github.com/genmoai/mochi/blob/aba74c1b5e0755b1fa3343d9e4bd22e89de77ab1/demos/fine_tuner/configs/lora.yaml |
| 4 | +when applicable. |
| 5 | +""" |
| 6 | + |
| 7 | +import argparse |
| 8 | + |
| 9 | + |
| 10 | +def _get_model_args(parser: argparse.ArgumentParser) -> None: |
| 11 | + parser.add_argument( |
| 12 | + "--pretrained_model_name_or_path", |
| 13 | + type=str, |
| 14 | + default=None, |
| 15 | + required=True, |
| 16 | + help="Path to pretrained model or model identifier from huggingface.co/models.", |
| 17 | + ) |
| 18 | + parser.add_argument( |
| 19 | + "--revision", |
| 20 | + type=str, |
| 21 | + default=None, |
| 22 | + required=False, |
| 23 | + help="Revision of pretrained model identifier from huggingface.co/models.", |
| 24 | + ) |
| 25 | + parser.add_argument( |
| 26 | + "--variant", |
| 27 | + type=str, |
| 28 | + default=None, |
| 29 | + help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16", |
| 30 | + ) |
| 31 | + parser.add_argument( |
| 32 | + "--cache_dir", |
| 33 | + type=str, |
| 34 | + default=None, |
| 35 | + help="The directory where the downloaded models and datasets will be stored.", |
| 36 | + ) |
| 37 | + parser.add_argument( |
| 38 | + "--cast_dit", |
| 39 | + action="store_true", |
| 40 | + help="If we should cast DiT params to a lower precision.", |
| 41 | + ) |
| 42 | + parser.add_argument( |
| 43 | + "--compile_dit", |
| 44 | + action="store_true", |
| 45 | + help="If we should compile the DiT.", |
| 46 | + ) |
| 47 | + |
| 48 | + |
| 49 | +def _get_dataset_args(parser: argparse.ArgumentParser) -> None: |
| 50 | + parser.add_argument( |
| 51 | + "--data_root", |
| 52 | + type=str, |
| 53 | + default=None, |
| 54 | + help=("A folder containing the training data."), |
| 55 | + ) |
| 56 | + parser.add_argument( |
| 57 | + "--caption_dropout", |
| 58 | + type=float, |
| 59 | + default=None, |
| 60 | + help=("Probability to drop out captions randomly."), |
| 61 | + ) |
| 62 | + |
| 63 | + parser.add_argument( |
| 64 | + "--dataloader_num_workers", |
| 65 | + type=int, |
| 66 | + default=0, |
| 67 | + help="Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.", |
| 68 | + ) |
| 69 | + parser.add_argument( |
| 70 | + "--pin_memory", |
| 71 | + action="store_true", |
| 72 | + help="Whether or not to use the pinned memory setting in pytorch dataloader.", |
| 73 | + ) |
| 74 | + |
| 75 | + |
| 76 | +def _get_validation_args(parser: argparse.ArgumentParser) -> None: |
| 77 | + parser.add_argument( |
| 78 | + "--validation_prompt", |
| 79 | + type=str, |
| 80 | + default=None, |
| 81 | + help="One or more prompt(s) that is used during validation to verify that the model is learning. Multiple validation prompts should be separated by the '--validation_prompt_seperator' string.", |
| 82 | + ) |
| 83 | + parser.add_argument( |
| 84 | + "--validation_images", |
| 85 | + type=str, |
| 86 | + default=None, |
| 87 | + help="One or more image path(s)/URLs that is used during validation to verify that the model is learning. Multiple validation paths should be separated by the '--validation_prompt_seperator' string. These should correspond to the order of the validation prompts.", |
| 88 | + ) |
| 89 | + parser.add_argument( |
| 90 | + "--validation_prompt_separator", |
| 91 | + type=str, |
| 92 | + default=":::", |
| 93 | + help="String that separates multiple validation prompts", |
| 94 | + ) |
| 95 | + parser.add_argument( |
| 96 | + "--num_validation_videos", |
| 97 | + type=int, |
| 98 | + default=1, |
| 99 | + help="Number of videos that should be generated during validation per `validation_prompt`.", |
| 100 | + ) |
| 101 | + parser.add_argument( |
| 102 | + "--validation_epochs", |
| 103 | + type=int, |
| 104 | + default=50, |
| 105 | + help="Run validation every X training steps. Validation consists of running the validation prompt `args.num_validation_videos` times.", |
| 106 | + ) |
| 107 | + parser.add_argument( |
| 108 | + "--enable_slicing", |
| 109 | + action="store_true", |
| 110 | + default=False, |
| 111 | + help="Whether or not to use VAE slicing for saving memory.", |
| 112 | + ) |
| 113 | + parser.add_argument( |
| 114 | + "--enable_tiling", |
| 115 | + action="store_true", |
| 116 | + default=False, |
| 117 | + help="Whether or not to use VAE tiling for saving memory.", |
| 118 | + ) |
| 119 | + parser.add_argument( |
| 120 | + "--enable_model_cpu_offload", |
| 121 | + action="store_true", |
| 122 | + default=False, |
| 123 | + help="Whether or not to enable model-wise CPU offloading when performing validation/testing to save memory.", |
| 124 | + ) |
| 125 | + parser.add_argument( |
| 126 | + "--fps", |
| 127 | + type=int, |
| 128 | + default=30, |
| 129 | + help="FPS to use when serializing the output videos.", |
| 130 | + ) |
| 131 | + parser.add_argument( |
| 132 | + "--height", |
| 133 | + type=int, |
| 134 | + default=480, |
| 135 | + ) |
| 136 | + parser.add_argument( |
| 137 | + "--width", |
| 138 | + type=int, |
| 139 | + default=848, |
| 140 | + ) |
| 141 | + |
| 142 | + |
| 143 | +def _get_training_args(parser: argparse.ArgumentParser) -> None: |
| 144 | + parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") |
| 145 | + parser.add_argument("--rank", type=int, default=16, help="The rank for LoRA matrices.") |
| 146 | + parser.add_argument( |
| 147 | + "--lora_alpha", |
| 148 | + type=int, |
| 149 | + default=16, |
| 150 | + help="The lora_alpha to compute scaling factor (lora_alpha / rank) for LoRA matrices.", |
| 151 | + ) |
| 152 | + parser.add_argument( |
| 153 | + "--target_modules", |
| 154 | + nargs="+", |
| 155 | + type=str, |
| 156 | + default=["to_k", "to_q", "to_v", "to_out.0"], |
| 157 | + help="Target modules to train LoRA for.", |
| 158 | + ) |
| 159 | + parser.add_argument( |
| 160 | + "--output_dir", |
| 161 | + type=str, |
| 162 | + default="mochi-lora", |
| 163 | + help="The output directory where the model predictions and checkpoints will be written.", |
| 164 | + ) |
| 165 | + parser.add_argument( |
| 166 | + "--train_batch_size", |
| 167 | + type=int, |
| 168 | + default=4, |
| 169 | + help="Batch size (per device) for the training dataloader.", |
| 170 | + ) |
| 171 | + parser.add_argument("--num_train_epochs", type=int, default=1) |
| 172 | + parser.add_argument( |
| 173 | + "--max_train_steps", |
| 174 | + type=int, |
| 175 | + default=None, |
| 176 | + help="Total number of training steps to perform. If provided, overrides `--num_train_epochs`.", |
| 177 | + ) |
| 178 | + parser.add_argument( |
| 179 | + "--gradient_checkpointing", |
| 180 | + action="store_true", |
| 181 | + help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", |
| 182 | + ) |
| 183 | + parser.add_argument( |
| 184 | + "--learning_rate", |
| 185 | + type=float, |
| 186 | + default=2e-4, |
| 187 | + help="Initial learning rate (after the potential warmup period) to use.", |
| 188 | + ) |
| 189 | + parser.add_argument( |
| 190 | + "--scale_lr", |
| 191 | + action="store_true", |
| 192 | + help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", |
| 193 | + ) |
| 194 | + parser.add_argument( |
| 195 | + "--lr_warmup_steps", |
| 196 | + type=int, |
| 197 | + default=200, |
| 198 | + help="Number of steps for the warmup in the lr scheduler.", |
| 199 | + ) |
| 200 | + |
| 201 | + |
| 202 | +def _get_optimizer_args(parser: argparse.ArgumentParser) -> None: |
| 203 | + parser.add_argument( |
| 204 | + "--optimizer", |
| 205 | + type=lambda s: s.lower(), |
| 206 | + default="adam", |
| 207 | + choices=["adam", "adamw"], |
| 208 | + help=("The optimizer type to use."), |
| 209 | + ) |
| 210 | + parser.add_argument( |
| 211 | + "--weight_decay", |
| 212 | + type=float, |
| 213 | + default=0.01, |
| 214 | + help="Weight decay to use for optimizer.", |
| 215 | + ) |
| 216 | + |
| 217 | + |
| 218 | +def _get_configuration_args(parser: argparse.ArgumentParser) -> None: |
| 219 | + parser.add_argument("--tracker_name", type=str, default=None, help="Project tracker name") |
| 220 | + parser.add_argument( |
| 221 | + "--push_to_hub", |
| 222 | + action="store_true", |
| 223 | + help="Whether or not to push the model to the Hub.", |
| 224 | + ) |
| 225 | + parser.add_argument( |
| 226 | + "--hub_token", |
| 227 | + type=str, |
| 228 | + default=None, |
| 229 | + help="The token to use to push to the Model Hub.", |
| 230 | + ) |
| 231 | + parser.add_argument( |
| 232 | + "--hub_model_id", |
| 233 | + type=str, |
| 234 | + default=None, |
| 235 | + help="The name of the repository to keep in sync with the local `output_dir`.", |
| 236 | + ) |
| 237 | + parser.add_argument( |
| 238 | + "--allow_tf32", |
| 239 | + action="store_true", |
| 240 | + help=( |
| 241 | + "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see" |
| 242 | + " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices" |
| 243 | + ), |
| 244 | + ) |
| 245 | + parser.add_argument( |
| 246 | + "--report_to", |
| 247 | + type=str, |
| 248 | + default=None, |
| 249 | + help="If logging to wandb." |
| 250 | + ) |
| 251 | + |
| 252 | + |
| 253 | +def get_args(): |
| 254 | + parser = argparse.ArgumentParser(description="Simple example of a training script for Mochi-1.") |
| 255 | + |
| 256 | + _get_model_args(parser) |
| 257 | + _get_dataset_args(parser) |
| 258 | + _get_training_args(parser) |
| 259 | + _get_validation_args(parser) |
| 260 | + _get_optimizer_args(parser) |
| 261 | + _get_configuration_args(parser) |
| 262 | + |
| 263 | + return parser.parse_args() |
0 commit comments