|
23 | 23 | ConsistencyDecoderVAE, |
24 | 24 | DDIMScheduler, |
25 | 25 | DiffusionPipeline, |
| 26 | + FasterCacheConfig, |
26 | 27 | KolorsPipeline, |
27 | 28 | StableDiffusionPipeline, |
28 | 29 | StableDiffusionXLPipeline, |
29 | 30 | UNet2DConditionModel, |
| 31 | + apply_faster_cache, |
30 | 32 | ) |
31 | 33 | from diffusers.image_processor import VaeImageProcessor |
32 | 34 | from diffusers.loaders import FluxIPAdapterMixin, IPAdapterMixin |
|
35 | 37 | from diffusers.models.unets.unet_3d_condition import UNet3DConditionModel |
36 | 38 | from diffusers.models.unets.unet_i2vgen_xl import I2VGenXLUNet |
37 | 39 | from diffusers.models.unets.unet_motion_model import UNetMotionModel |
| 40 | +from diffusers.pipelines.faster_cache_utils import FasterCacheBlockHook, FasterCacheDenoiserHook |
38 | 41 | from diffusers.pipelines.pipeline_utils import StableDiffusionMixin |
39 | 42 | from diffusers.schedulers import KarrasDiffusionSchedulers |
40 | 43 | from diffusers.utils import logging |
@@ -2271,6 +2274,167 @@ def _test_save_load_optional_components(self, expected_max_difference=1e-4): |
2271 | 2274 | self.assertLess(max_diff, expected_max_difference) |
2272 | 2275 |
|
2273 | 2276 |
|
| 2277 | +class FasterCacheTesterMixin: |
| 2278 | + fastercache_config = FasterCacheConfig( |
| 2279 | + spatial_attention_block_skip_range=2, |
| 2280 | + spatial_attention_timestep_skip_range=(-1, 901), |
| 2281 | + unconditional_batch_skip_range=2, |
| 2282 | + attention_weight_callback=lambda _: 0.5, |
| 2283 | + ) |
| 2284 | + |
| 2285 | + def test_fastercache_basic_warning_or_errors_raised(self): |
| 2286 | + components = self.get_dummy_components() |
| 2287 | + |
| 2288 | + logger = logging.get_logger("diffusers.pipelines.faster_cache_utils") |
| 2289 | + logger.setLevel(logging.INFO) |
| 2290 | + |
| 2291 | + # Check if warning is raised when no FasterCacheConfig is provided |
| 2292 | + pipe = self.pipeline_class(**components) |
| 2293 | + with CaptureLogger(logger) as cap_logger: |
| 2294 | + apply_faster_cache(pipe) |
| 2295 | + self.assertTrue("No FasterCacheConfig provided" in cap_logger.out) |
| 2296 | + |
| 2297 | + # Check if warning is raise when no attention_weight_callback is provided |
| 2298 | + pipe = self.pipeline_class(**components) |
| 2299 | + with CaptureLogger(logger) as cap_logger: |
| 2300 | + config = FasterCacheConfig(spatial_attention_block_skip_range=2, attention_weight_callback=None) |
| 2301 | + apply_faster_cache(pipe, config) |
| 2302 | + self.assertTrue("No `attention_weight_callback` provided when enabling FasterCache" in cap_logger.out) |
| 2303 | + |
| 2304 | + # Check if error raised when unsupported tensor format used |
| 2305 | + pipe = self.pipeline_class(**components) |
| 2306 | + with self.assertRaises(ValueError): |
| 2307 | + config = FasterCacheConfig(spatial_attention_block_skip_range=2, tensor_format="BFHWC") |
| 2308 | + apply_faster_cache(pipe, config) |
| 2309 | + |
| 2310 | + def test_fastercache_inference(self, expected_atol: float = 0.1): |
| 2311 | + device = "cpu" # ensure determinism for the device-dependent torch.Generator |
| 2312 | + num_layers = 2 |
| 2313 | + components = self.get_dummy_components(num_layers=num_layers) |
| 2314 | + pipe = self.pipeline_class(**components) |
| 2315 | + pipe = pipe.to(device) |
| 2316 | + pipe.set_progress_bar_config(disable=None) |
| 2317 | + |
| 2318 | + inputs = self.get_dummy_inputs(device) |
| 2319 | + inputs["num_inference_steps"] = 4 |
| 2320 | + output = pipe(**inputs)[0] |
| 2321 | + original_image_slice = output.flatten() |
| 2322 | + original_image_slice = np.concatenate((original_image_slice[:8], original_image_slice[-8:])) |
| 2323 | + |
| 2324 | + apply_faster_cache(pipe, self.fastercache_config) |
| 2325 | + |
| 2326 | + inputs = self.get_dummy_inputs(device) |
| 2327 | + inputs["num_inference_steps"] = 4 |
| 2328 | + output = pipe(**inputs)[0] |
| 2329 | + image_slice_fastercache_enabled = output.flatten() |
| 2330 | + image_slice_fastercache_enabled = np.concatenate( |
| 2331 | + (image_slice_fastercache_enabled[:8], image_slice_fastercache_enabled[-8:]) |
| 2332 | + ) |
| 2333 | + |
| 2334 | + assert np.allclose( |
| 2335 | + original_image_slice, image_slice_fastercache_enabled, atol=expected_atol |
| 2336 | + ), "FasterCache outputs should not differ much in specified timestep range." |
| 2337 | + |
| 2338 | + def test_fastercache_state(self): |
| 2339 | + device = "cpu" # ensure determinism for the device-dependent torch.Generator |
| 2340 | + |
| 2341 | + num_layers = 0 |
| 2342 | + num_single_layers = 0 |
| 2343 | + dummy_component_kwargs = {} |
| 2344 | + dummy_component_parameters = inspect.signature(self.get_dummy_components).parameters |
| 2345 | + if "num_layers" in dummy_component_parameters: |
| 2346 | + num_layers = 2 |
| 2347 | + dummy_component_kwargs["num_layers"] = num_layers |
| 2348 | + if "num_single_layers" in dummy_component_parameters: |
| 2349 | + num_single_layers = 2 |
| 2350 | + dummy_component_kwargs["num_single_layers"] = num_single_layers |
| 2351 | + |
| 2352 | + components = self.get_dummy_components(**dummy_component_kwargs) |
| 2353 | + pipe = self.pipeline_class(**components) |
| 2354 | + pipe.set_progress_bar_config(disable=None) |
| 2355 | + |
| 2356 | + apply_faster_cache(pipe, self.fastercache_config) |
| 2357 | + |
| 2358 | + expected_hooks = 0 |
| 2359 | + if self.fastercache_config.spatial_attention_block_skip_range is not None: |
| 2360 | + expected_hooks += num_layers + num_single_layers |
| 2361 | + if self.fastercache_config.temporal_attention_block_skip_range is not None: |
| 2362 | + expected_hooks += num_layers + num_single_layers |
| 2363 | + |
| 2364 | + # Check if fastercache denoiser hook is attached |
| 2365 | + denoiser = pipe.transformer if hasattr(pipe, "transformer") else pipe.unet |
| 2366 | + self.assertTrue( |
| 2367 | + hasattr(denoiser, "_diffusers_hook") and isinstance(denoiser._diffusers_hook, FasterCacheDenoiserHook), |
| 2368 | + "Hook should be of type FasterCacheDenoiserHook.", |
| 2369 | + ) |
| 2370 | + |
| 2371 | + # Check if all blocks have fastercache block hook attached |
| 2372 | + count = 0 |
| 2373 | + for name, module in denoiser.named_modules(): |
| 2374 | + if hasattr(module, "_diffusers_hook"): |
| 2375 | + if name == "": |
| 2376 | + # Skip the root denoiser module |
| 2377 | + continue |
| 2378 | + count += 1 |
| 2379 | + self.assertTrue( |
| 2380 | + isinstance(module._diffusers_hook, FasterCacheBlockHook), |
| 2381 | + "Hook should be of type FasterCacheBlockHook.", |
| 2382 | + ) |
| 2383 | + self.assertEqual(count, expected_hooks, "Number of hooks should match expected number.") |
| 2384 | + |
| 2385 | + # Perform inference to ensure that states are updated correctly |
| 2386 | + def fastercache_state_check_callback(pipe, i, t, kwargs): |
| 2387 | + for name, module in denoiser.named_modules(): |
| 2388 | + if not hasattr(module, "_diffusers_hook"): |
| 2389 | + continue |
| 2390 | + |
| 2391 | + state = module._fastercache_state |
| 2392 | + |
| 2393 | + if name == "": |
| 2394 | + # Root denoiser module |
| 2395 | + self.assertTrue(state.low_frequency_delta is not None, "Low frequency delta should be set.") |
| 2396 | + self.assertTrue(state.high_frequency_delta is not None, "High frequency delta should be set.") |
| 2397 | + else: |
| 2398 | + # Internal blocks |
| 2399 | + self.assertTrue(state.cache is not None and len(state.cache) == 2, "Cache should be set.") |
| 2400 | + |
| 2401 | + self.assertTrue(state.iteration == i + 1, "Hook iteration state should have updated during inference.") |
| 2402 | + self.assertTrue( |
| 2403 | + state.is_guidance_distilled is not None, |
| 2404 | + "`is_guidance_distilled` should be set to either True or False.", |
| 2405 | + ) |
| 2406 | + |
| 2407 | + return {} |
| 2408 | + |
| 2409 | + inputs = self.get_dummy_inputs(device) |
| 2410 | + inputs["num_inference_steps"] = 4 |
| 2411 | + inputs["callback_on_step_end"] = fastercache_state_check_callback |
| 2412 | + _ = pipe(**inputs)[0] |
| 2413 | + |
| 2414 | + # After inference, reset_stateful_hooks is called within the pipeline, which should have reset the states |
| 2415 | + for name, module in denoiser.named_modules(): |
| 2416 | + if not hasattr(module, "_diffusers_hook"): |
| 2417 | + continue |
| 2418 | + |
| 2419 | + state = module._fastercache_state |
| 2420 | + |
| 2421 | + if name == "": |
| 2422 | + # Root denoiser module |
| 2423 | + self.assertTrue(state.iteration == 0, "Iteration should be reset to 0.") |
| 2424 | + self.assertTrue(state.low_frequency_delta is None, "Low frequency delta should be reset to None.") |
| 2425 | + self.assertTrue(state.high_frequency_delta is None, "High frequency delta should be reset to None.") |
| 2426 | + self.assertTrue( |
| 2427 | + state.is_guidance_distilled is None, "`is_guidance_distilled` should be reset to None." |
| 2428 | + ) |
| 2429 | + else: |
| 2430 | + self.assertTrue(state.iteration == 0, "Iteration should be reset to 0.") |
| 2431 | + self.assertTrue(state.batch_size is None, "Batch size should be reset to None.") |
| 2432 | + self.assertTrue(state.cache is None, "Cache should be reset to None.") |
| 2433 | + self.assertTrue( |
| 2434 | + state.is_guidance_distilled is None, "`is_guidance_distilled` should be reset to None." |
| 2435 | + ) |
| 2436 | + |
| 2437 | + |
2274 | 2438 | # Some models (e.g. unCLIP) are extremely likely to significantly deviate depending on which hardware is used. |
2275 | 2439 | # This helper function is used to check that the image doesn't deviate on average more than 10 pixels from a |
2276 | 2440 | # reference image. |
|
0 commit comments