Skip to content

Commit c0889c2

Browse files
committed
support teacache
1 parent 913591c commit c0889c2

File tree

3 files changed

+122
-31
lines changed

3 files changed

+122
-31
lines changed

diffsynth/pipelines/flux_image.py

Lines changed: 91 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -280,6 +280,8 @@ def __call__(
280280
eligen_entity_masks=None,
281281
enable_eligen_on_negative=False,
282282
enable_eligen_inpaint=False,
283+
# TeaCache
284+
tea_cache_l1_thresh=None,
283285
# Tile
284286
tiled=False,
285287
tile_size=128,
@@ -314,6 +316,9 @@ def __call__(
314316
# ControlNets
315317
controlnet_kwargs_posi, controlnet_kwargs_nega, local_controlnet_kwargs = self.prepare_controlnet(controlnet_image, masks, controlnet_inpaint_mask, tiler_kwargs, enable_controlnet_on_negative)
316318

319+
# TeaCache
320+
tea_cache_kwargs = {"tea_cache": TeaCache(num_inference_steps, rel_l1_thresh=tea_cache_l1_thresh) if tea_cache_l1_thresh is not None else None}
321+
317322
# Denoise
318323
self.load_models_to_device(['dit', 'controlnet'])
319324
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
@@ -323,7 +328,7 @@ def __call__(
323328
inference_callback = lambda prompt_emb_posi, controlnet_kwargs: lets_dance_flux(
324329
dit=self.dit, controlnet=self.controlnet,
325330
hidden_states=latents, timestep=timestep,
326-
**prompt_emb_posi, **tiler_kwargs, **extra_input, **controlnet_kwargs, **ipadapter_kwargs_list_posi, **eligen_kwargs_posi,
331+
**prompt_emb_posi, **tiler_kwargs, **extra_input, **controlnet_kwargs, **ipadapter_kwargs_list_posi, **eligen_kwargs_posi, **tea_cache_kwargs,
327332
)
328333
noise_pred_posi = self.control_noise_via_local_prompts(
329334
prompt_emb_posi, prompt_emb_locals, masks, mask_scales, inference_callback,
@@ -362,6 +367,48 @@ def __call__(
362367
return image
363368

364369

370+
class TeaCache:
371+
def __init__(self, num_inference_steps, rel_l1_thresh):
372+
self.num_inference_steps = num_inference_steps
373+
self.step = 0
374+
self.accumulated_rel_l1_distance = 0
375+
self.previous_modulated_input = None
376+
self.rel_l1_thresh = rel_l1_thresh
377+
self.previous_residual = None
378+
self.previous_hidden_states = None
379+
380+
def check(self, dit: FluxDiT, hidden_states, conditioning):
381+
inp = hidden_states.clone()
382+
temb_ = conditioning.clone()
383+
modulated_inp, _, _, _, _ = dit.blocks[0].norm1_a(inp, emb=temb_)
384+
if self.step == 0 or self.step == self.num_inference_steps - 1:
385+
should_calc = True
386+
self.accumulated_rel_l1_distance = 0
387+
else:
388+
coefficients = [4.98651651e+02, -2.83781631e+02, 5.58554382e+01, -3.82021401e+00, 2.64230861e-01]
389+
rescale_func = np.poly1d(coefficients)
390+
self.accumulated_rel_l1_distance += rescale_func(((modulated_inp-self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item())
391+
if self.accumulated_rel_l1_distance < self.rel_l1_thresh:
392+
should_calc = False
393+
else:
394+
should_calc = True
395+
self.accumulated_rel_l1_distance = 0
396+
self.previous_modulated_input = modulated_inp
397+
self.step += 1
398+
if self.step == self.num_inference_steps:
399+
self.step = 0
400+
if should_calc:
401+
self.previous_hidden_states = hidden_states.clone()
402+
return not should_calc
403+
404+
def store(self, hidden_states):
405+
self.previous_residual = hidden_states - self.previous_hidden_states
406+
self.previous_hidden_states = None
407+
408+
def update(self, hidden_states):
409+
hidden_states = hidden_states + self.previous_residual
410+
return hidden_states
411+
365412

366413
def lets_dance_flux(
367414
dit: FluxDiT,
@@ -380,6 +427,7 @@ def lets_dance_flux(
380427
entity_prompt_emb=None,
381428
entity_masks=None,
382429
ipadapter_kwargs_list={},
430+
tea_cache: TeaCache = None,
383431
**kwargs
384432
):
385433
if tiled:
@@ -446,36 +494,48 @@ def flux_forward_fn(hl, hr, wl, wr):
446494
image_rotary_emb = dit.pos_embedder(torch.cat((text_ids, image_ids), dim=1))
447495
attention_mask = None
448496

449-
# Joint Blocks
450-
for block_id, block in enumerate(dit.blocks):
451-
hidden_states, prompt_emb = block(
452-
hidden_states,
453-
prompt_emb,
454-
conditioning,
455-
image_rotary_emb,
456-
attention_mask,
457-
ipadapter_kwargs_list=ipadapter_kwargs_list.get(block_id, None)
458-
)
459-
# ControlNet
460-
if controlnet is not None and controlnet_frames is not None:
461-
hidden_states = hidden_states + controlnet_res_stack[block_id]
462-
463-
# Single Blocks
464-
hidden_states = torch.cat([prompt_emb, hidden_states], dim=1)
465-
num_joint_blocks = len(dit.blocks)
466-
for block_id, block in enumerate(dit.single_blocks):
467-
hidden_states, prompt_emb = block(
468-
hidden_states,
469-
prompt_emb,
470-
conditioning,
471-
image_rotary_emb,
472-
attention_mask,
473-
ipadapter_kwargs_list=ipadapter_kwargs_list.get(block_id + num_joint_blocks, None)
474-
)
475-
# ControlNet
476-
if controlnet is not None and controlnet_frames is not None:
477-
hidden_states[:, prompt_emb.shape[1]:] = hidden_states[:, prompt_emb.shape[1]:] + controlnet_single_res_stack[block_id]
478-
hidden_states = hidden_states[:, prompt_emb.shape[1]:]
497+
# TeaCache
498+
if tea_cache is not None:
499+
tea_cache_update = tea_cache.check(dit, hidden_states, conditioning)
500+
else:
501+
tea_cache_update = False
502+
503+
if tea_cache_update:
504+
hidden_states = tea_cache.update(hidden_states)
505+
else:
506+
# Joint Blocks
507+
for block_id, block in enumerate(dit.blocks):
508+
hidden_states, prompt_emb = block(
509+
hidden_states,
510+
prompt_emb,
511+
conditioning,
512+
image_rotary_emb,
513+
attention_mask,
514+
ipadapter_kwargs_list=ipadapter_kwargs_list.get(block_id, None)
515+
)
516+
# ControlNet
517+
if controlnet is not None and controlnet_frames is not None:
518+
hidden_states = hidden_states + controlnet_res_stack[block_id]
519+
520+
# Single Blocks
521+
hidden_states = torch.cat([prompt_emb, hidden_states], dim=1)
522+
num_joint_blocks = len(dit.blocks)
523+
for block_id, block in enumerate(dit.single_blocks):
524+
hidden_states, prompt_emb = block(
525+
hidden_states,
526+
prompt_emb,
527+
conditioning,
528+
image_rotary_emb,
529+
attention_mask,
530+
ipadapter_kwargs_list=ipadapter_kwargs_list.get(block_id + num_joint_blocks, None)
531+
)
532+
# ControlNet
533+
if controlnet is not None and controlnet_frames is not None:
534+
hidden_states[:, prompt_emb.shape[1]:] = hidden_states[:, prompt_emb.shape[1]:] + controlnet_single_res_stack[block_id]
535+
hidden_states = hidden_states[:, prompt_emb.shape[1]:]
536+
537+
if tea_cache is not None:
538+
tea_cache.store(hidden_states)
479539

480540
hidden_states = dit.final_norm_out(hidden_states, conditioning)
481541
hidden_states = dit.final_proj_out(hidden_states)

examples/TeaCache/README.md

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
# TeaCache
2+
3+
TeaCache ([Timestep Embedding Aware Cache](https://github.com/ali-vilab/TeaCache)) is a training-free caching approach that estimates and leverages the fluctuating differences among model outputs across timesteps, thereby accelerating the inference.
4+
5+
## Examples
6+
7+
We provide examples on FLUX.1-dev. See [./flux_teacache.py](./flux_teacache.py).
8+
9+
Steps: 50
10+
11+
GPU: A100
12+
13+
|TeaCache is disabled|tea_cache_l1_thresh=0.2|tea_cache_l1_thresh=0.4|tea_cache_l1_thresh=0.6|tea_cache_l1_thresh=0.8|
14+
|-|-|-|-|-|
15+
|23s|13s|9s|6s|5s|
16+
|![image_None](https://github.com/user-attachments/assets/2bf5187a-9693-44d3-9ebb-6c33cd15443f)|![image_0 2](https://github.com/user-attachments/assets/5532ba94-c7e2-446e-a9ba-1c68c0f63350)|![image_0 4](https://github.com/user-attachments/assets/4c57c50d-87cd-493b-8603-1da57ec3b70d)|![image_0 6](https://github.com/user-attachments/assets/1d95a3a9-71f9-4b1a-ad5f-a5ea8d52eca7)|![image_0 8](https://github.com/user-attachments/assets/d8cfdd74-8b45-4048-b1b7-ce480aa23fa1)

examples/TeaCache/flux_teacache.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
import torch
2+
from diffsynth import ModelManager, FluxImagePipeline
3+
4+
5+
model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cuda", model_id_list=["FLUX.1-dev"])
6+
pipe = FluxImagePipeline.from_model_manager(model_manager)
7+
8+
prompt = "CG, masterpiece, best quality, solo, long hair, wavy hair, silver hair, blue eyes, blue dress, medium breasts, dress, underwater, air bubble, floating hair, refraction, portrait. The girl's flowing silver hair shimmers with every color of the rainbow and cascades down, merging with the floating flora around her."
9+
10+
for tea_cache_l1_thresh in [None, 0.2, 0.4, 0.6, 0.8]:
11+
image = pipe(
12+
prompt=prompt, embedded_guidance=3.5, seed=0,
13+
num_inference_steps=50, tea_cache_l1_thresh=tea_cache_l1_thresh
14+
)
15+
image.save(f"image_{tea_cache_l1_thresh}.png")

0 commit comments

Comments
 (0)