|
1 | 1 | # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. |
2 | 2 | import math |
3 | | - |
| 3 | +import numpy as np |
4 | 4 | import torch |
5 | 5 | import torch.amp as amp |
6 | 6 | import torch.nn as nn |
@@ -484,6 +484,7 @@ def __init__( |
484 | 484 | self.num_frame_per_block = 1 |
485 | 485 | self.flag_causal_attention = False |
486 | 486 | self.block_mask = None |
| 487 | + self.enable_teacache = False |
487 | 488 |
|
488 | 489 | # embeddings |
489 | 490 | self.patch_embedding = nn.Conv3d(in_dim, dim, kernel_size=patch_size, stride=patch_size) |
@@ -574,6 +575,50 @@ def attention_mask(b, h, q_idx, kv_idx): |
574 | 575 |
|
575 | 576 | return block_mask |
576 | 577 |
|
| 578 | + def initialize_teacache(self, enable_teacache=True, num_steps=25, teacache_thresh=0.15, use_ret_steps=False, ckpt_dir=''): |
| 579 | + self.enable_teacache = enable_teacache |
| 580 | + print('using teacache') |
| 581 | + self.cnt = 0 |
| 582 | + self.num_steps = num_steps |
| 583 | + self.teacache_thresh = teacache_thresh |
| 584 | + self.accumulated_rel_l1_distance_even = 0 |
| 585 | + self.accumulated_rel_l1_distance_odd = 0 |
| 586 | + self.previous_e0_even = None |
| 587 | + self.previous_e0_odd = None |
| 588 | + self.previous_residual_even = None |
| 589 | + self.previous_residual_odd = None |
| 590 | + self.use_ref_steps = use_ret_steps |
| 591 | + if "I2V" in ckpt_dir: |
| 592 | + if use_ret_steps: |
| 593 | + if '540P' in ckpt_dir: |
| 594 | + self.coefficients = [ 2.57151496e+05, -3.54229917e+04, 1.40286849e+03, -1.35890334e+01, 1.32517977e-01] |
| 595 | + if '720P' in ckpt_dir: |
| 596 | + self.coefficients = [ 8.10705460e+03, 2.13393892e+03, -3.72934672e+02, 1.66203073e+01, -4.17769401e-02] |
| 597 | + self.ret_steps = 5*2 |
| 598 | + self.cutoff_steps = num_steps*2 |
| 599 | + else: |
| 600 | + if '540P' in ckpt_dir: |
| 601 | + self.coefficients = [-3.02331670e+02, 2.23948934e+02, -5.25463970e+01, 5.87348440e+00, -2.01973289e-01] |
| 602 | + if '720P' in ckpt_dir: |
| 603 | + self.coefficients = [-114.36346466, 65.26524496, -18.82220707, 4.91518089, -0.23412683] |
| 604 | + self.ret_steps = 1*2 |
| 605 | + self.cutoff_steps = num_steps*2 - 2 |
| 606 | + else: |
| 607 | + if use_ret_steps: |
| 608 | + if '1.3B' in ckpt_dir: |
| 609 | + self.coefficients = [-5.21862437e+04, 9.23041404e+03, -5.28275948e+02, 1.36987616e+01, -4.99875664e-02] |
| 610 | + if '14B' in ckpt_dir: |
| 611 | + self.coefficients = [-3.03318725e+05, 4.90537029e+04, -2.65530556e+03, 5.87365115e+01, -3.15583525e-01] |
| 612 | + self.ret_steps = 5*2 |
| 613 | + self.cutoff_steps = num_steps*2 |
| 614 | + else: |
| 615 | + if '1.3B' in ckpt_dir: |
| 616 | + self.coefficients = [2.39676752e+03, -1.31110545e+03, 2.01331979e+02, -8.29855975e+00, 1.37887774e-01] |
| 617 | + if '14B' in ckpt_dir: |
| 618 | + self.coefficients = [-5784.54975374, 5449.50911966, -1811.16591783, 256.27178429, -13.02252404] |
| 619 | + self.ret_steps = 1*2 |
| 620 | + self.cutoff_steps = num_steps*2 - 2 |
| 621 | + |
577 | 622 | def forward(self, x, t, context, clip_fea=None, y=None, fps=None): |
578 | 623 | r""" |
579 | 624 | Forward pass through the diffusion model |
@@ -664,13 +709,68 @@ def forward(self, x, t, context, clip_fea=None, y=None, fps=None): |
664 | 709 |
|
665 | 710 | # arguments |
666 | 711 | kwargs = dict(e=e0, grid_sizes=grid_sizes, freqs=self.freqs, context=context, block_mask=self.block_mask) |
667 | | - for block in self.blocks: |
668 | | - x = block(x, **kwargs) |
| 712 | + if self.enable_teacache: |
| 713 | + modulated_inp = e0 if self.use_ref_steps else e |
| 714 | + # teacache |
| 715 | + if self.cnt%2==0: # even -> conditon |
| 716 | + self.is_even = True |
| 717 | + if self.cnt < self.ret_steps or self.cnt >= self.cutoff_steps: |
| 718 | + should_calc_even = True |
| 719 | + self.accumulated_rel_l1_distance_even = 0 |
| 720 | + else: |
| 721 | + rescale_func = np.poly1d(self.coefficients) |
| 722 | + self.accumulated_rel_l1_distance_even += rescale_func(((modulated_inp-self.previous_e0_even).abs().mean() / self.previous_e0_even.abs().mean()).cpu().item()) |
| 723 | + if self.accumulated_rel_l1_distance_even < self.teacache_thresh: |
| 724 | + should_calc_even = False |
| 725 | + else: |
| 726 | + should_calc_even = True |
| 727 | + self.accumulated_rel_l1_distance_even = 0 |
| 728 | + self.previous_e0_even = modulated_inp.clone() |
| 729 | + |
| 730 | + else: # odd -> unconditon |
| 731 | + self.is_even = False |
| 732 | + if self.cnt < self.ret_steps or self.cnt >= self.cutoff_steps: |
| 733 | + should_calc_odd = True |
| 734 | + self.accumulated_rel_l1_distance_odd = 0 |
| 735 | + else: |
| 736 | + rescale_func = np.poly1d(self.coefficients) |
| 737 | + self.accumulated_rel_l1_distance_odd += rescale_func(((modulated_inp-self.previous_e0_odd).abs().mean() / self.previous_e0_odd.abs().mean()).cpu().item()) |
| 738 | + if self.accumulated_rel_l1_distance_odd < self.teacache_thresh: |
| 739 | + should_calc_odd = False |
| 740 | + else: |
| 741 | + should_calc_odd = True |
| 742 | + self.accumulated_rel_l1_distance_odd = 0 |
| 743 | + self.previous_e0_odd = modulated_inp.clone() |
| 744 | + |
| 745 | + if self.enable_teacache: |
| 746 | + if self.is_even: |
| 747 | + if not should_calc_even: |
| 748 | + x += self.previous_residual_even |
| 749 | + else: |
| 750 | + ori_x = x.clone() |
| 751 | + for block in self.blocks: |
| 752 | + x = block(x, **kwargs) |
| 753 | + self.previous_residual_even = x - ori_x |
| 754 | + else: |
| 755 | + if not should_calc_odd: |
| 756 | + x += self.previous_residual_odd |
| 757 | + else: |
| 758 | + ori_x = x.clone() |
| 759 | + for block in self.blocks: |
| 760 | + x = block(x, **kwargs) |
| 761 | + self.previous_residual_odd = x - ori_x |
| 762 | + |
| 763 | + else: |
| 764 | + for block in self.blocks: |
| 765 | + x = block(x, **kwargs) |
669 | 766 |
|
670 | 767 | x = self.head(x, e) |
671 | 768 |
|
672 | 769 | # unpatchify |
673 | 770 | x = self.unpatchify(x, grid_sizes) |
| 771 | + self.cnt += 1 |
| 772 | + if self.cnt >= self.num_steps: |
| 773 | + self.cnt = 0 |
674 | 774 | return x.float() |
675 | 775 |
|
676 | 776 | def unpatchify(self, x, grid_sizes): |
|
0 commit comments