|
24 | 24 | DDIMScheduler, |
25 | 25 | DiffusionPipeline, |
26 | 26 | KolorsPipeline, |
| 27 | + PyramidAttentionBroadcastConfig, |
27 | 28 | StableDiffusionPipeline, |
28 | 29 | StableDiffusionXLPipeline, |
29 | 30 | UNet2DConditionModel, |
| 31 | + apply_pyramid_attention_broadcast, |
30 | 32 | ) |
31 | 33 | from diffusers.image_processor import VaeImageProcessor |
32 | 34 | from diffusers.loaders import FluxIPAdapterMixin, IPAdapterMixin |
|
36 | 38 | from diffusers.models.unets.unet_i2vgen_xl import I2VGenXLUNet |
37 | 39 | from diffusers.models.unets.unet_motion_model import UNetMotionModel |
38 | 40 | from diffusers.pipelines.pipeline_utils import StableDiffusionMixin |
| 41 | +from diffusers.pipelines.pyramid_attention_broadcast_utils import PyramidAttentionBroadcastHook |
39 | 42 | from diffusers.schedulers import KarrasDiffusionSchedulers |
40 | 43 | from diffusers.utils import logging |
41 | 44 | from diffusers.utils.import_utils import is_xformers_available |
@@ -2271,6 +2274,109 @@ def _test_save_load_optional_components(self, expected_max_difference=1e-4): |
2271 | 2274 | self.assertLess(max_diff, expected_max_difference) |
2272 | 2275 |
|
2273 | 2276 |
|
| 2277 | +class PyramidAttentionBroadcastTesterMixin: |
| 2278 | + pab_config = PyramidAttentionBroadcastConfig( |
| 2279 | + spatial_attention_block_skip_range=2, |
| 2280 | + spatial_attention_timestep_skip_range=(100, 800), |
| 2281 | + spatial_attention_block_identifiers=["transformer_blocks"], |
| 2282 | + ) |
| 2283 | + |
| 2284 | + def test_pyramid_attention_broadcast_layers(self): |
| 2285 | + device = "cpu" # ensure determinism for the device-dependent torch.Generator |
| 2286 | + num_layers = 2 |
| 2287 | + components = self.get_dummy_components(num_layers=num_layers) |
| 2288 | + pipe = self.pipeline_class(**components) |
| 2289 | + pipe.set_progress_bar_config(disable=None) |
| 2290 | + |
| 2291 | + apply_pyramid_attention_broadcast(pipe, self.pab_config) |
| 2292 | + |
| 2293 | + expected_hooks = 0 |
| 2294 | + if self.pab_config.spatial_attention_block_skip_range is not None: |
| 2295 | + expected_hooks += num_layers |
| 2296 | + if self.pab_config.temporal_attention_block_skip_range is not None: |
| 2297 | + expected_hooks += num_layers |
| 2298 | + if self.pab_config.cross_attention_block_skip_range is not None: |
| 2299 | + expected_hooks += num_layers |
| 2300 | + |
| 2301 | + denoiser = pipe.transformer if hasattr(pipe, "transformer") else pipe.unet |
| 2302 | + count = 0 |
| 2303 | + for module in denoiser.modules(): |
| 2304 | + if hasattr(module, "_diffusers_hook"): |
| 2305 | + count += 1 |
| 2306 | + self.assertTrue( |
| 2307 | + isinstance(module._diffusers_hook, PyramidAttentionBroadcastHook), |
| 2308 | + "Hook should be of type PyramidAttentionBroadcastHook.", |
| 2309 | + ) |
| 2310 | + self.assertTrue( |
| 2311 | + hasattr(module, "_pyramid_attention_broadcast_state"), |
| 2312 | + "PAB state should be initialized when enabled.", |
| 2313 | + ) |
| 2314 | + self.assertTrue( |
| 2315 | + module._pyramid_attention_broadcast_state.cache is None, "Cache should be None at initialization." |
| 2316 | + ) |
| 2317 | + self.assertEqual(count, expected_hooks, "Number of hooks should match the expected number.") |
| 2318 | + |
| 2319 | + # Perform dummy inference step to ensure state is updated |
| 2320 | + def pab_state_check_callback(pipe, i, t, kwargs): |
| 2321 | + for module in denoiser.modules(): |
| 2322 | + if hasattr(module, "_diffusers_hook"): |
| 2323 | + self.assertTrue( |
| 2324 | + module._pyramid_attention_broadcast_state.cache is not None, |
| 2325 | + "Cache should have updated during inference.", |
| 2326 | + ) |
| 2327 | + self.assertTrue( |
| 2328 | + module._pyramid_attention_broadcast_state.iteration == i + 1, |
| 2329 | + "Hook iteration state should have updated during inference.", |
| 2330 | + ) |
| 2331 | + return {} |
| 2332 | + |
| 2333 | + inputs = self.get_dummy_inputs(device) |
| 2334 | + inputs["num_inference_steps"] = 2 |
| 2335 | + inputs["callback_on_step_end"] = pab_state_check_callback |
| 2336 | + pipe(**inputs)[0] |
| 2337 | + |
| 2338 | + # After inference, reset_stateful_hooks is called within the pipeline, which should have reset the states |
| 2339 | + for module in denoiser.modules(): |
| 2340 | + if hasattr(module, "_diffusers_hook"): |
| 2341 | + self.assertTrue( |
| 2342 | + module._pyramid_attention_broadcast_state.cache is None, |
| 2343 | + "Cache should be reset to None after inference.", |
| 2344 | + ) |
| 2345 | + self.assertTrue( |
| 2346 | + module._pyramid_attention_broadcast_state.iteration == 0, |
| 2347 | + "Iteration should be reset to 0 after inference.", |
| 2348 | + ) |
| 2349 | + |
| 2350 | + def test_pyramid_attention_broadcast_inference(self, expected_atol: float = 0.2): |
| 2351 | + # We need to use higher tolerance because we are using a random model. With a converged/trained |
| 2352 | + # model, the tolerance can be lower. |
| 2353 | + |
| 2354 | + device = "cpu" # ensure determinism for the device-dependent torch.Generator |
| 2355 | + num_layers = 2 |
| 2356 | + components = self.get_dummy_components(num_layers=num_layers) |
| 2357 | + pipe = self.pipeline_class(**components) |
| 2358 | + pipe = pipe.to(device) |
| 2359 | + pipe.set_progress_bar_config(disable=None) |
| 2360 | + |
| 2361 | + inputs = self.get_dummy_inputs(device) |
| 2362 | + inputs["num_inference_steps"] = 4 |
| 2363 | + output = pipe(**inputs)[0] |
| 2364 | + original_image_slice = output.flatten() |
| 2365 | + original_image_slice = np.concatenate((original_image_slice[:8], original_image_slice[-8:])) |
| 2366 | + |
| 2367 | + apply_pyramid_attention_broadcast(pipe, self.pab_config) |
| 2368 | + |
| 2369 | + inputs = self.get_dummy_inputs(device) |
| 2370 | + inputs["num_inference_steps"] = 4 |
| 2371 | + output = pipe(**inputs)[0] |
| 2372 | + image_slice_pab_enabled = output.flatten() |
| 2373 | + image_slice_pab_enabled = np.concatenate((image_slice_pab_enabled[:8], image_slice_pab_enabled[-8:])) |
| 2374 | + |
| 2375 | + assert np.allclose( |
| 2376 | + original_image_slice, image_slice_pab_enabled, atol=expected_atol |
| 2377 | + ), "PAB outputs should not differ much in specified timestep range." |
| 2378 | + |
| 2379 | + |
2274 | 2380 | # Some models (e.g. unCLIP) are extremely likely to significantly deviate depending on which hardware is used. |
2275 | 2381 | # This helper function is used to check that the image doesn't deviate on average more than 10 pixels from a |
2276 | 2382 | # reference image. |
|
0 commit comments