Skip to content

Commit 1028f53

Browse files
linchunzesteven-kl
andauthored
support teacache (#35)
Co-authored-by: Steven <[email protected]>
1 parent 8ddc4df commit 1028f53

File tree

5 files changed

+154
-7
lines changed

5 files changed

+154
-7
lines changed

README.md

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -191,13 +191,16 @@ python3 generate_video_df.py \
191191
--overlap_history 17 \
192192
--prompt "A graceful white swan with a curved neck and delicate feathers swimming in a serene lake at dawn, its reflection perfectly mirrored in the still water as mist rises from the surface, with the swan occasionally dipping its head into the water to feed." \
193193
--addnoise_condition 20 \
194-
--offload
194+
--offload \
195+
--teacache \
196+
--use_ret_steps \
197+
--teacache_thresh 0.3
195198
```
196199

197200
asynchronous generation for 30s video
198201
```shell
199202
model_id=Skywork/SkyReels-V2-DF-14B-540P
200-
# synchronous inference
203+
# asynchronous inference
201204
python3 generate_video_df.py \
202205
--model_id ${model_id} \
203206
--resolution 540P \
@@ -232,7 +235,10 @@ python3 generate_video.py \
232235
--shift 8.0 \
233236
--fps 24 \
234237
--prompt "A serene lake surrounded by towering mountains, with a few swans gracefully gliding across the water and sunlight dancing on the surface." \
235-
--offload
238+
--offload \
239+
--teacache \
240+
--use_ret_steps \
241+
--teacache_thresh 0.3
236242
```
237243
> **Note**:
238244
> - When using an **image-to-video (I2V)** model, you must provide an input image using the `--image ${image_path}` parameter. The `--guidance_scale 5.0` and `--shift 3.0` is recommended for I2V model.
@@ -269,7 +275,10 @@ Below are the key parameters you can customize for video generation:
269275
| --offload | True | Offloads model components to CPU to reduce VRAM usage (recommended) |
270276
| --use_usp | True | Enables multi-GPU acceleration with xDiT USP |
271277
| --outdir | ./video_out | Directory where generated videos will be saved |
272-
| --prompt_enhancer | True | expand the prompt into a more detailed description |
278+
| --prompt_enhancer | True | Expand the prompt into a more detailed description |
279+
| --teacache | False | Enables teacache for faster inference |
280+
| --teacache_thresh | 0.2 | Higher speedup will cause to worse quality |
281+
| --use_ret_steps | False | Retention Steps for teacache |
273282

274283
**Diffusion Forcing Additional Parameters**
275284
| Parameter | Recommended Value | Description |

generate_video.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,16 @@
4848
default="A serene lake surrounded by towering mountains, with a few swans gracefully gliding across the water and sunlight dancing on the surface.",
4949
)
5050
parser.add_argument("--prompt_enhancer", action="store_true")
51+
parser.add_argument("--teacache", action="store_true")
52+
parser.add_argument(
53+
"--teacache_thresh",
54+
type=float,
55+
default=0.2,
56+
help="Higher speedup will cause to worse quality -- 0.1 for 2.0x speedup -- 0.2 for 3.0x speedup")
57+
parser.add_argument(
58+
"--use_ret_steps",
59+
action="store_true",
60+
help="Using Retention Steps will result in faster generation speed and better generation quality.")
5161
args = parser.parse_args()
5262

5363
args.model_id = download_model(args.model_id)
@@ -116,6 +126,11 @@
116126
height, width = width, height
117127
args.image = resizecrop(args.image, height, width)
118128

129+
if args.teacache:
130+
pipe.transformer.initialize_teacache(enable_teacache=True, num_steps=args.inference_steps,
131+
teacache_thresh=args.teacache_thresh, use_ret_steps=args.use_ret_steps,
132+
ckpt_dir=args.model_id)
133+
119134
prompt_input = args.prompt
120135
if args.prompt_enhancer and image is not None:
121136
prompt_input = prompt_enhancer(prompt_input)

generate_video_df.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,16 @@
3939
default="A woman in a leather jacket and sunglasses riding a vintage motorcycle through a desert highway at sunset, her hair blowing wildly in the wind as the motorcycle kicks up dust, with the golden sun casting long shadows across the barren landscape.",
4040
)
4141
parser.add_argument("--prompt_enhancer", action="store_true")
42+
parser.add_argument("--teacache", action="store_true")
43+
parser.add_argument(
44+
"--teacache_thresh",
45+
type=float,
46+
default=0.2,
47+
help="Higher speedup will cause to worse quality -- 0.1 for 2.0x speedup -- 0.2 for 3.0x speedup")
48+
parser.add_argument(
49+
"--use_ret_steps",
50+
action="store_true",
51+
help="Using Retention Steps will result in faster generation speed and better generation quality.")
4252
args = parser.parse_args()
4353

4454
args.model_id = download_model(args.model_id)
@@ -117,6 +127,16 @@
117127

118128
if args.causal_attention:
119129
pipe.transformer.set_ar_attention(args.causal_block_size)
130+
131+
if args.teacache:
132+
if args.ar_step > 0:
133+
num_steps = args.inference_steps + (((args.base_num_frames - 1)//4 + 1) // args.causal_block_size - 1) * args.ar_step
134+
print('num_steps:', num_steps)
135+
else:
136+
num_steps = args.inference_steps
137+
pipe.transformer.initialize_teacache(enable_teacache=True, num_steps=num_steps,
138+
teacache_thresh=args.teacache_thresh, use_ret_steps=args.use_ret_steps,
139+
ckpt_dir=args.model_id)
120140

121141
print(f"prompt:{prompt_input}")
122142
print(f"guidance_scale:{guidance_scale}")

skyreels_v2_infer/modules/transformer.py

Lines changed: 103 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
22
import math
3-
3+
import numpy as np
44
import torch
55
import torch.amp as amp
66
import torch.nn as nn
@@ -484,6 +484,7 @@ def __init__(
484484
self.num_frame_per_block = 1
485485
self.flag_causal_attention = False
486486
self.block_mask = None
487+
self.enable_teacache = False
487488

488489
# embeddings
489490
self.patch_embedding = nn.Conv3d(in_dim, dim, kernel_size=patch_size, stride=patch_size)
@@ -574,6 +575,50 @@ def attention_mask(b, h, q_idx, kv_idx):
574575

575576
return block_mask
576577

578+
def initialize_teacache(self, enable_teacache=True, num_steps=25, teacache_thresh=0.15, use_ret_steps=False, ckpt_dir=''):
579+
self.enable_teacache = enable_teacache
580+
print('using teacache')
581+
self.cnt = 0
582+
self.num_steps = num_steps
583+
self.teacache_thresh = teacache_thresh
584+
self.accumulated_rel_l1_distance_even = 0
585+
self.accumulated_rel_l1_distance_odd = 0
586+
self.previous_e0_even = None
587+
self.previous_e0_odd = None
588+
self.previous_residual_even = None
589+
self.previous_residual_odd = None
590+
self.use_ref_steps = use_ret_steps
591+
if "I2V" in ckpt_dir:
592+
if use_ret_steps:
593+
if '540P' in ckpt_dir:
594+
self.coefficients = [ 2.57151496e+05, -3.54229917e+04, 1.40286849e+03, -1.35890334e+01, 1.32517977e-01]
595+
if '720P' in ckpt_dir:
596+
self.coefficients = [ 8.10705460e+03, 2.13393892e+03, -3.72934672e+02, 1.66203073e+01, -4.17769401e-02]
597+
self.ret_steps = 5*2
598+
self.cutoff_steps = num_steps*2
599+
else:
600+
if '540P' in ckpt_dir:
601+
self.coefficients = [-3.02331670e+02, 2.23948934e+02, -5.25463970e+01, 5.87348440e+00, -2.01973289e-01]
602+
if '720P' in ckpt_dir:
603+
self.coefficients = [-114.36346466, 65.26524496, -18.82220707, 4.91518089, -0.23412683]
604+
self.ret_steps = 1*2
605+
self.cutoff_steps = num_steps*2 - 2
606+
else:
607+
if use_ret_steps:
608+
if '1.3B' in ckpt_dir:
609+
self.coefficients = [-5.21862437e+04, 9.23041404e+03, -5.28275948e+02, 1.36987616e+01, -4.99875664e-02]
610+
if '14B' in ckpt_dir:
611+
self.coefficients = [-3.03318725e+05, 4.90537029e+04, -2.65530556e+03, 5.87365115e+01, -3.15583525e-01]
612+
self.ret_steps = 5*2
613+
self.cutoff_steps = num_steps*2
614+
else:
615+
if '1.3B' in ckpt_dir:
616+
self.coefficients = [2.39676752e+03, -1.31110545e+03, 2.01331979e+02, -8.29855975e+00, 1.37887774e-01]
617+
if '14B' in ckpt_dir:
618+
self.coefficients = [-5784.54975374, 5449.50911966, -1811.16591783, 256.27178429, -13.02252404]
619+
self.ret_steps = 1*2
620+
self.cutoff_steps = num_steps*2 - 2
621+
577622
def forward(self, x, t, context, clip_fea=None, y=None, fps=None):
578623
r"""
579624
Forward pass through the diffusion model
@@ -664,13 +709,68 @@ def forward(self, x, t, context, clip_fea=None, y=None, fps=None):
664709

665710
# arguments
666711
kwargs = dict(e=e0, grid_sizes=grid_sizes, freqs=self.freqs, context=context, block_mask=self.block_mask)
667-
for block in self.blocks:
668-
x = block(x, **kwargs)
712+
if self.enable_teacache:
713+
modulated_inp = e0 if self.use_ref_steps else e
714+
# teacache
715+
if self.cnt%2==0: # even -> conditon
716+
self.is_even = True
717+
if self.cnt < self.ret_steps or self.cnt >= self.cutoff_steps:
718+
should_calc_even = True
719+
self.accumulated_rel_l1_distance_even = 0
720+
else:
721+
rescale_func = np.poly1d(self.coefficients)
722+
self.accumulated_rel_l1_distance_even += rescale_func(((modulated_inp-self.previous_e0_even).abs().mean() / self.previous_e0_even.abs().mean()).cpu().item())
723+
if self.accumulated_rel_l1_distance_even < self.teacache_thresh:
724+
should_calc_even = False
725+
else:
726+
should_calc_even = True
727+
self.accumulated_rel_l1_distance_even = 0
728+
self.previous_e0_even = modulated_inp.clone()
729+
730+
else: # odd -> unconditon
731+
self.is_even = False
732+
if self.cnt < self.ret_steps or self.cnt >= self.cutoff_steps:
733+
should_calc_odd = True
734+
self.accumulated_rel_l1_distance_odd = 0
735+
else:
736+
rescale_func = np.poly1d(self.coefficients)
737+
self.accumulated_rel_l1_distance_odd += rescale_func(((modulated_inp-self.previous_e0_odd).abs().mean() / self.previous_e0_odd.abs().mean()).cpu().item())
738+
if self.accumulated_rel_l1_distance_odd < self.teacache_thresh:
739+
should_calc_odd = False
740+
else:
741+
should_calc_odd = True
742+
self.accumulated_rel_l1_distance_odd = 0
743+
self.previous_e0_odd = modulated_inp.clone()
744+
745+
if self.enable_teacache:
746+
if self.is_even:
747+
if not should_calc_even:
748+
x += self.previous_residual_even
749+
else:
750+
ori_x = x.clone()
751+
for block in self.blocks:
752+
x = block(x, **kwargs)
753+
self.previous_residual_even = x - ori_x
754+
else:
755+
if not should_calc_odd:
756+
x += self.previous_residual_odd
757+
else:
758+
ori_x = x.clone()
759+
for block in self.blocks:
760+
x = block(x, **kwargs)
761+
self.previous_residual_odd = x - ori_x
762+
763+
else:
764+
for block in self.blocks:
765+
x = block(x, **kwargs)
669766

670767
x = self.head(x, e)
671768

672769
# unpatchify
673770
x = self.unpatchify(x, grid_sizes)
771+
self.cnt += 1
772+
if self.cnt >= self.num_steps:
773+
self.cnt = 0
674774
return x.float()
675775

676776
def unpatchify(self, x, grid_sizes):

skyreels_v2_infer/pipelines/diffusion_forcing_pipeline.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -328,6 +328,9 @@ def __call__(
328328
finished_frame_num = i * (base_num_frames - overlap_history_frames) + overlap_history_frames
329329
left_frame_num = latent_length - finished_frame_num
330330
base_num_frames_iter = min(left_frame_num + overlap_history_frames, base_num_frames)
331+
if ar_step > 0 and self.transformer.enable_teacache:
332+
num_steps = num_inference_steps + ((base_num_frames_iter - overlap_history_frames) // causal_block_size - 1) * ar_step
333+
self.transformer.num_steps = num_steps
331334
else: # i == 0
332335
base_num_frames_iter = base_num_frames
333336
latent_shape = [16, base_num_frames_iter, latent_height, latent_width]

0 commit comments

Comments
 (0)