Skip to content

Commit 91c5879

Browse files
committed
merge conflicts.
2 parents 6106fc1 + 586f8f7 commit 91c5879

File tree

5 files changed

+57
-5
lines changed

5 files changed

+57
-5
lines changed

README.md

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,19 @@ Summary of the optimizations:
1818
* `coordinate_descent_check_all_directions = True`
1919
* `torch.export` + Ahead-of-time Inductor (AOTI) + CUDAGraphs
2020

21+
All of the above optimizations are lossless (outside of minor numerical differences sometimes
22+
introduced through the use of `torch.compile` / `torch.export`) EXCEPT FOR dynamic float8 quantization.
23+
Disable quantization if you want the same quality results as the baseline while still being
24+
quite a bit faster.
25+
26+
Here are some example outputs for prompt `"A cat playing with a ball of yarn"`:
27+
28+
**Baseline:**
29+
![baseline_output](https://github.com/user-attachments/assets/8ba746d2-fbf3-4e30-adc4-11303231c146)
30+
31+
**Fully-optimized (with quantization):**
32+
![fast_output](https://github.com/user-attachments/assets/1a31dec4-38d5-45b2-8ae6-c7fb2e6413a4)
33+
2134
## Setup
2235
We rely primarily on pure PyTorch for the optimizations. Currently, a relatively recent nightly version of PyTorch is required.
2336
The numbers reported here were gathered using:

gen_image.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
import random
2+
import time
3+
import torch
4+
from torch.profiler import profile, record_function, ProfilerActivity
5+
from utils.benchmark_utils import annotate, create_parser
6+
from utils.pipeline_utils import load_pipeline # noqa: E402
7+
8+
9+
def set_rand_seeds(seed):
10+
random.seed(seed)
11+
torch.manual_seed(seed)
12+
13+
14+
def main(args):
15+
pipeline = load_pipeline(args)
16+
set_rand_seeds(args.seed)
17+
18+
image = pipeline(
19+
args.prompt, num_inference_steps=args.num_inference_steps, guidance_scale=0.0
20+
).images[0]
21+
image.save(args.output_file)
22+
23+
24+
if __name__ == "__main__":
25+
parser = create_parser()
26+
args = parser.parse_args()
27+
# use the cached model to minimize latency
28+
args.use_cached_model = True
29+
main(args)

run_benchmark.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,15 @@ def _determine_pipe_call_kwargs(args):
1212
kwargs = {"max_sequence_length": 512, "guidance_scale": 3.5}
1313
return kwargs
1414

15+
def set_rand_seeds(seed):
16+
random.seed(seed)
17+
torch.manual_seed(seed)
18+
19+
1520
def main(args):
21+
set_rand_seeds(args.seed)
1622
pipeline = load_pipeline(args)
23+
set_rand_seeds(args.seed)
1724

1825
# warmup
1926
for _ in range(3):
@@ -66,10 +73,6 @@ def main(args):
6673

6774

6875
if __name__ == "__main__":
69-
seed = 42
70-
random.seed(seed)
71-
torch.manual_seed(seed)
72-
7376
parser = create_parser()
7477
args = parser.parse_args()
7578
main(args)

utils/benchmark_utils.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,16 @@ def create_parser():
1515
help="Text prompt")
1616
parser.add_argument("--cache-dir", type=str, default=os.path.expandvars("$HOME/.cache/flux-fast"),
1717
help="Cache directory for storing exported models")
18+
parser.add_argument("--use-cached-model", action="store_true",
19+
help="Attempt to use cached model only (don't re-export)")
1820
parser.add_argument("--device", type=str, choices=["cuda", "cpu"], default="cuda",
1921
help="Device to use")
2022
parser.add_argument("--num_inference_steps", type=int, default=4,
2123
help="Number of denoising steps")
2224
parser.add_argument("--output-file", type=str, default="output.png",
2325
help="Output image file path")
26+
parser.add_argument("--seed", type=int, default=42,
27+
help="Random seed to use")
2428
# file path for optional output PyTorch Profiler trace
2529
parser.add_argument("--trace-file", type=str, default=None,
2630
help="Output PyTorch Profiler trace file path")

utils/pipeline_utils.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -387,7 +387,10 @@ def optimize(pipeline, args):
387387
elif args.compile_export_mode == "export_aoti":
388388
# NB: Using a cached export + AOTI model is not supported yet
389389
pipeline = use_export_aoti(
390-
pipeline, cache_dir=args.cache_dir, serialize=True, is_timestep_distilled=is_timestep_distilled
390+
pipeline,
391+
cache_dir=args.cache_dir,
392+
serialize=(not args.use_cached_model),
393+
is_timestep_distilled=is_timestep_distilled
391394
)
392395
elif args.compile_export_mode == "disabled":
393396
pass

0 commit comments

Comments
 (0)