forked from vipshop/cache-dit
-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathrun_flux_nunchaku.py
More file actions
115 lines (95 loc) · 2.9 KB
/
run_flux_nunchaku.py
File metadata and controls
115 lines (95 loc) · 2.9 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import os
import sys
sys.path.append("..")
import time
import torch
from diffusers import FluxPipeline, FluxTransformer2DModel
from nunchaku.models.transformers.transformer_flux_v2 import (
NunchakuFluxTransformer2DModelV2,
)
from utils import get_args, strify, MemoryTracker
import cache_dit
args = get_args()
print(args)
nunchaku_flux_dir = os.environ.get(
"NUNCHAKA_FLUX_DIR",
"nunchaku-tech/nunchaku-flux.1-dev",
)
transformer = NunchakuFluxTransformer2DModelV2.from_pretrained(
f"{nunchaku_flux_dir}/svdq-int4_r32-flux.1-dev.safetensors",
)
pipe: FluxPipeline = FluxPipeline.from_pretrained(
(
args.model_path
if args.model_path is not None
else os.environ.get("FLUX_DIR", "black-forest-labs/FLUX.1-dev")
),
transformer=transformer,
torch_dtype=torch.bfloat16,
).to("cuda")
if args.cache:
from cache_dit import (
ParamsModifier,
DBCacheConfig,
TaylorSeerCalibratorConfig,
)
cache_dit.enable_cache(
pipe,
cache_config=DBCacheConfig(
Fn_compute_blocks=args.Fn,
Bn_compute_blocks=args.Bn,
max_warmup_steps=args.max_warmup_steps,
max_cached_steps=args.max_cached_steps,
max_continuous_cached_steps=args.max_continuous_cached_steps,
residual_diff_threshold=args.rdt,
),
calibrator_config=(
TaylorSeerCalibratorConfig(
taylorseer_order=args.taylorseer_order,
)
if args.taylorseer
else None
),
params_modifiers=[
ParamsModifier(
# transformer_blocks
cache_config=DBCacheConfig().reset(residual_diff_threshold=args.rdt),
),
ParamsModifier(
# single_transformer_blocks
cache_config=DBCacheConfig().reset(residual_diff_threshold=args.rdt * 3),
),
],
)
# Set default prompt
prompt = "A cat holding a sign that says hello world"
if args.prompt is not None:
prompt = args.prompt
def run_pipe(pipe: FluxPipeline):
image = pipe(
prompt,
num_inference_steps=28,
generator=torch.Generator("cpu").manual_seed(0),
).images[0]
return image
if args.compile:
assert isinstance(pipe.transformer, FluxTransformer2DModel)
cache_dit.set_compile_configs()
pipe.transformer = torch.compile(pipe.transformer)
# warmup
_ = run_pipe(pipe)
memory_tracker = MemoryTracker() if args.track_memory else None
if memory_tracker:
memory_tracker.__enter__()
start = time.time()
image = run_pipe(pipe)
end = time.time()
if memory_tracker:
memory_tracker.__exit__(None, None, None)
memory_tracker.report()
cache_dit.summary(pipe)
time_cost = end - start
save_path = f"flux.nunchaku.int4.{strify(args, pipe)}.png"
print(f"Time cost: {time_cost:.2f}s")
print(f"Saving image to {save_path}")
image.save(save_path)