88from einops import rearrange
99import numpy as np
1010from PIL import Image
11+ from tqdm import tqdm
1112
1213
1314
@@ -94,6 +95,7 @@ def __call__(
9495 embedded_guidance = 6.0 ,
9596 cfg_scale = 1.0 ,
9697 num_inference_steps = 30 ,
98+ tea_cache_l1_thresh = None ,
9799 tile_size = (17 , 30 , 30 ),
98100 tile_stride = (12 , 20 , 20 ),
99101 step_processor = None ,
@@ -126,6 +128,9 @@ def __call__(
126128 # Extra input
127129 extra_input = self .prepare_extra_input (latents , guidance = embedded_guidance )
128130
131+ # TeaCache
132+ tea_cache_kwargs = {"tea_cache" : TeaCache (num_inference_steps , rel_l1_thresh = tea_cache_l1_thresh ) if tea_cache_l1_thresh is not None else None }
133+
129134 # Denoise
130135 self .load_models_to_device ([] if self .vram_management else ["dit" ])
131136 for progress_id , timestep in enumerate (progress_bar_cmd (self .scheduler .timesteps )):
@@ -134,9 +139,9 @@ def __call__(
134139
135140 # Inference
136141 with torch .autocast (device_type = self .device , dtype = self .torch_dtype ):
137- noise_pred_posi = self .dit ( latents , timestep , ** prompt_emb_posi , ** extra_input )
142+ noise_pred_posi = lets_dance_hunyuan_video ( self .dit , latents , timestep , ** prompt_emb_posi , ** extra_input , ** tea_cache_kwargs )
138143 if cfg_scale != 1.0 :
139- noise_pred_nega = self .dit ( latents , timestep , ** prompt_emb_nega , ** extra_input )
144+ noise_pred_nega = lets_dance_hunyuan_video ( self .dit , latents , timestep , ** prompt_emb_nega , ** extra_input )
140145 noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega )
141146 else :
142147 noise_pred = noise_pred_posi
@@ -165,3 +170,94 @@ def __call__(
165170 frames = self .tensor2video (frames [0 ])
166171
167172 return frames
173+
174+
175+
176+ class TeaCache :
177+ def __init__ (self , num_inference_steps , rel_l1_thresh ):
178+ self .num_inference_steps = num_inference_steps
179+ self .step = 0
180+ self .accumulated_rel_l1_distance = 0
181+ self .previous_modulated_input = None
182+ self .rel_l1_thresh = rel_l1_thresh
183+ self .previous_residual = None
184+ self .previous_hidden_states = None
185+
186+ def check (self , dit : HunyuanVideoDiT , img , vec ):
187+ img_ = img .clone ()
188+ vec_ = vec .clone ()
189+ img_mod1_shift , img_mod1_scale , _ , _ , _ , _ = dit .double_blocks [0 ].component_a .mod (vec_ ).chunk (6 , dim = - 1 )
190+ normed_inp = dit .double_blocks [0 ].component_a .norm1 (img_ )
191+ modulated_inp = normed_inp * (1 + img_mod1_scale .unsqueeze (1 )) + img_mod1_shift .unsqueeze (1 )
192+ if self .step == 0 or self .step == self .num_inference_steps - 1 :
193+ should_calc = True
194+ self .accumulated_rel_l1_distance = 0
195+ else :
196+ coefficients = [7.33226126e+02 , - 4.01131952e+02 , 6.75869174e+01 , - 3.14987800e+00 , 9.61237896e-02 ]
197+ rescale_func = np .poly1d (coefficients )
198+ self .accumulated_rel_l1_distance += rescale_func (((modulated_inp - self .previous_modulated_input ).abs ().mean () / self .previous_modulated_input .abs ().mean ()).cpu ().item ())
199+ if self .accumulated_rel_l1_distance < self .rel_l1_thresh :
200+ should_calc = False
201+ else :
202+ should_calc = True
203+ self .accumulated_rel_l1_distance = 0
204+ self .previous_modulated_input = modulated_inp
205+ self .step += 1
206+ if self .step == self .num_inference_steps :
207+ self .step = 0
208+ if should_calc :
209+ self .previous_hidden_states = img .clone ()
210+ return not should_calc
211+
212+ def store (self , hidden_states ):
213+ self .previous_residual = hidden_states - self .previous_hidden_states
214+ self .previous_hidden_states = None
215+
216+ def update (self , hidden_states ):
217+ hidden_states = hidden_states + self .previous_residual
218+ return hidden_states
219+
220+
221+
222+ def lets_dance_hunyuan_video (
223+ dit : HunyuanVideoDiT ,
224+ x : torch .Tensor ,
225+ t : torch .Tensor ,
226+ prompt_emb : torch .Tensor = None ,
227+ text_mask : torch .Tensor = None ,
228+ pooled_prompt_emb : torch .Tensor = None ,
229+ freqs_cos : torch .Tensor = None ,
230+ freqs_sin : torch .Tensor = None ,
231+ guidance : torch .Tensor = None ,
232+ tea_cache : TeaCache = None ,
233+ ** kwargs
234+ ):
235+ B , C , T , H , W = x .shape
236+
237+ vec = dit .time_in (t , dtype = torch .float32 ) + dit .vector_in (pooled_prompt_emb ) + dit .guidance_in (guidance * 1000 , dtype = torch .float32 )
238+ img = dit .img_in (x )
239+ txt = dit .txt_in (prompt_emb , t , text_mask )
240+
241+ # TeaCache
242+ if tea_cache is not None :
243+ tea_cache_update = tea_cache .check (dit , img , vec )
244+ else :
245+ tea_cache_update = False
246+
247+ if tea_cache_update :
248+ print ("TeaCache skip forward." )
249+ img = tea_cache .update (img )
250+ else :
251+ for block in tqdm (dit .double_blocks , desc = "Double stream blocks" ):
252+ img , txt = block (img , txt , vec , (freqs_cos , freqs_sin ))
253+
254+ x = torch .concat ([img , txt ], dim = 1 )
255+ for block in tqdm (dit .single_blocks , desc = "Single stream blocks" ):
256+ x = block (x , vec , (freqs_cos , freqs_sin ))
257+ img = x [:, :- 256 ]
258+
259+ if tea_cache is not None :
260+ tea_cache .store (img )
261+ img = dit .final_layer (img , vec )
262+ img = dit .unpatchify (img , T = T // 1 , H = H // 2 , W = W // 2 )
263+ return img
0 commit comments