Skip to content

Commit 9f8112e

Browse files
committed
support teacache-hunyuanvideo
1 parent d9fad82 commit 9f8112e

File tree

3 files changed

+163
-7
lines changed

3 files changed

+163
-7
lines changed

diffsynth/pipelines/hunyuan_video.py

Lines changed: 98 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from einops import rearrange
99
import numpy as np
1010
from PIL import Image
11+
from tqdm import tqdm
1112

1213

1314

@@ -94,6 +95,7 @@ def __call__(
9495
embedded_guidance=6.0,
9596
cfg_scale=1.0,
9697
num_inference_steps=30,
98+
tea_cache_l1_thresh=None,
9799
tile_size=(17, 30, 30),
98100
tile_stride=(12, 20, 20),
99101
step_processor=None,
@@ -126,6 +128,9 @@ def __call__(
126128
# Extra input
127129
extra_input = self.prepare_extra_input(latents, guidance=embedded_guidance)
128130

131+
# TeaCache
132+
tea_cache_kwargs = {"tea_cache": TeaCache(num_inference_steps, rel_l1_thresh=tea_cache_l1_thresh) if tea_cache_l1_thresh is not None else None}
133+
129134
# Denoise
130135
self.load_models_to_device([] if self.vram_management else ["dit"])
131136
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
@@ -134,9 +139,9 @@ def __call__(
134139

135140
# Inference
136141
with torch.autocast(device_type=self.device, dtype=self.torch_dtype):
137-
noise_pred_posi = self.dit(latents, timestep, **prompt_emb_posi, **extra_input)
142+
noise_pred_posi = lets_dance_hunyuan_video(self.dit, latents, timestep, **prompt_emb_posi, **extra_input, **tea_cache_kwargs)
138143
if cfg_scale != 1.0:
139-
noise_pred_nega = self.dit(latents, timestep, **prompt_emb_nega, **extra_input)
144+
noise_pred_nega = lets_dance_hunyuan_video(self.dit, latents, timestep, **prompt_emb_nega, **extra_input)
140145
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
141146
else:
142147
noise_pred = noise_pred_posi
@@ -165,3 +170,94 @@ def __call__(
165170
frames = self.tensor2video(frames[0])
166171

167172
return frames
173+
174+
175+
176+
class TeaCache:
177+
def __init__(self, num_inference_steps, rel_l1_thresh):
178+
self.num_inference_steps = num_inference_steps
179+
self.step = 0
180+
self.accumulated_rel_l1_distance = 0
181+
self.previous_modulated_input = None
182+
self.rel_l1_thresh = rel_l1_thresh
183+
self.previous_residual = None
184+
self.previous_hidden_states = None
185+
186+
def check(self, dit: HunyuanVideoDiT, img, vec):
187+
img_ = img.clone()
188+
vec_ = vec.clone()
189+
img_mod1_shift, img_mod1_scale, _, _, _, _ = dit.double_blocks[0].component_a.mod(vec_).chunk(6, dim=-1)
190+
normed_inp = dit.double_blocks[0].component_a.norm1(img_)
191+
modulated_inp = normed_inp * (1 + img_mod1_scale.unsqueeze(1)) + img_mod1_shift.unsqueeze(1)
192+
if self.step == 0 or self.step == self.num_inference_steps - 1:
193+
should_calc = True
194+
self.accumulated_rel_l1_distance = 0
195+
else:
196+
coefficients = [7.33226126e+02, -4.01131952e+02, 6.75869174e+01, -3.14987800e+00, 9.61237896e-02]
197+
rescale_func = np.poly1d(coefficients)
198+
self.accumulated_rel_l1_distance += rescale_func(((modulated_inp-self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item())
199+
if self.accumulated_rel_l1_distance < self.rel_l1_thresh:
200+
should_calc = False
201+
else:
202+
should_calc = True
203+
self.accumulated_rel_l1_distance = 0
204+
self.previous_modulated_input = modulated_inp
205+
self.step += 1
206+
if self.step == self.num_inference_steps:
207+
self.step = 0
208+
if should_calc:
209+
self.previous_hidden_states = img.clone()
210+
return not should_calc
211+
212+
def store(self, hidden_states):
213+
self.previous_residual = hidden_states - self.previous_hidden_states
214+
self.previous_hidden_states = None
215+
216+
def update(self, hidden_states):
217+
hidden_states = hidden_states + self.previous_residual
218+
return hidden_states
219+
220+
221+
222+
def lets_dance_hunyuan_video(
223+
dit: HunyuanVideoDiT,
224+
x: torch.Tensor,
225+
t: torch.Tensor,
226+
prompt_emb: torch.Tensor = None,
227+
text_mask: torch.Tensor = None,
228+
pooled_prompt_emb: torch.Tensor = None,
229+
freqs_cos: torch.Tensor = None,
230+
freqs_sin: torch.Tensor = None,
231+
guidance: torch.Tensor = None,
232+
tea_cache: TeaCache = None,
233+
**kwargs
234+
):
235+
B, C, T, H, W = x.shape
236+
237+
vec = dit.time_in(t, dtype=torch.float32) + dit.vector_in(pooled_prompt_emb) + dit.guidance_in(guidance * 1000, dtype=torch.float32)
238+
img = dit.img_in(x)
239+
txt = dit.txt_in(prompt_emb, t, text_mask)
240+
241+
# TeaCache
242+
if tea_cache is not None:
243+
tea_cache_update = tea_cache.check(dit, img, vec)
244+
else:
245+
tea_cache_update = False
246+
247+
if tea_cache_update:
248+
print("TeaCache skip forward.")
249+
img = tea_cache.update(img)
250+
else:
251+
for block in tqdm(dit.double_blocks, desc="Double stream blocks"):
252+
img, txt = block(img, txt, vec, (freqs_cos, freqs_sin))
253+
254+
x = torch.concat([img, txt], dim=1)
255+
for block in tqdm(dit.single_blocks, desc="Single stream blocks"):
256+
x = block(x, vec, (freqs_cos, freqs_sin))
257+
img = x[:, :-256]
258+
259+
if tea_cache is not None:
260+
tea_cache.store(img)
261+
img = dit.final_layer(img, vec)
262+
img = dit.unpatchify(img, T=T//1, H=H//2, W=W//2)
263+
return img

examples/TeaCache/README.md

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,31 @@ TeaCache ([Timestep Embedding Aware Cache](https://github.com/ali-vilab/TeaCache
44

55
## Examples
66

7-
We provide examples on FLUX.1-dev. See [./flux_teacache.py](./flux_teacache.py).
7+
### FLUX
8+
9+
Script: [./flux_teacache.py](./flux_teacache.py)
10+
11+
Model: FLUX.1-dev
812

913
Steps: 50
1014

1115
GPU: A100
1216

13-
|TeaCache is disabled|tea_cache_l1_thresh=0.2|tea_cache_l1_thresh=0.4|tea_cache_l1_thresh=0.6|tea_cache_l1_thresh=0.8|
14-
|-|-|-|-|-|
15-
|23s|13s|9s|6s|5s|
16-
|![image_None](https://github.com/user-attachments/assets/2bf5187a-9693-44d3-9ebb-6c33cd15443f)|![image_0 2](https://github.com/user-attachments/assets/5532ba94-c7e2-446e-a9ba-1c68c0f63350)|![image_0 4](https://github.com/user-attachments/assets/4c57c50d-87cd-493b-8603-1da57ec3b70d)|![image_0 6](https://github.com/user-attachments/assets/1d95a3a9-71f9-4b1a-ad5f-a5ea8d52eca7)|![image_0 8](https://github.com/user-attachments/assets/d8cfdd74-8b45-4048-b1b7-ce480aa23fa1)
17+
|TeaCache is disabled|tea_cache_l1_thresh=0.2|tea_cache_l1_thresh=0.8|
18+
|-|-|-|
19+
|23s|13s|5s|
20+
|![image_None](https://github.com/user-attachments/assets/2bf5187a-9693-44d3-9ebb-6c33cd15443f)|![image_0 2](https://github.com/user-attachments/assets/5532ba94-c7e2-446e-a9ba-1c68c0f63350)|![image_0 8](https://github.com/user-attachments/assets/d8cfdd74-8b45-4048-b1b7-ce480aa23fa1)
21+
22+
### Hunyuan Video
23+
24+
Script: [./hunyuanvideo_teacache.py](./hunyuanvideo_teacache.py)
25+
26+
Model: Hunyuan Video
27+
28+
Steps: 30
29+
30+
GPU: A100
31+
32+
The following video was generated using TeaCache. It is nearly identical to [the video without TeaCache enabled](https://github.com/user-attachments/assets/48dd24bb-0cc6-40d2-88c3-10feed3267e9), but with double the speed.
33+
34+
https://github.com/user-attachments/assets/cd9801c5-88ce-4efc-b055-2c7737166f34
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
import torch
2+
torch.cuda.set_per_process_memory_fraction(1.0, 0)
3+
from diffsynth import ModelManager, HunyuanVideoPipeline, download_models, save_video
4+
5+
6+
download_models(["HunyuanVideo"])
7+
model_manager = ModelManager()
8+
9+
# The DiT model is loaded in bfloat16.
10+
model_manager.load_models(
11+
[
12+
"models/HunyuanVideo/transformers/mp_rank_00_model_states.pt"
13+
],
14+
torch_dtype=torch.bfloat16, # you can use torch_dtype=torch.float8_e4m3fn to enable quantization.
15+
device="cpu"
16+
)
17+
18+
# The other modules are loaded in float16.
19+
model_manager.load_models(
20+
[
21+
"models/HunyuanVideo/text_encoder/model.safetensors",
22+
"models/HunyuanVideo/text_encoder_2",
23+
"models/HunyuanVideo/vae/pytorch_model.pt",
24+
],
25+
torch_dtype=torch.float16,
26+
device="cpu"
27+
)
28+
29+
# We support LoRA inference. You can use the following code to load your LoRA model.
30+
# model_manager.load_lora("models/lora/xxx.safetensors", lora_alpha=1.0)
31+
32+
# The computation device is "cuda".
33+
pipe = HunyuanVideoPipeline.from_model_manager(
34+
model_manager,
35+
torch_dtype=torch.bfloat16,
36+
device="cuda"
37+
)
38+
39+
# Enjoy!
40+
prompt = "CG, masterpiece, best quality, solo, long hair, wavy hair, silver hair, blue eyes, blue dress, medium breasts, dress, underwater, air bubble, floating hair, refraction, portrait. The girl's flowing silver hair shimmers with every color of the rainbow and cascades down, merging with the floating flora around her."
41+
video = pipe(prompt, seed=0, tea_cache_l1_thresh=0.15)
42+
save_video(video, "video_girl.mp4", fps=30, quality=6)

0 commit comments

Comments
 (0)