|
36 | 36 | from diffusers.hooks.faster_cache import FasterCacheBlockHook, FasterCacheDenoiserHook |
37 | 37 | from diffusers.hooks.first_block_cache import FirstBlockCacheConfig |
38 | 38 | from diffusers.hooks.pyramid_attention_broadcast import PyramidAttentionBroadcastHook |
| 39 | +from diffusers.hooks.taylorseer_cache import TaylorSeerCacheConfig |
39 | 40 | from diffusers.image_processor import VaeImageProcessor |
40 | 41 | from diffusers.loaders import FluxIPAdapterMixin, IPAdapterMixin |
41 | 42 | from diffusers.models.attention import AttentionModuleMixin |
@@ -2923,6 +2924,56 @@ def run_forward(pipe): |
2923 | 2924 | "Outputs from normal inference and after disabling cache should not differ." |
2924 | 2925 | ) |
2925 | 2926 |
|
| 2927 | +class TaylorSeerCacheTesterMixin: |
| 2928 | + taylorseer_cache_config = TaylorSeerCacheConfig( |
| 2929 | + cache_interval=5, |
| 2930 | + disable_cache_before_step=10, |
| 2931 | + max_order=1, |
| 2932 | + taylor_factors_dtype=torch.bfloat16, |
| 2933 | + use_lite_mode=True, |
| 2934 | + ) |
| 2935 | + |
| 2936 | + def test_taylorseer_cache_inference(self, expected_atol: float = 0.1): |
| 2937 | + device = "cpu" # ensure determinism for the device-dependent torch.Generator |
| 2938 | + |
| 2939 | + def create_pipe(): |
| 2940 | + torch.manual_seed(0) |
| 2941 | + num_layers = 2 |
| 2942 | + components = self.get_dummy_components(num_layers=num_layers) |
| 2943 | + pipe = self.pipeline_class(**components) |
| 2944 | + pipe = pipe.to(device) |
| 2945 | + pipe.set_progress_bar_config(disable=None) |
| 2946 | + return pipe |
| 2947 | + |
| 2948 | + def run_forward(pipe): |
| 2949 | + torch.manual_seed(0) |
| 2950 | + inputs = self.get_dummy_inputs(device) |
| 2951 | + inputs["num_inference_steps"] = 50 |
| 2952 | + return pipe(**inputs)[0] |
| 2953 | + |
| 2954 | + # Run inference without TaylorSeerCache |
| 2955 | + pipe = create_pipe() |
| 2956 | + output = run_forward(pipe).flatten() |
| 2957 | + original_image_slice = np.concatenate((output[:8], output[-8:])) |
| 2958 | + |
| 2959 | + # Run inference with TaylorSeerCache enabled |
| 2960 | + pipe = create_pipe() |
| 2961 | + pipe.transformer.enable_cache(self.taylorseer_cache_config) |
| 2962 | + output = run_forward(pipe).flatten() |
| 2963 | + image_slice_fbc_enabled = np.concatenate((output[:8], output[-8:])) |
| 2964 | + |
| 2965 | + # Run inference with TaylorSeerCache disabled |
| 2966 | + pipe.transformer.disable_cache() |
| 2967 | + output = run_forward(pipe).flatten() |
| 2968 | + image_slice_fbc_disabled = np.concatenate((output[:8], output[-8:])) |
| 2969 | + |
| 2970 | + assert np.allclose(original_image_slice, image_slice_fbc_enabled, atol=expected_atol), ( |
| 2971 | + "TaylorSeerCache outputs should not differ much." |
| 2972 | + ) |
| 2973 | + assert np.allclose(original_image_slice, image_slice_fbc_disabled, atol=1e-4), ( |
| 2974 | + "Outputs from normal inference and after disabling cache should not differ." |
| 2975 | + ) |
| 2976 | + |
2926 | 2977 |
|
2927 | 2978 | # Some models (e.g. unCLIP) are extremely likely to significantly deviate depending on which hardware is used. |
2928 | 2979 | # This helper function is used to check that the image doesn't deviate on average more than 10 pixels from a |
|
0 commit comments