diff --git a/tests/hooks/test_mag_cache.py b/tests/hooks/test_mag_cache.py deleted file mode 100644 index a7e1b52d3b69..000000000000 --- a/tests/hooks/test_mag_cache.py +++ /dev/null @@ -1,244 +0,0 @@ -# Copyright 2025 HuggingFace Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import unittest - -import numpy as np -import torch - -from diffusers import MagCacheConfig, apply_mag_cache -from diffusers.hooks._helpers import TransformerBlockMetadata, TransformerBlockRegistry -from diffusers.models import ModelMixin -from diffusers.utils import logging - - -logger = logging.get_logger(__name__) - - -class DummyBlock(torch.nn.Module): - def __init__(self): - super().__init__() - - def forward(self, hidden_states, encoder_hidden_states=None, **kwargs): - # Output is double input - # This ensures Residual = 2*Input - Input = Input - return hidden_states * 2.0 - - -class DummyTransformer(ModelMixin): - def __init__(self): - super().__init__() - self.transformer_blocks = torch.nn.ModuleList([DummyBlock(), DummyBlock()]) - - def forward(self, hidden_states, encoder_hidden_states=None): - for block in self.transformer_blocks: - hidden_states = block(hidden_states, encoder_hidden_states=encoder_hidden_states) - return hidden_states - - -class TupleOutputBlock(torch.nn.Module): - def __init__(self): - super().__init__() - - def forward(self, hidden_states, encoder_hidden_states=None, **kwargs): - # Returns a tuple - return hidden_states * 2.0, encoder_hidden_states - - -class TupleTransformer(ModelMixin): - def __init__(self): - super().__init__() - self.transformer_blocks = torch.nn.ModuleList([TupleOutputBlock()]) - - def forward(self, hidden_states, encoder_hidden_states=None): - for block in self.transformer_blocks: - # Emulate Flux-like behavior - output = block(hidden_states, encoder_hidden_states=encoder_hidden_states) - hidden_states = output[0] - encoder_hidden_states = output[1] - return hidden_states, encoder_hidden_states - - -class MagCacheTests(unittest.TestCase): - def setUp(self): - # Register standard dummy block - TransformerBlockRegistry.register( - DummyBlock, - TransformerBlockMetadata(return_hidden_states_index=None, return_encoder_hidden_states_index=None), - ) - # Register tuple block (Flux style) - TransformerBlockRegistry.register( - TupleOutputBlock, - TransformerBlockMetadata(return_hidden_states_index=0, return_encoder_hidden_states_index=1), - ) - - def _set_context(self, model, context_name): - """Helper to set context on all hooks in the model.""" - for module in model.modules(): - if hasattr(module, "_diffusers_hook"): - module._diffusers_hook._set_context(context_name) - - def _get_calibration_data(self, model): - for module in model.modules(): - if hasattr(module, "_diffusers_hook"): - hook = module._diffusers_hook.get_hook("mag_cache_block_hook") - if hook: - return hook.state_manager.get_state().calibration_ratios - return [] - - def test_mag_cache_validation(self): - """Test that missing mag_ratios raises ValueError.""" - with self.assertRaises(ValueError): - MagCacheConfig(num_inference_steps=10, calibrate=False) - - def test_mag_cache_skipping_logic(self): - """ - Tests that MagCache correctly calculates residuals and skips blocks when conditions are met. - """ - model = DummyTransformer() - - # Dummy ratios: [1.0, 1.0] implies 0 accumulated error if we skip - ratios = np.array([1.0, 1.0]) - - config = MagCacheConfig( - threshold=100.0, - num_inference_steps=2, - retention_ratio=0.0, # Enable immediate skipping - max_skip_steps=5, - mag_ratios=ratios, - ) - - apply_mag_cache(model, config) - self._set_context(model, "test_context") - - # Step 0: Input 10.0 -> Output 40.0 (2 blocks * 2x each) - # HeadInput=10. Output=40. Residual=30. - input_t0 = torch.tensor([[[10.0]]]) - output_t0 = model(input_t0) - self.assertTrue(torch.allclose(output_t0, torch.tensor([[[40.0]]])), "Step 0 failed") - - # Step 1: Input 11.0. - # If Skipped: Output = Input(11) + Residual(30) = 41.0 - # If Computed: Output = 11 * 4 = 44.0 - input_t1 = torch.tensor([[[11.0]]]) - output_t1 = model(input_t1) - - self.assertTrue( - torch.allclose(output_t1, torch.tensor([[[41.0]]])), f"Expected Skip (41.0), got {output_t1.item()}" - ) - - def test_mag_cache_retention(self): - """Test that retention_ratio prevents skipping even if error is low.""" - model = DummyTransformer() - # Ratios that imply 0 error, so it *would* skip if retention allowed it - ratios = np.array([1.0, 1.0]) - - config = MagCacheConfig( - threshold=100.0, - num_inference_steps=2, - retention_ratio=1.0, # Force retention for ALL steps - mag_ratios=ratios, - ) - - apply_mag_cache(model, config) - self._set_context(model, "test_context") - - # Step 0 - model(torch.tensor([[[10.0]]])) - - # Step 1: Should COMPUTE (44.0) not SKIP (41.0) because of retention - input_t1 = torch.tensor([[[11.0]]]) - output_t1 = model(input_t1) - - self.assertTrue( - torch.allclose(output_t1, torch.tensor([[[44.0]]])), - f"Expected Compute (44.0) due to retention, got {output_t1.item()}", - ) - - def test_mag_cache_tuple_outputs(self): - """Test compatibility with models returning (hidden, encoder_hidden) like Flux.""" - model = TupleTransformer() - ratios = np.array([1.0, 1.0]) - - config = MagCacheConfig(threshold=100.0, num_inference_steps=2, retention_ratio=0.0, mag_ratios=ratios) - - apply_mag_cache(model, config) - self._set_context(model, "test_context") - - # Step 0: Compute. Input 10.0 -> Output 20.0 (1 block * 2x) - # Residual = 10.0 - input_t0 = torch.tensor([[[10.0]]]) - enc_t0 = torch.tensor([[[1.0]]]) - out_0, _ = model(input_t0, encoder_hidden_states=enc_t0) - self.assertTrue(torch.allclose(out_0, torch.tensor([[[20.0]]]))) - - # Step 1: Skip. Input 11.0. - # Skipped Output = 11 + 10 = 21.0 - input_t1 = torch.tensor([[[11.0]]]) - out_1, _ = model(input_t1, encoder_hidden_states=enc_t0) - - self.assertTrue( - torch.allclose(out_1, torch.tensor([[[21.0]]])), f"Tuple skip failed. Expected 21.0, got {out_1.item()}" - ) - - def test_mag_cache_reset(self): - """Test that state resets correctly after num_inference_steps.""" - model = DummyTransformer() - config = MagCacheConfig( - threshold=100.0, num_inference_steps=2, retention_ratio=0.0, mag_ratios=np.array([1.0, 1.0]) - ) - apply_mag_cache(model, config) - self._set_context(model, "test_context") - - input_t = torch.ones(1, 1, 1) - - model(input_t) # Step 0 - model(input_t) # Step 1 (Skipped) - - # Step 2 (Reset -> Step 0) -> Should Compute - # Input 2.0 -> Output 8.0 - input_t2 = torch.tensor([[[2.0]]]) - output_t2 = model(input_t2) - - self.assertTrue(torch.allclose(output_t2, torch.tensor([[[8.0]]])), "State did not reset correctly") - - def test_mag_cache_calibration(self): - """Test that calibration mode records ratios.""" - model = DummyTransformer() - config = MagCacheConfig(num_inference_steps=2, calibrate=True) - apply_mag_cache(model, config) - self._set_context(model, "test_context") - - # Step 0 - # HeadInput = 10. Output = 40. Residual = 30. - # Ratio 0 is placeholder 1.0 - model(torch.tensor([[[10.0]]])) - - # Check intermediate state - ratios = self._get_calibration_data(model) - self.assertEqual(len(ratios), 1) - self.assertEqual(ratios[0], 1.0) - - # Step 1 - # HeadInput = 10. Output = 40. Residual = 30. - # PrevResidual = 30. CurrResidual = 30. - # Ratio = 30/30 = 1.0 - model(torch.tensor([[[10.0]]])) - - # Verify it computes fully (no skip) - # If it skipped, output would be 41.0. It should be 40.0 - # Actually in test setup, input is same (10.0) so output 40.0. - # Let's ensure list is empty after reset (end of step 1) - ratios_after = self._get_calibration_data(model) - self.assertEqual(ratios_after, []) diff --git a/tests/models/testing_utils/__init__.py b/tests/models/testing_utils/__init__.py index ea076b3ec774..54e34eab2e7c 100644 --- a/tests/models/testing_utils/__init__.py +++ b/tests/models/testing_utils/__init__.py @@ -5,8 +5,12 @@ FasterCacheTesterMixin, FirstBlockCacheConfigMixin, FirstBlockCacheTesterMixin, + MagCacheConfigMixin, + MagCacheTesterMixin, PyramidAttentionBroadcastConfigMixin, PyramidAttentionBroadcastTesterMixin, + TaylorSeerCacheConfigMixin, + TaylorSeerCacheTesterMixin, ) from .common import BaseModelTesterConfig, ModelTesterMixin from .compile import TorchCompileTesterMixin @@ -50,6 +54,8 @@ "FasterCacheTesterMixin", "FirstBlockCacheConfigMixin", "FirstBlockCacheTesterMixin", + "MagCacheConfigMixin", + "MagCacheTesterMixin", "GGUFCompileTesterMixin", "GGUFConfigMixin", "GGUFTesterMixin", @@ -65,6 +71,8 @@ "ModelTesterMixin", "PyramidAttentionBroadcastConfigMixin", "PyramidAttentionBroadcastTesterMixin", + "TaylorSeerCacheConfigMixin", + "TaylorSeerCacheTesterMixin", "QuantizationCompileTesterMixin", "QuantizationTesterMixin", "QuantoCompileTesterMixin", diff --git a/tests/models/testing_utils/cache.py b/tests/models/testing_utils/cache.py index f1c2ecba88a7..e8a835f6bccf 100644 --- a/tests/models/testing_utils/cache.py +++ b/tests/models/testing_utils/cache.py @@ -18,10 +18,18 @@ import pytest import torch -from diffusers.hooks import FasterCacheConfig, FirstBlockCacheConfig, PyramidAttentionBroadcastConfig +from diffusers.hooks import ( + FasterCacheConfig, + FirstBlockCacheConfig, + MagCacheConfig, + PyramidAttentionBroadcastConfig, + TaylorSeerCacheConfig, +) from diffusers.hooks.faster_cache import _FASTER_CACHE_BLOCK_HOOK, _FASTER_CACHE_DENOISER_HOOK from diffusers.hooks.first_block_cache import _FBC_BLOCK_HOOK, _FBC_LEADER_BLOCK_HOOK +from diffusers.hooks.mag_cache import _MAG_CACHE_BLOCK_HOOK, _MAG_CACHE_LEADER_BLOCK_HOOK from diffusers.hooks.pyramid_attention_broadcast import _PYRAMID_ATTENTION_BROADCAST_HOOK +from diffusers.hooks.taylorseer_cache import _TAYLORSEER_CACHE_HOOK from diffusers.models.cache_utils import CacheMixin from ...testing_utils import assert_tensors_close, backend_empty_cache, is_cache, torch_device @@ -554,3 +562,192 @@ def test_faster_cache_context_manager(self): @require_cache_mixin def test_faster_cache_reset_stateful_cache(self): self._test_reset_stateful_cache() + + +@is_cache +class MagCacheConfigMixin: + """ + Base mixin providing MagCache config. + + Expected class attributes: + - model_class: The model class to test (must use CacheMixin) + """ + + # Default MagCache config - can be overridden by subclasses. + # Uses neutral ratios [1.0, 1.0] and a high threshold so the second + # inference step is always skipped, which is required by _test_cache_inference. + MAG_CACHE_CONFIG = { + "num_inference_steps": 2, + "retention_ratio": 0.0, + "threshold": 100.0, + "mag_ratios": [1.0, 1.0], + } + + def _get_cache_config(self): + return MagCacheConfig(**self.MAG_CACHE_CONFIG) + + def _get_hook_names(self): + return [_MAG_CACHE_LEADER_BLOCK_HOOK, _MAG_CACHE_BLOCK_HOOK] + + +@is_cache +class MagCacheTesterMixin(MagCacheConfigMixin, CacheTesterMixin): + """ + Mixin class for testing MagCache on models. + + Expected class attributes: + - model_class: The model class to test (must use CacheMixin) + + Expected methods to be implemented by subclasses: + - get_init_dict(): Returns dict of arguments to initialize the model + - get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass + + Pytest mark: cache + Use `pytest -m "not cache"` to skip these tests + """ + + @require_cache_mixin + def test_mag_cache_enable_disable_state(self): + self._test_cache_enable_disable_state() + + @require_cache_mixin + def test_mag_cache_double_enable_raises_error(self): + self._test_cache_double_enable_raises_error() + + @require_cache_mixin + def test_mag_cache_hooks_registered(self): + self._test_cache_hooks_registered() + + @require_cache_mixin + def test_mag_cache_inference(self): + self._test_cache_inference() + + @require_cache_mixin + def test_mag_cache_context_manager(self): + self._test_cache_context_manager() + + @require_cache_mixin + def test_mag_cache_reset_stateful_cache(self): + self._test_reset_stateful_cache() + + +@is_cache +class TaylorSeerCacheConfigMixin: + """ + Base mixin providing TaylorSeerCache config. + + Expected class attributes: + - model_class: The model class to test (must use CacheMixin) + """ + + # Default TaylorSeerCache config - can be overridden by subclasses. + # Uses a low cache_interval and disable_cache_before_step=0 so the second + # inference step is always predicted, which is required by _test_cache_inference. + TAYLORSEER_CACHE_CONFIG = { + "cache_interval": 3, + "disable_cache_before_step": 1, + "max_order": 1, + } + + def _get_cache_config(self): + return TaylorSeerCacheConfig(**self.TAYLORSEER_CACHE_CONFIG) + + def _get_hook_names(self): + return [_TAYLORSEER_CACHE_HOOK] + + +@is_cache +class TaylorSeerCacheTesterMixin(TaylorSeerCacheConfigMixin, CacheTesterMixin): + """ + Mixin class for testing TaylorSeerCache on models. + + Expected class attributes: + - model_class: The model class to test (must use CacheMixin) + + Expected methods to be implemented by subclasses: + - get_init_dict(): Returns dict of arguments to initialize the model + - get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass + + Pytest mark: cache + Use `pytest -m "not cache"` to skip these tests + """ + + @torch.no_grad() + def _test_cache_inference(self): + """Test that model can run inference with TaylorSeer cache enabled (requires cache_context).""" + init_dict = self.get_init_dict() + inputs_dict = self.get_dummy_inputs() + model = self.model_class(**init_dict).to(torch_device) + model.eval() + + config = self._get_cache_config() + model.enable_cache(config) + + # TaylorSeer requires cache_context to be set for inference + with model.cache_context("taylorseer_test"): + # First pass populates the cache + _ = model(**inputs_dict, return_dict=False)[0] + + # Create modified inputs for second pass + inputs_dict_step2 = inputs_dict.copy() + if self.cache_input_key in inputs_dict_step2: + inputs_dict_step2[self.cache_input_key] = inputs_dict_step2[self.cache_input_key] + torch.randn_like( + inputs_dict_step2[self.cache_input_key] + ) + + # Second pass - TaylorSeer should use cached Taylor series predictions + output_with_cache = model(**inputs_dict_step2, return_dict=False)[0] + + assert output_with_cache is not None, "Model output should not be None with cache enabled." + assert not torch.isnan(output_with_cache).any(), "Model output contains NaN with cache enabled." + + # Run same inputs without cache to compare + model.disable_cache() + output_without_cache = model(**inputs_dict_step2, return_dict=False)[0] + + # Cached output should be different from non-cached output (due to approximation) + assert not torch.allclose(output_without_cache, output_with_cache, atol=1e-5), ( + "Cached output should be different from non-cached output due to cache approximation." + ) + + @torch.no_grad() + def _test_reset_stateful_cache(self): + """Test that _reset_stateful_cache resets the TaylorSeer cache state (requires cache_context).""" + init_dict = self.get_init_dict() + inputs_dict = self.get_dummy_inputs() + model = self.model_class(**init_dict).to(torch_device) + model.eval() + + config = self._get_cache_config() + model.enable_cache(config) + + with model.cache_context("taylorseer_test"): + _ = model(**inputs_dict, return_dict=False)[0] + + model._reset_stateful_cache() + + model.disable_cache() + + @require_cache_mixin + def test_taylorseer_cache_enable_disable_state(self): + self._test_cache_enable_disable_state() + + @require_cache_mixin + def test_taylorseer_cache_double_enable_raises_error(self): + self._test_cache_double_enable_raises_error() + + @require_cache_mixin + def test_taylorseer_cache_hooks_registered(self): + self._test_cache_hooks_registered() + + @require_cache_mixin + def test_taylorseer_cache_inference(self): + self._test_cache_inference() + + @require_cache_mixin + def test_taylorseer_cache_context_manager(self): + self._test_cache_context_manager() + + @require_cache_mixin + def test_taylorseer_cache_reset_stateful_cache(self): + self._test_reset_stateful_cache() diff --git a/tests/models/transformers/test_models_transformer_flux.py b/tests/models/transformers/test_models_transformer_flux.py index 2d39dadfcad1..b80377eb2875 100644 --- a/tests/models/transformers/test_models_transformer_flux.py +++ b/tests/models/transformers/test_models_transformer_flux.py @@ -37,6 +37,7 @@ IPAdapterTesterMixin, LoraHotSwappingForModelTesterMixin, LoraTesterMixin, + MagCacheTesterMixin, MemoryTesterMixin, ModelOptCompileTesterMixin, ModelOptTesterMixin, @@ -45,6 +46,7 @@ QuantoCompileTesterMixin, QuantoTesterMixin, SingleFileTesterMixin, + TaylorSeerCacheTesterMixin, TorchAoCompileTesterMixin, TorchAoTesterMixin, TorchCompileTesterMixin, @@ -430,3 +432,11 @@ class TestFluxTransformerFasterCache(FluxTransformerTesterConfig, FasterCacheTes "tensor_format": "BCHW", "is_guidance_distilled": True, } + + +class TestFluxTransformerMagCache(FluxTransformerTesterConfig, MagCacheTesterMixin): + """MagCache tests for Flux Transformer.""" + + +class TestFluxTransformerTaylorSeerCache(FluxTransformerTesterConfig, TaylorSeerCacheTesterMixin): + """TaylorSeerCache tests for Flux Transformer."""