@@ -280,6 +280,8 @@ def __call__(
280280 eligen_entity_masks = None ,
281281 enable_eligen_on_negative = False ,
282282 enable_eligen_inpaint = False ,
283+ # TeaCache
284+ tea_cache_l1_thresh = None ,
283285 # Tile
284286 tiled = False ,
285287 tile_size = 128 ,
@@ -314,6 +316,9 @@ def __call__(
314316 # ControlNets
315317 controlnet_kwargs_posi , controlnet_kwargs_nega , local_controlnet_kwargs = self .prepare_controlnet (controlnet_image , masks , controlnet_inpaint_mask , tiler_kwargs , enable_controlnet_on_negative )
316318
319+ # TeaCache
320+ tea_cache_kwargs = {"tea_cache" : TeaCache (num_inference_steps , rel_l1_thresh = tea_cache_l1_thresh ) if tea_cache_l1_thresh is not None else None }
321+
317322 # Denoise
318323 self .load_models_to_device (['dit' , 'controlnet' ])
319324 for progress_id , timestep in enumerate (progress_bar_cmd (self .scheduler .timesteps )):
@@ -323,7 +328,7 @@ def __call__(
323328 inference_callback = lambda prompt_emb_posi , controlnet_kwargs : lets_dance_flux (
324329 dit = self .dit , controlnet = self .controlnet ,
325330 hidden_states = latents , timestep = timestep ,
326- ** prompt_emb_posi , ** tiler_kwargs , ** extra_input , ** controlnet_kwargs , ** ipadapter_kwargs_list_posi , ** eligen_kwargs_posi ,
331+ ** prompt_emb_posi , ** tiler_kwargs , ** extra_input , ** controlnet_kwargs , ** ipadapter_kwargs_list_posi , ** eligen_kwargs_posi , ** tea_cache_kwargs ,
327332 )
328333 noise_pred_posi = self .control_noise_via_local_prompts (
329334 prompt_emb_posi , prompt_emb_locals , masks , mask_scales , inference_callback ,
@@ -362,6 +367,48 @@ def __call__(
362367 return image
363368
364369
370+ class TeaCache :
371+ def __init__ (self , num_inference_steps , rel_l1_thresh ):
372+ self .num_inference_steps = num_inference_steps
373+ self .step = 0
374+ self .accumulated_rel_l1_distance = 0
375+ self .previous_modulated_input = None
376+ self .rel_l1_thresh = rel_l1_thresh
377+ self .previous_residual = None
378+ self .previous_hidden_states = None
379+
380+ def check (self , dit : FluxDiT , hidden_states , conditioning ):
381+ inp = hidden_states .clone ()
382+ temb_ = conditioning .clone ()
383+ modulated_inp , _ , _ , _ , _ = dit .blocks [0 ].norm1_a (inp , emb = temb_ )
384+ if self .step == 0 or self .step == self .num_inference_steps - 1 :
385+ should_calc = True
386+ self .accumulated_rel_l1_distance = 0
387+ else :
388+ coefficients = [4.98651651e+02 , - 2.83781631e+02 , 5.58554382e+01 , - 3.82021401e+00 , 2.64230861e-01 ]
389+ rescale_func = np .poly1d (coefficients )
390+ self .accumulated_rel_l1_distance += rescale_func (((modulated_inp - self .previous_modulated_input ).abs ().mean () / self .previous_modulated_input .abs ().mean ()).cpu ().item ())
391+ if self .accumulated_rel_l1_distance < self .rel_l1_thresh :
392+ should_calc = False
393+ else :
394+ should_calc = True
395+ self .accumulated_rel_l1_distance = 0
396+ self .previous_modulated_input = modulated_inp
397+ self .step += 1
398+ if self .step == self .num_inference_steps :
399+ self .step = 0
400+ if should_calc :
401+ self .previous_hidden_states = hidden_states .clone ()
402+ return not should_calc
403+
404+ def store (self , hidden_states ):
405+ self .previous_residual = hidden_states - self .previous_hidden_states
406+ self .previous_hidden_states = None
407+
408+ def update (self , hidden_states ):
409+ hidden_states = hidden_states + self .previous_residual
410+ return hidden_states
411+
365412
366413def lets_dance_flux (
367414 dit : FluxDiT ,
@@ -380,6 +427,7 @@ def lets_dance_flux(
380427 entity_prompt_emb = None ,
381428 entity_masks = None ,
382429 ipadapter_kwargs_list = {},
430+ tea_cache : TeaCache = None ,
383431 ** kwargs
384432):
385433 if tiled :
@@ -446,36 +494,48 @@ def flux_forward_fn(hl, hr, wl, wr):
446494 image_rotary_emb = dit .pos_embedder (torch .cat ((text_ids , image_ids ), dim = 1 ))
447495 attention_mask = None
448496
449- # Joint Blocks
450- for block_id , block in enumerate (dit .blocks ):
451- hidden_states , prompt_emb = block (
452- hidden_states ,
453- prompt_emb ,
454- conditioning ,
455- image_rotary_emb ,
456- attention_mask ,
457- ipadapter_kwargs_list = ipadapter_kwargs_list .get (block_id , None )
458- )
459- # ControlNet
460- if controlnet is not None and controlnet_frames is not None :
461- hidden_states = hidden_states + controlnet_res_stack [block_id ]
462-
463- # Single Blocks
464- hidden_states = torch .cat ([prompt_emb , hidden_states ], dim = 1 )
465- num_joint_blocks = len (dit .blocks )
466- for block_id , block in enumerate (dit .single_blocks ):
467- hidden_states , prompt_emb = block (
468- hidden_states ,
469- prompt_emb ,
470- conditioning ,
471- image_rotary_emb ,
472- attention_mask ,
473- ipadapter_kwargs_list = ipadapter_kwargs_list .get (block_id + num_joint_blocks , None )
474- )
475- # ControlNet
476- if controlnet is not None and controlnet_frames is not None :
477- hidden_states [:, prompt_emb .shape [1 ]:] = hidden_states [:, prompt_emb .shape [1 ]:] + controlnet_single_res_stack [block_id ]
478- hidden_states = hidden_states [:, prompt_emb .shape [1 ]:]
497+ # TeaCache
498+ if tea_cache is not None :
499+ tea_cache_update = tea_cache .check (dit , hidden_states , conditioning )
500+ else :
501+ tea_cache_update = False
502+
503+ if tea_cache_update :
504+ hidden_states = tea_cache .update (hidden_states )
505+ else :
506+ # Joint Blocks
507+ for block_id , block in enumerate (dit .blocks ):
508+ hidden_states , prompt_emb = block (
509+ hidden_states ,
510+ prompt_emb ,
511+ conditioning ,
512+ image_rotary_emb ,
513+ attention_mask ,
514+ ipadapter_kwargs_list = ipadapter_kwargs_list .get (block_id , None )
515+ )
516+ # ControlNet
517+ if controlnet is not None and controlnet_frames is not None :
518+ hidden_states = hidden_states + controlnet_res_stack [block_id ]
519+
520+ # Single Blocks
521+ hidden_states = torch .cat ([prompt_emb , hidden_states ], dim = 1 )
522+ num_joint_blocks = len (dit .blocks )
523+ for block_id , block in enumerate (dit .single_blocks ):
524+ hidden_states , prompt_emb = block (
525+ hidden_states ,
526+ prompt_emb ,
527+ conditioning ,
528+ image_rotary_emb ,
529+ attention_mask ,
530+ ipadapter_kwargs_list = ipadapter_kwargs_list .get (block_id + num_joint_blocks , None )
531+ )
532+ # ControlNet
533+ if controlnet is not None and controlnet_frames is not None :
534+ hidden_states [:, prompt_emb .shape [1 ]:] = hidden_states [:, prompt_emb .shape [1 ]:] + controlnet_single_res_stack [block_id ]
535+ hidden_states = hidden_states [:, prompt_emb .shape [1 ]:]
536+
537+ if tea_cache is not None :
538+ tea_cache .store (hidden_states )
479539
480540 hidden_states = dit .final_norm_out (hidden_states , conditioning )
481541 hidden_states = dit .final_proj_out (hidden_states )
0 commit comments