|
6 | 6 |
|
7 | 7 | def create_parser(): |
8 | 8 | """Creates CLI args parser.""" |
9 | | - parser = argparse.ArgumentParser() |
| 9 | + parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) |
10 | 10 |
|
11 | 11 | # general options |
12 | | - parser.add_argument("--ckpt", type=str, default="black-forest-labs/FLUX.1-schnell") |
13 | | - parser.add_argument("--prompt", type=str, default="A cat playing with a ball of yarn") |
14 | | - parser.add_argument("--cache-dir", type=str, default=os.path.expandvars("$HOME/.cache/flux-fast")) |
15 | | - parser.add_argument("--device", type=str, choices=["cuda", "cpu"], default="cuda") |
16 | | - parser.add_argument("--num_inference_steps", type=int, default=4) |
17 | | - parser.add_argument("--output-file", type=str, default="output.png") |
| 12 | + parser.add_argument("--ckpt", type=str, default="black-forest-labs/FLUX.1-schnell", |
| 13 | + help="Model checkpoint path") |
| 14 | + parser.add_argument("--prompt", type=str, default="A cat playing with a ball of yarn", |
| 15 | + help="Text prompt") |
| 16 | + parser.add_argument("--cache-dir", type=str, default=os.path.expandvars("$HOME/.cache/flux-fast"), |
| 17 | + help="Cache directory for storing exported models") |
| 18 | + parser.add_argument("--device", type=str, choices=["cuda", "cpu"], default="cuda", |
| 19 | + help="Device to use") |
| 20 | + parser.add_argument("--num_inference_steps", type=int, default=4, |
| 21 | + help="Number of denoising steps") |
| 22 | + parser.add_argument("--output-file", type=str, default="output.png", |
| 23 | + help="Output image file path") |
18 | 24 | # file path for optional output PyTorch Profiler trace |
19 | | - parser.add_argument("--trace-file", type=str, default=None) |
| 25 | + parser.add_argument("--trace-file", type=str, default=None, |
| 26 | + help="Output PyTorch Profiler trace file path") |
20 | 27 |
|
21 | 28 | # optimizations - all are on by default but each can be disabled |
22 | | - parser.add_argument("--disable_bf16", action="store_true") |
| 29 | + parser.add_argument("--disable_bf16", action="store_true", |
| 30 | + help="Disables usage of torch.bfloat16") |
23 | 31 | # torch.compile OR torch.export + AOTI OR neither |
24 | 32 | parser.add_argument("--compile_export_mode", type=str, default="export_aoti", |
25 | | - choices=["compile", "export_aoti", "disabled"]) |
| 33 | + choices=["compile", "export_aoti", "disabled"], |
| 34 | + help="Configures how torch.compile or torch.export + AOTI are used") |
26 | 35 | # fused (q, k, v) projections |
27 | | - parser.add_argument("--disable_fused_projections", action="store_true") |
| 36 | + parser.add_argument("--disable_fused_projections", action="store_true", |
| 37 | + help="Disables fused q,k,v projections") |
28 | 38 | # channels_last memory format |
29 | | - parser.add_argument("--disable_channels_last", action="store_true") |
| 39 | + parser.add_argument("--disable_channels_last", action="store_true", |
| 40 | + help="Disables usage of torch.channels_last memory format") |
30 | 41 | # Flash Attention v3 |
31 | | - parser.add_argument("--disable_fa3", action="store_true") |
| 42 | + parser.add_argument("--disable_fa3", action="store_true", |
| 43 | + help="Disables use of Flash Attention V3") |
32 | 44 | # dynamic float8 quantization |
33 | | - parser.add_argument("--disable_quant", action="store_true") |
| 45 | + parser.add_argument("--disable_quant", action="store_true", |
| 46 | + help="Disables usage of dynamic float8 quantization") |
34 | 47 | # flags for tuning inductor |
35 | | - parser.add_argument("--disable_inductor_tuning_flags", action="store_true") |
| 48 | + parser.add_argument("--disable_inductor_tuning_flags", action="store_true", |
| 49 | + help="Disables use of inductor tuning flags") |
36 | 50 | return parser |
37 | 51 |
|
38 | 52 |
|
|
0 commit comments