forked from vipshop/cache-dit
-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathrun_flux_ao.py
More file actions
84 lines (62 loc) · 1.85 KB
/
run_flux_ao.py
File metadata and controls
84 lines (62 loc) · 1.85 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import os
import sys
sys.path.append("..")
import time
import torch
from diffusers import FluxPipeline, FluxTransformer2DModel
from utils import get_args, strify, cachify, MemoryTracker
import cache_dit
args = get_args()
print(args)
pipe: FluxPipeline = FluxPipeline.from_pretrained(
(
args.model_path
if args.model_path is not None
else os.environ.get(
"FLUX_DIR",
"black-forest-labs/FLUX.1-dev",
)
),
torch_dtype=torch.bfloat16,
).to("cuda")
if args.cache:
cachify(args, pipe)
if args.quantize:
assert isinstance(pipe.transformer, FluxTransformer2DModel)
pipe.transformer = cache_dit.quantize(
pipe.transformer,
quant_type=args.quantize_type,
)
# Set default prompt
prompt = "A cat holding a sign that says hello world"
if args.prompt is not None:
prompt = args.prompt
def run_pipe(warmup: bool = False):
image = pipe(
prompt,
width=1024 if args.width is None else args.width,
height=1024 if args.height is None else args.height,
num_inference_steps=((28 if args.steps is None else args.steps) if not warmup else 5),
generator=torch.Generator("cpu").manual_seed(0),
).images[0]
return image
if args.compile:
assert isinstance(pipe.transformer, FluxTransformer2DModel)
pipe.transformer.compile_repeated_blocks()
# warmup
_ = run_pipe(warmup=True)
memory_tracker = MemoryTracker() if args.track_memory else None
if memory_tracker:
memory_tracker.__enter__()
start = time.time()
image = run_pipe()
end = time.time()
if memory_tracker:
memory_tracker.__exit__(None, None, None)
memory_tracker.report()
cache_dit.summary(pipe)
time_cost = end - start
save_path = f"flux.ao.{strify(args, pipe)}.png"
print(f"Time cost: {time_cost:.2f}s")
print(f"Saving image to {save_path}")
image.save(save_path)