-
Notifications
You must be signed in to change notification settings - Fork 6.6k
Add support for Magcache #12744
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Add support for Magcache #12744
Conversation
|
@leffff could you review as well if possible? |
|
Hi @AlanPonnachan @sayakpaul |
|
@leffff , Thank you for your review. To address this, I am implementing a Calibration Mode. My plan is to add a
Users can then simply run one calibration pass for their specific model/scheduler, copy the output ratios, and pass them into I am working on this update now and will push the changes shortly! |
Sounds great! |
|
Thanks for the thoughtful discussions here @AlanPonnachan and @leffff! I will leave my two cents below:
Ccing @DN6 to get his thoughts here, too. |
|
Thanks @sayakpaul and @leffff for the feedback! I have updated the PR to address these points. Instead of a standalone utility script, I integrated the calibration logic directly into the hook configuration for better usability:
Ready for review! |
|
Looks Great! Could you please provide a usage example:
And Provide Generations To be Sure it works, please provide generations for SD3.5 Medium, Flux, Wan T2V 2.1 1.3b I also believe, as caching is suitable for all tasks, can we also try Kandinsky 5.0 Video Pro I2V kandinskylab/Kandinsky-5.0-I2V-Pro-sft-5s-Diffusers |
1. Usage Example import torch
from diffusers import FluxPipeline from diffusers.hooks import MagCacheConfig, apply_mag_cache
pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16).to("cuda")
# CALIBRATION STEP
config = MagCacheConfig(calibrate=True, num_inference_steps=4)
apply_mag_cache(pipe.transformer, config)
pipe("A cat playing chess", num_inference_steps=4)
# Logs: [1.0, 1.37, 0.97, 0.87]
# INFERENCE STEP
config = MagCacheConfig(mag_ratios=[1.0, 1.37, 0.97, 0.87], num_inference_steps=4)
apply_mag_cache(pipe.transformer, config)
pipe("A cat playing chess", num_inference_steps=4)2. Benchmark ResultsI validated the implementation on Flux, SD 3.5, and Wan 2.1 using a T4 Colab environment.
3. GenerationsAttached below are the outputs for the successful runs. |
|
Here is the Colab notebook used to generate the benchmarks above. It includes the full setup, memory optimizations (sequential offloading/dummy embeds), and the execution logs: |
|
@bot /style |
|
Style bot fixed some files and pushed the changes. |
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
|
This looks good! |
|
@AlanPonnachan thanks for your great work thus far! Some minor questions (mostly out of curiosity below):
Additionally, I could obtain outputs with Wan 1.3B and they look reasonable to me. Codeimport torch
from diffusers import AutoencoderKLWan, WanPipeline
from diffusers.hooks import MagCacheConfig, apply_mag_cache
from diffusers.utils import export_to_video
# Available models: Wan-AI/Wan2.1-T2V-14B-Diffusers, Wan-AI/Wan2.1-T2V-1.3B-Diffusers
model_id = "Wan-AI/Wan2.1-T2V-1.3B-Diffusers"
num_inference_steps = 50
vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
pipe = WanPipeline.from_pretrained(model_id, vae=vae, torch_dtype=torch.bfloat16)
pipe.to("cuda")
# config = MagCacheConfig(calibrate=True, num_inference_steps=num_inference_steps)
# apply_mag_cache(pipe.transformer, config)
config = MagCacheConfig(
mag_ratios=[1.0, 1.0337707996368408, 0.9908783435821533, 0.9898878931999207, 0.990186870098114, 0.989551305770874, 0.9898356199264526, 0.9901290535926819, 0.9913457632064819, 0.9893063902854919, 0.990695059299469, 0.9892956614494324, 0.9910416603088379, 0.9908630847930908, 0.9897039532661438, 0.9907404184341431, 0.98955237865448, 0.9905906915664673, 0.9881031513214111, 0.98977130651474, 0.9878108501434326, 0.9873648285865784, 0.98862624168396, 0.9870336055755615, 0.9855726957321167, 0.9857151508331299, 0.98496013879776, 0.9846605658531189, 0.9835416674613953, 0.984062671661377, 0.9805435538291931, 0.9828993678092957, 0.9804039001464844, 0.9776313304901123, 0.9769471883773804, 0.9752448201179504, 0.973810076713562, 0.9708614349365234, 0.9703076481819153, 0.9666262865066528, 0.9658275246620178, 0.9612534046173096, 0.9553734064102173, 0.9522399306297302, 0.9467942118644714, 0.9430344104766846, 0.9335862994194031, 0.9285727739334106, 0.9244886636734009, 0.9560992121696472],
num_inference_steps=num_inference_steps
)
apply_mag_cache(pipe.transformer, config)
prompt = "A cat walks on the grass, realistic"
negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"
output = pipe(
prompt=prompt,
negative_prompt=negative_prompt,
height=480,
width=832,
num_frames=81,
guidance_scale=5.0,
num_inference_steps=num_inference_steps,
).frames[0]
export_to_video(output, "output.mp4", fps=15)Outputs: # Calibation
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 50/50 [01:35<00:00, 1.91s/it]
# After using the `mag_ratios`
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 50/50 [00:27<00:00, 1.82it/s]Video output: output.mp4However, there seems to be a problem when using Kandinsky 5 and the error seems obvious to me. Error: https://pastebin.com/F7arxTWg Codeimport torch
from diffusers import Kandinsky5T2VPipeline
from diffusers.hooks import MagCacheConfig, apply_mag_cache
from diffusers.utils import export_to_video
model_id = "ai-forever/Kandinsky-5.0-T2V-Lite-sft-5s-Diffusers"
num_inference_steps = 50
pipe = Kandinsky5T2VPipeline.from_pretrained(model_id, torch_dtype=torch.bfloat16)
pipe = pipe.to("cuda")
config = MagCacheConfig(calibrate=True, num_inference_steps=num_inference_steps)
apply_mag_cache(pipe.transformer, config)
# config = MagCacheConfig(
# mag_ratios=[...],
# num_inference_steps=num_inference_steps
# )
# apply_mag_cache(pipe.transformer, config)
prompt = "A cat and a dog baking a cake together in a kitchen."
negative_prompt = "Static, 2D cartoon, cartoon, 2d animation, paintings, images, worst quality, low quality, ugly, deformed, walking backwards"
output = pipe(
prompt=prompt,
negative_prompt=negative_prompt,
height=512,
width=768,
num_frames=121, # ~5 seconds at 24fps
num_inference_steps=num_inference_steps,
guidance_scale=5.0,
).frames[0]
export_to_video(output, "output_kandinsky.mp4", fps=24, quality=9)For this, instead of a line like the following maybe we could pass it to the cache config? I understand this could be difficult for the users but my thought is since they have to perform calibration anyway, this is still reasonable? Just for curiosity, I changed to: diff --git a/src/diffusers/hooks/mag_cache.py b/src/diffusers/hooks/mag_cache.py
index 71ebfcb25..0a7c333db 100644
--- a/src/diffusers/hooks/mag_cache.py
+++ b/src/diffusers/hooks/mag_cache.py
@@ -183,7 +183,7 @@ class MagCacheHeadHook(ModelHook):
self.state_manager.set_context("inference")
# Capture input hidden_states
- hidden_states = self._metadata._get_parameter_from_args_kwargs("hidden_states", args, kwargs)
+ hidden_states = self._metadata._get_parameter_from_args_kwargs("visual_embed", args, kwargs)
state: MagCacheState = self.state_manager.get_state()
state.head_block_input = hidden_states
@@ -297,7 +297,7 @@ class MagCacheBlockHook(ModelHook):
state: MagCacheState = self.state_manager.get_state()
if not state.should_compute:
- hidden_states = self._metadata._get_parameter_from_args_kwargs("hidden_states", args, kwargs)
+ hidden_states = self._metadata._get_parameter_from_args_kwargs("visual_embed", args, kwargs)
if self.is_tail:
# Still need to advance step index even if we skip
self._advance_step(state)
And ran the above code. But I am getting a pair of Unfold[MagCache] Calibration Complete. Copy these values to MagCacheConfig(mag_ratios=...):
[1.0, 1.0096147060394287, 0.8601706027984619, 1.0066865682601929, 1.1018145084381104, 1.0066889524459839, 1.07235848903656, 1.006271243095398, 1.0583757162094116, 1.0066468715667725, 1.0803261995315552, 1.0059221982955933, 1.0304542779922485, 1.0061317682266235, 1.0251237154006958, 1.006355881690979, 1.0230522155761719, 1.0063568353652954, 1.0354706048965454, 1.006076455116272, 1.0154225826263428, 1.0064369440078735, 1.0257697105407715, 1.0066747665405273, 1.012341856956482, 1.0068379640579224, 1.017471432685852, 1.0070058107376099, 1.008599877357483, 1.00702702999115, 1.0158008337020874, 1.0070949792861938, 1.0113613605499268, 1.0063375234603882, 1.0122487545013428, 1.0064034461975098, 1.0091496706008911, 1.0062494277954102, 1.0109937191009521, 1.0061204433441162, 1.0084550380706787, 1.0059889554977417, 1.006821870803833, 1.0058847665786743, 1.0106556415557861, 1.005847454071045, 1.0057544708251953, 1.0058276653289795, 1.0092748403549194, 1.005746841430664]
[MagCache] Calibration Complete. Copy these values to MagCacheConfig(mag_ratios=...):
[1.0, 1.0056898593902588, 1.0074970722198486, 1.005563735961914, 1.0061627626419067, 1.0054070949554443, 1.0053973197937012, 1.0052893161773682, 1.0067739486694336, 1.0051906108856201, 1.0049010515213013, 1.0050380229949951, 1.0056493282318115, 1.0049028396606445, 1.0056771039962769, 1.0048167705535889, 1.0038255453109741, 1.0047082901000977, 1.0041747093200684, 1.004562258720398, 1.002451777458191, 1.0044060945510864, 1.0022073984146118, 1.0042728185653687, 1.0011045932769775, 1.0041989088058472, 0.9996317625045776, 1.0040632486343384, 0.9980409741401672, 1.0038821697235107, 0.9960299134254456, 1.004146933555603, 0.9924721717834473, 1.0041824579238892, 0.9876144528388977, 1.0041331052780151, 0.9839898943901062, 1.003833293914795, 0.976319432258606, 1.0032036304473877, 0.9627748131752014, 1.002505898475647, 0.9450504779815674, 1.001646637916565, 0.9085856080055237, 0.9999536275863647, 0.8368133306503296, 0.9975034594535828, 0.6354470252990723, 0.9997955560684204]When applying the first one, I got: output_kandinsky.mp4When applying the second one, I got: output_kandinsky_2.mp4Thought this would help :) |
|
@sayakpaul thank you for running inferences from your side, it helped a lot. 1. Regarding
|
Makes sense, yeah!
This is awesome. Let's make sure we document it once we're at that point.
Okay then this needs to be documented as well. However, there are some small models where we run CFG in a batched manner. Would that affect Cc: @Zehong-Ma! Hey maybe you would like to review the PR as well :) |
sayakpaul
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot for working on this!
I have left some comments, LMK what you think of them.
Let's add documentation and button up testing :)
| ) | ||
| _import_structure["hooks"].extend( | ||
| [ | ||
| "FLUX_MAG_RATIOS", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's fine if we don't expose it this publicly. Since these are still a bit experimental in nature, I would prefer it to stay within the core MagCache implementation file.
| "blocks", | ||
| "transformer_blocks", | ||
| "single_transformer_blocks", | ||
| "layers", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For ZImage I am guessing?
| TransformerBlockRegistry.register( | ||
| model_class=JointTransformerBlock, | ||
| metadata=TransformerBlockMetadata( | ||
| return_hidden_states_index=1, | ||
| return_encoder_hidden_states_index=0, | ||
| ), | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For SD3 I am guessing?
| def nearest_interp(src_array: np.ndarray, target_length: int) -> np.ndarray: | ||
| """ | ||
| Interpolate the source array to the target length using nearest neighbor interpolation. | ||
| """ | ||
| src_length = len(src_array) | ||
| if target_length == 1: | ||
| return np.array([src_array[-1]]) | ||
|
|
||
| scale = (src_length - 1) / (target_length - 1) | ||
| mapped_indices = np.round(np.arange(target_length) * scale).astype(int) | ||
| return src_array[mapped_indices] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we keep it to PyTorch-only? I am guessing there will be a performance advantage to be had because these operations can be performed on an accelerator if needed.
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| import unittest |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can follow the similar testing logic as #12569
| self._metadata = TransformerBlockRegistry.get(unwrapped_module.__class__) | ||
| return module | ||
|
|
||
| def new_forward(self, module: torch.nn.Module, *args, **kwargs): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we can decorate this function with torch.compiler.disable so that we can also apply compilation alongside caching for potential performance gains?
Thanks for your review and the contribution of @AlanPonnachan . I have briefly reviewed the pull request. Most of your discussion are correct and concise. There may be two important things that should be clearly discussed or fixed.
|






What does this PR do?
This PR adds support for MagCache (Magnitude-aware Cache), a training-free inference acceleration method for diffusion models, specifically targeting Transformer-based architectures like Flux.
This implementation follows the
ModelHookpattern (similar toFirstBlockCache) to integrate seamlessly into Diffusers.Key features:
MagCacheConfig: Configuration class to control threshold, retention ratio, and skipping limits.calibrate=Trueflag. When enabled, the hook runs full inference and calculates/prints the magnitude ratios for the specific model and scheduler. This makes MagCache compatible with any transformer model (e.g., Hunyuan, Wan, SD3), not just Flux.mag_ratiosmust be explicitly provided in the config (or calibration enabled).FLUX_MAG_RATIOSas a constant for convenience, derived from the official implementation.Fixes Magcache Support. #12697
Before submitting
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
@sayakpaul