Skip to content

Commit d68977d

Browse files
committed
add tests
1 parent 436b772 commit d68977d

File tree

6 files changed

+197
-20
lines changed

6 files changed

+197
-20
lines changed

src/diffusers/pipelines/faster_cache_utils.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -33,11 +33,11 @@
3333
_ATTENTION_CLASSES = (Attention,)
3434

3535
_SPATIAL_ATTENTION_BLOCK_IDENTIFIERS = (
36-
"blocks.*attn1",
37-
"transformer_blocks.*attn1",
38-
"single_transformer_blocks.*attn1",
36+
"blocks.*attn",
37+
"transformer_blocks.*attn",
38+
"single_transformer_blocks.*attn",
3939
)
40-
_TEMPORAL_ATTENTION_BLOCK_IDENTIFIERS = ("temporal_transformer_blocks.*attn1",)
40+
_TEMPORAL_ATTENTION_BLOCK_IDENTIFIERS = ("temporal_transformer_blocks.*attn",)
4141
_UNCOND_COND_INPUT_KWARGS_IDENTIFIERS = (
4242
"hidden_states",
4343
"encoder_hidden_states",
@@ -263,6 +263,7 @@ def apply_faster_cache(
263263
"""
264264

265265
if config is None:
266+
logger.warning("No FasterCacheConfig provided. Using default configuration.")
266267
config = FasterCacheConfig()
267268

268269
if config.attention_weight_callback is None:
@@ -271,7 +272,7 @@ def apply_faster_cache(
271272
# this depends from model-to-model. It is required by the user to provide a weight callback if they want to
272273
# use a different weight function. Defaulting to 0.5 works well in practice for most cases.
273274
logger.warning(
274-
"FasterCache requires an `attention_weight_callback` to be set. Defaulting to using a weight of 0.5 for all timesteps."
275+
"No `attention_weight_callback` provided when enabling FasterCache. Defaulting to using a weight of 0.5 for all timesteps."
275276
)
276277
config.attention_weight_callback = lambda _: 0.5
277278

tests/pipelines/cogvideo/test_cogvideox.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131

3232
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
3333
from ..test_pipelines_common import (
34+
FasterCacheTesterMixin,
3435
PipelineTesterMixin,
3536
check_qkv_fusion_matches_attn_procs_length,
3637
check_qkv_fusion_processors_exist,
@@ -41,7 +42,7 @@
4142
enable_full_determinism()
4243

4344

44-
class CogVideoXPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
45+
class CogVideoXPipelineFastTests(PipelineTesterMixin, FasterCacheTesterMixin, unittest.TestCase):
4546
pipeline_class = CogVideoXPipeline
4647
params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"}
4748
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
@@ -59,7 +60,7 @@ class CogVideoXPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
5960
)
6061
test_xformers_attention = False
6162

62-
def get_dummy_components(self):
63+
def get_dummy_components(self, num_layers: int = 1):
6364
torch.manual_seed(0)
6465
transformer = CogVideoXTransformer3DModel(
6566
# Product of num_attention_heads * attention_head_dim must be divisible by 16 for 3D positional embeddings
@@ -71,7 +72,7 @@ def get_dummy_components(self):
7172
out_channels=4,
7273
time_embed_dim=2,
7374
text_embed_dim=32, # Must match with tiny-random-t5
74-
num_layers=1,
75+
num_layers=num_layers,
7576
sample_width=2, # latent width: 2 -> final width: 16
7677
sample_height=2, # latent height: 2 -> final height: 16
7778
sample_frames=9, # latent frames: (9 - 1) / 4 + 1 = 3 -> final frames: 9

tests/pipelines/flux/test_pipeline_flux.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,28 +16,29 @@
1616
)
1717

1818
from ..test_pipelines_common import (
19+
FasterCacheTesterMixin,
1920
FluxIPAdapterTesterMixin,
2021
PipelineTesterMixin,
2122
check_qkv_fusion_matches_attn_procs_length,
2223
check_qkv_fusion_processors_exist,
2324
)
2425

2526

26-
class FluxPipelineFastTests(unittest.TestCase, PipelineTesterMixin, FluxIPAdapterTesterMixin):
27+
class FluxPipelineFastTests(unittest.TestCase, PipelineTesterMixin, FluxIPAdapterTesterMixin, FasterCacheTesterMixin):
2728
pipeline_class = FluxPipeline
2829
params = frozenset(["prompt", "height", "width", "guidance_scale", "prompt_embeds", "pooled_prompt_embeds"])
2930
batch_params = frozenset(["prompt"])
3031

3132
# there is no xformers processor for Flux
3233
test_xformers_attention = False
3334

34-
def get_dummy_components(self):
35+
def get_dummy_components(self, num_layers: int = 1, num_single_layers: int = 1):
3536
torch.manual_seed(0)
3637
transformer = FluxTransformer2DModel(
3738
patch_size=1,
3839
in_channels=4,
39-
num_layers=1,
40-
num_single_layers=1,
40+
num_layers=num_layers,
41+
num_single_layers=num_single_layers,
4142
attention_head_dim=16,
4243
num_attention_heads=2,
4344
joint_attention_dim=32,

tests/pipelines/latte/test_latte.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from diffusers import (
2626
AutoencoderKL,
2727
DDIMScheduler,
28+
FasterCacheConfig,
2829
LattePipeline,
2930
LatteTransformer3DModel,
3031
)
@@ -38,13 +39,13 @@
3839
)
3940

4041
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
41-
from ..test_pipelines_common import PipelineTesterMixin, to_np
42+
from ..test_pipelines_common import FasterCacheTesterMixin, PipelineTesterMixin, to_np
4243

4344

4445
enable_full_determinism()
4546

4647

47-
class LattePipelineFastTests(PipelineTesterMixin, unittest.TestCase):
48+
class LattePipelineFastTests(PipelineTesterMixin, FasterCacheTesterMixin, unittest.TestCase):
4849
pipeline_class = LattePipeline
4950
params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"}
5051
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
@@ -53,11 +54,20 @@ class LattePipelineFastTests(PipelineTesterMixin, unittest.TestCase):
5354

5455
required_optional_params = PipelineTesterMixin.required_optional_params
5556

56-
def get_dummy_components(self):
57+
fastercache_config = FasterCacheConfig(
58+
spatial_attention_block_skip_range=2,
59+
temporal_attention_block_skip_range=2,
60+
spatial_attention_timestep_skip_range=(-1, 901),
61+
temporal_attention_timestep_skip_range=(-1, 901),
62+
unconditional_batch_skip_range=2,
63+
attention_weight_callback=lambda _: 0.5,
64+
)
65+
66+
def get_dummy_components(self, num_layers: int = 1):
5767
torch.manual_seed(0)
5868
transformer = LatteTransformer3DModel(
5969
sample_size=8,
60-
num_layers=1,
70+
num_layers=num_layers,
6171
patch_size=2,
6272
attention_head_dim=8,
6373
num_attention_heads=3,

tests/pipelines/mochi/test_mochi.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,13 +30,13 @@
3030
)
3131

3232
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
33-
from ..test_pipelines_common import PipelineTesterMixin, to_np
33+
from ..test_pipelines_common import FasterCacheTesterMixin, PipelineTesterMixin, to_np
3434

3535

3636
enable_full_determinism()
3737

3838

39-
class MochiPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
39+
class MochiPipelineFastTests(PipelineTesterMixin, FasterCacheTesterMixin, unittest.TestCase):
4040
pipeline_class = MochiPipeline
4141
params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"}
4242
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
@@ -54,13 +54,13 @@ class MochiPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
5454
)
5555
test_xformers_attention = False
5656

57-
def get_dummy_components(self):
57+
def get_dummy_components(self, num_layers: int = 2):
5858
torch.manual_seed(0)
5959
transformer = MochiTransformer3DModel(
6060
patch_size=2,
6161
num_attention_heads=2,
6262
attention_head_dim=8,
63-
num_layers=2,
63+
num_layers=num_layers,
6464
pooled_projection_dim=16,
6565
in_channels=12,
6666
out_channels=None,

tests/pipelines/test_pipelines_common.py

Lines changed: 164 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,12 @@
2323
ConsistencyDecoderVAE,
2424
DDIMScheduler,
2525
DiffusionPipeline,
26+
FasterCacheConfig,
2627
KolorsPipeline,
2728
StableDiffusionPipeline,
2829
StableDiffusionXLPipeline,
2930
UNet2DConditionModel,
31+
apply_faster_cache,
3032
)
3133
from diffusers.image_processor import VaeImageProcessor
3234
from diffusers.loaders import FluxIPAdapterMixin, IPAdapterMixin
@@ -35,6 +37,7 @@
3537
from diffusers.models.unets.unet_3d_condition import UNet3DConditionModel
3638
from diffusers.models.unets.unet_i2vgen_xl import I2VGenXLUNet
3739
from diffusers.models.unets.unet_motion_model import UNetMotionModel
40+
from diffusers.pipelines.faster_cache_utils import FasterCacheBlockHook, FasterCacheDenoiserHook
3841
from diffusers.pipelines.pipeline_utils import StableDiffusionMixin
3942
from diffusers.schedulers import KarrasDiffusionSchedulers
4043
from diffusers.utils import logging
@@ -2271,6 +2274,167 @@ def _test_save_load_optional_components(self, expected_max_difference=1e-4):
22712274
self.assertLess(max_diff, expected_max_difference)
22722275

22732276

2277+
class FasterCacheTesterMixin:
2278+
fastercache_config = FasterCacheConfig(
2279+
spatial_attention_block_skip_range=2,
2280+
spatial_attention_timestep_skip_range=(-1, 901),
2281+
unconditional_batch_skip_range=2,
2282+
attention_weight_callback=lambda _: 0.5,
2283+
)
2284+
2285+
def test_fastercache_basic_warning_or_errors_raised(self):
2286+
components = self.get_dummy_components()
2287+
2288+
logger = logging.get_logger("diffusers.pipelines.faster_cache_utils")
2289+
logger.setLevel(logging.INFO)
2290+
2291+
# Check if warning is raised when no FasterCacheConfig is provided
2292+
pipe = self.pipeline_class(**components)
2293+
with CaptureLogger(logger) as cap_logger:
2294+
apply_faster_cache(pipe)
2295+
self.assertTrue("No FasterCacheConfig provided" in cap_logger.out)
2296+
2297+
# Check if warning is raise when no attention_weight_callback is provided
2298+
pipe = self.pipeline_class(**components)
2299+
with CaptureLogger(logger) as cap_logger:
2300+
config = FasterCacheConfig(spatial_attention_block_skip_range=2, attention_weight_callback=None)
2301+
apply_faster_cache(pipe, config)
2302+
self.assertTrue("No `attention_weight_callback` provided when enabling FasterCache" in cap_logger.out)
2303+
2304+
# Check if error raised when unsupported tensor format used
2305+
pipe = self.pipeline_class(**components)
2306+
with self.assertRaises(ValueError):
2307+
config = FasterCacheConfig(spatial_attention_block_skip_range=2, tensor_format="BFHWC")
2308+
apply_faster_cache(pipe, config)
2309+
2310+
def test_fastercache_inference(self, expected_atol: float = 0.1):
2311+
device = "cpu" # ensure determinism for the device-dependent torch.Generator
2312+
num_layers = 2
2313+
components = self.get_dummy_components(num_layers=num_layers)
2314+
pipe = self.pipeline_class(**components)
2315+
pipe = pipe.to(device)
2316+
pipe.set_progress_bar_config(disable=None)
2317+
2318+
inputs = self.get_dummy_inputs(device)
2319+
inputs["num_inference_steps"] = 4
2320+
output = pipe(**inputs)[0]
2321+
original_image_slice = output.flatten()
2322+
original_image_slice = np.concatenate((original_image_slice[:8], original_image_slice[-8:]))
2323+
2324+
apply_faster_cache(pipe, self.fastercache_config)
2325+
2326+
inputs = self.get_dummy_inputs(device)
2327+
inputs["num_inference_steps"] = 4
2328+
output = pipe(**inputs)[0]
2329+
image_slice_fastercache_enabled = output.flatten()
2330+
image_slice_fastercache_enabled = np.concatenate(
2331+
(image_slice_fastercache_enabled[:8], image_slice_fastercache_enabled[-8:])
2332+
)
2333+
2334+
assert np.allclose(
2335+
original_image_slice, image_slice_fastercache_enabled, atol=expected_atol
2336+
), "FasterCache outputs should not differ much in specified timestep range."
2337+
2338+
def test_fastercache_state(self):
2339+
device = "cpu" # ensure determinism for the device-dependent torch.Generator
2340+
2341+
num_layers = 0
2342+
num_single_layers = 0
2343+
dummy_component_kwargs = {}
2344+
dummy_component_parameters = inspect.signature(self.get_dummy_components).parameters
2345+
if "num_layers" in dummy_component_parameters:
2346+
num_layers = 2
2347+
dummy_component_kwargs["num_layers"] = num_layers
2348+
if "num_single_layers" in dummy_component_parameters:
2349+
num_single_layers = 2
2350+
dummy_component_kwargs["num_single_layers"] = num_single_layers
2351+
2352+
components = self.get_dummy_components(**dummy_component_kwargs)
2353+
pipe = self.pipeline_class(**components)
2354+
pipe.set_progress_bar_config(disable=None)
2355+
2356+
apply_faster_cache(pipe, self.fastercache_config)
2357+
2358+
expected_hooks = 0
2359+
if self.fastercache_config.spatial_attention_block_skip_range is not None:
2360+
expected_hooks += num_layers + num_single_layers
2361+
if self.fastercache_config.temporal_attention_block_skip_range is not None:
2362+
expected_hooks += num_layers + num_single_layers
2363+
2364+
# Check if fastercache denoiser hook is attached
2365+
denoiser = pipe.transformer if hasattr(pipe, "transformer") else pipe.unet
2366+
self.assertTrue(
2367+
hasattr(denoiser, "_diffusers_hook") and isinstance(denoiser._diffusers_hook, FasterCacheDenoiserHook),
2368+
"Hook should be of type FasterCacheDenoiserHook.",
2369+
)
2370+
2371+
# Check if all blocks have fastercache block hook attached
2372+
count = 0
2373+
for name, module in denoiser.named_modules():
2374+
if hasattr(module, "_diffusers_hook"):
2375+
if name == "":
2376+
# Skip the root denoiser module
2377+
continue
2378+
count += 1
2379+
self.assertTrue(
2380+
isinstance(module._diffusers_hook, FasterCacheBlockHook),
2381+
"Hook should be of type FasterCacheBlockHook.",
2382+
)
2383+
self.assertEqual(count, expected_hooks, "Number of hooks should match expected number.")
2384+
2385+
# Perform inference to ensure that states are updated correctly
2386+
def fastercache_state_check_callback(pipe, i, t, kwargs):
2387+
for name, module in denoiser.named_modules():
2388+
if not hasattr(module, "_diffusers_hook"):
2389+
continue
2390+
2391+
state = module._fastercache_state
2392+
2393+
if name == "":
2394+
# Root denoiser module
2395+
self.assertTrue(state.low_frequency_delta is not None, "Low frequency delta should be set.")
2396+
self.assertTrue(state.high_frequency_delta is not None, "High frequency delta should be set.")
2397+
else:
2398+
# Internal blocks
2399+
self.assertTrue(state.cache is not None and len(state.cache) == 2, "Cache should be set.")
2400+
2401+
self.assertTrue(state.iteration == i + 1, "Hook iteration state should have updated during inference.")
2402+
self.assertTrue(
2403+
state.is_guidance_distilled is not None,
2404+
"`is_guidance_distilled` should be set to either True or False.",
2405+
)
2406+
2407+
return {}
2408+
2409+
inputs = self.get_dummy_inputs(device)
2410+
inputs["num_inference_steps"] = 4
2411+
inputs["callback_on_step_end"] = fastercache_state_check_callback
2412+
_ = pipe(**inputs)[0]
2413+
2414+
# After inference, reset_stateful_hooks is called within the pipeline, which should have reset the states
2415+
for name, module in denoiser.named_modules():
2416+
if not hasattr(module, "_diffusers_hook"):
2417+
continue
2418+
2419+
state = module._fastercache_state
2420+
2421+
if name == "":
2422+
# Root denoiser module
2423+
self.assertTrue(state.iteration == 0, "Iteration should be reset to 0.")
2424+
self.assertTrue(state.low_frequency_delta is None, "Low frequency delta should be reset to None.")
2425+
self.assertTrue(state.high_frequency_delta is None, "High frequency delta should be reset to None.")
2426+
self.assertTrue(
2427+
state.is_guidance_distilled is None, "`is_guidance_distilled` should be reset to None."
2428+
)
2429+
else:
2430+
self.assertTrue(state.iteration == 0, "Iteration should be reset to 0.")
2431+
self.assertTrue(state.batch_size is None, "Batch size should be reset to None.")
2432+
self.assertTrue(state.cache is None, "Cache should be reset to None.")
2433+
self.assertTrue(
2434+
state.is_guidance_distilled is None, "`is_guidance_distilled` should be reset to None."
2435+
)
2436+
2437+
22742438
# Some models (e.g. unCLIP) are extremely likely to significantly deviate depending on which hardware is used.
22752439
# This helper function is used to check that the image doesn't deviate on average more than 10 pixels from a
22762440
# reference image.

0 commit comments

Comments
 (0)