Skip to content

Commit e2dae7e

Browse files
author
toilaluan
committed
add tests
1 parent 716dfe1 commit e2dae7e

File tree

6 files changed

+61
-1
lines changed

6 files changed

+61
-1
lines changed

tests/pipelines/flux/test_pipeline_flux.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
FluxIPAdapterTesterMixin,
3030
PipelineTesterMixin,
3131
PyramidAttentionBroadcastTesterMixin,
32+
TaylorSeerCacheTesterMixin,
3233
check_qkv_fused_layers_exist,
3334
)
3435

@@ -39,6 +40,7 @@ class FluxPipelineFastTests(
3940
PyramidAttentionBroadcastTesterMixin,
4041
FasterCacheTesterMixin,
4142
FirstBlockCacheTesterMixin,
43+
TaylorSeerCacheTesterMixin,
4244
unittest.TestCase,
4345
):
4446
pipeline_class = FluxPipeline

tests/pipelines/flux/test_pipeline_flux_kontext.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
FluxIPAdapterTesterMixin,
2020
PipelineTesterMixin,
2121
PyramidAttentionBroadcastTesterMixin,
22+
TaylorSeerCacheTesterMixin,
2223
)
2324

2425

@@ -28,6 +29,7 @@ class FluxKontextPipelineFastTests(
2829
FluxIPAdapterTesterMixin,
2930
PyramidAttentionBroadcastTesterMixin,
3031
FasterCacheTesterMixin,
32+
TaylorSeerCacheTesterMixin,
3133
):
3234
pipeline_class = FluxKontextPipeline
3335
params = frozenset(

tests/pipelines/flux/test_pipeline_flux_kontext_inpaint.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
FluxIPAdapterTesterMixin,
2020
PipelineTesterMixin,
2121
PyramidAttentionBroadcastTesterMixin,
22+
TaylorSeerCacheTesterMixin,
2223
)
2324

2425

@@ -28,6 +29,7 @@ class FluxKontextInpaintPipelineFastTests(
2829
FluxIPAdapterTesterMixin,
2930
PyramidAttentionBroadcastTesterMixin,
3031
FasterCacheTesterMixin,
32+
TaylorSeerCacheTesterMixin,
3133
):
3234
pipeline_class = FluxKontextInpaintPipeline
3335
params = frozenset(

tests/pipelines/flux2/test_pipeline_flux2.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,12 @@
1616
)
1717
from ..test_pipelines_common import (
1818
PipelineTesterMixin,
19+
TaylorSeerCacheTesterMixin,
1920
check_qkv_fused_layers_exist,
2021
)
2122

2223

23-
class Flux2PipelineFastTests(PipelineTesterMixin, unittest.TestCase):
24+
class Flux2PipelineFastTests(PipelineTesterMixin, TaylorSeerCacheTesterMixin, unittest.TestCase):
2425
pipeline_class = Flux2Pipeline
2526
params = frozenset(["prompt", "height", "width", "guidance_scale", "prompt_embeds"])
2627
batch_params = frozenset(["prompt"])

tests/pipelines/hunyuan_video/test_hunyuan_video.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
FirstBlockCacheTesterMixin,
3434
PipelineTesterMixin,
3535
PyramidAttentionBroadcastTesterMixin,
36+
TaylorSeerCacheTesterMixin,
3637
to_np,
3738
)
3839

@@ -45,6 +46,7 @@ class HunyuanVideoPipelineFastTests(
4546
PyramidAttentionBroadcastTesterMixin,
4647
FasterCacheTesterMixin,
4748
FirstBlockCacheTesterMixin,
49+
TaylorSeerCacheTesterMixin,
4850
unittest.TestCase,
4951
):
5052
pipeline_class = HunyuanVideoPipeline

tests/pipelines/test_pipelines_common.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
from diffusers.hooks.faster_cache import FasterCacheBlockHook, FasterCacheDenoiserHook
3737
from diffusers.hooks.first_block_cache import FirstBlockCacheConfig
3838
from diffusers.hooks.pyramid_attention_broadcast import PyramidAttentionBroadcastHook
39+
from diffusers.hooks.taylorseer_cache import TaylorSeerCacheConfig
3940
from diffusers.image_processor import VaeImageProcessor
4041
from diffusers.loaders import FluxIPAdapterMixin, IPAdapterMixin
4142
from diffusers.models.attention import AttentionModuleMixin
@@ -2923,6 +2924,56 @@ def run_forward(pipe):
29232924
"Outputs from normal inference and after disabling cache should not differ."
29242925
)
29252926

2927+
class TaylorSeerCacheTesterMixin:
2928+
taylorseer_cache_config = TaylorSeerCacheConfig(
2929+
cache_interval=5,
2930+
disable_cache_before_step=10,
2931+
max_order=1,
2932+
taylor_factors_dtype=torch.bfloat16,
2933+
use_lite_mode=True,
2934+
)
2935+
2936+
def test_taylorseer_cache_inference(self, expected_atol: float = 0.1):
2937+
device = "cpu" # ensure determinism for the device-dependent torch.Generator
2938+
2939+
def create_pipe():
2940+
torch.manual_seed(0)
2941+
num_layers = 2
2942+
components = self.get_dummy_components(num_layers=num_layers)
2943+
pipe = self.pipeline_class(**components)
2944+
pipe = pipe.to(device)
2945+
pipe.set_progress_bar_config(disable=None)
2946+
return pipe
2947+
2948+
def run_forward(pipe):
2949+
torch.manual_seed(0)
2950+
inputs = self.get_dummy_inputs(device)
2951+
inputs["num_inference_steps"] = 50
2952+
return pipe(**inputs)[0]
2953+
2954+
# Run inference without TaylorSeerCache
2955+
pipe = create_pipe()
2956+
output = run_forward(pipe).flatten()
2957+
original_image_slice = np.concatenate((output[:8], output[-8:]))
2958+
2959+
# Run inference with TaylorSeerCache enabled
2960+
pipe = create_pipe()
2961+
pipe.transformer.enable_cache(self.taylorseer_cache_config)
2962+
output = run_forward(pipe).flatten()
2963+
image_slice_fbc_enabled = np.concatenate((output[:8], output[-8:]))
2964+
2965+
# Run inference with TaylorSeerCache disabled
2966+
pipe.transformer.disable_cache()
2967+
output = run_forward(pipe).flatten()
2968+
image_slice_fbc_disabled = np.concatenate((output[:8], output[-8:]))
2969+
2970+
assert np.allclose(original_image_slice, image_slice_fbc_enabled, atol=expected_atol), (
2971+
"TaylorSeerCache outputs should not differ much."
2972+
)
2973+
assert np.allclose(original_image_slice, image_slice_fbc_disabled, atol=1e-4), (
2974+
"Outputs from normal inference and after disabling cache should not differ."
2975+
)
2976+
29262977

29272978
# Some models (e.g. unCLIP) are extremely likely to significantly deviate depending on which hardware is used.
29282979
# This helper function is used to check that the image doesn't deviate on average more than 10 pixels from a

0 commit comments

Comments
 (0)