Skip to content

Commit 1b92b1d

Browse files
committed
update tests
1 parent ffbabb5 commit 1b92b1d

File tree

1 file changed

+22
-15
lines changed

1 file changed

+22
-15
lines changed

tests/pipelines/test_pipelines_common.py

Lines changed: 22 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
UNet2DConditionModel,
3131
apply_pyramid_attention_broadcast,
3232
)
33+
from diffusers.hooks.pyramid_attention_broadcast import PyramidAttentionBroadcastHook
3334
from diffusers.image_processor import VaeImageProcessor
3435
from diffusers.loaders import FluxIPAdapterMixin, IPAdapterMixin
3536
from diffusers.models.attention_processor import AttnProcessor
@@ -38,7 +39,6 @@
3839
from diffusers.models.unets.unet_i2vgen_xl import I2VGenXLUNet
3940
from diffusers.models.unets.unet_motion_model import UNetMotionModel
4041
from diffusers.pipelines.pipeline_utils import StableDiffusionMixin
41-
from diffusers.pipelines.pyramid_attention_broadcast_utils import PyramidAttentionBroadcastHook
4242
from diffusers.schedulers import KarrasDiffusionSchedulers
4343
from diffusers.utils import logging
4444
from diffusers.utils.import_utils import is_xformers_available
@@ -2298,7 +2298,9 @@ def test_pyramid_attention_broadcast_layers(self):
22982298
pipe = self.pipeline_class(**components)
22992299
pipe.set_progress_bar_config(disable=None)
23002300

2301-
apply_pyramid_attention_broadcast(pipe, self.pab_config)
2301+
self.pab_config.current_timestep_callback = lambda: pipe._current_timestep
2302+
denoiser = pipe.transformer if hasattr(pipe, "transformer") else pipe.unet
2303+
apply_pyramid_attention_broadcast(denoiser, self.pab_config)
23022304

23032305
expected_hooks = 0
23042306
if self.pab_config.spatial_attention_block_skip_range is not None:
@@ -2312,30 +2314,30 @@ def test_pyramid_attention_broadcast_layers(self):
23122314
count = 0
23132315
for module in denoiser.modules():
23142316
if hasattr(module, "_diffusers_hook"):
2317+
hook = module._diffusers_hook.get_hook("pyramid_attention_broadcast")
2318+
if hook is None:
2319+
continue
23152320
count += 1
23162321
self.assertTrue(
2317-
isinstance(module._diffusers_hook, PyramidAttentionBroadcastHook),
2322+
isinstance(hook, PyramidAttentionBroadcastHook),
23182323
"Hook should be of type PyramidAttentionBroadcastHook.",
23192324
)
2320-
self.assertTrue(
2321-
hasattr(module, "_pyramid_attention_broadcast_state"),
2322-
"PAB state should be initialized when enabled.",
2323-
)
2324-
self.assertTrue(
2325-
module._pyramid_attention_broadcast_state.cache is None, "Cache should be None at initialization."
2326-
)
2325+
self.assertTrue(hook.state.cache is None, "Cache should be None at initialization.")
23272326
self.assertEqual(count, expected_hooks, "Number of hooks should match the expected number.")
23282327

23292328
# Perform dummy inference step to ensure state is updated
23302329
def pab_state_check_callback(pipe, i, t, kwargs):
23312330
for module in denoiser.modules():
23322331
if hasattr(module, "_diffusers_hook"):
2332+
hook = module._diffusers_hook.get_hook("pyramid_attention_broadcast")
2333+
if hook is None:
2334+
continue
23332335
self.assertTrue(
2334-
module._pyramid_attention_broadcast_state.cache is not None,
2336+
hook.state.cache is not None,
23352337
"Cache should have updated during inference.",
23362338
)
23372339
self.assertTrue(
2338-
module._pyramid_attention_broadcast_state.iteration == i + 1,
2340+
hook.state.iteration == i + 1,
23392341
"Hook iteration state should have updated during inference.",
23402342
)
23412343
return {}
@@ -2348,12 +2350,15 @@ def pab_state_check_callback(pipe, i, t, kwargs):
23482350
# After inference, reset_stateful_hooks is called within the pipeline, which should have reset the states
23492351
for module in denoiser.modules():
23502352
if hasattr(module, "_diffusers_hook"):
2353+
hook = module._diffusers_hook.get_hook("pyramid_attention_broadcast")
2354+
if hook is None:
2355+
continue
23512356
self.assertTrue(
2352-
module._pyramid_attention_broadcast_state.cache is None,
2357+
hook.state.cache is None,
23532358
"Cache should be reset to None after inference.",
23542359
)
23552360
self.assertTrue(
2356-
module._pyramid_attention_broadcast_state.iteration == 0,
2361+
hook.state.iteration == 0,
23572362
"Iteration should be reset to 0 after inference.",
23582363
)
23592364

@@ -2374,7 +2379,9 @@ def test_pyramid_attention_broadcast_inference(self, expected_atol: float = 0.2)
23742379
original_image_slice = output.flatten()
23752380
original_image_slice = np.concatenate((original_image_slice[:8], original_image_slice[-8:]))
23762381

2377-
apply_pyramid_attention_broadcast(pipe, self.pab_config)
2382+
self.pab_config.current_timestep_callback = lambda: pipe._current_timestep
2383+
denoiser = pipe.transformer if hasattr(pipe, "transformer") else pipe.unet
2384+
apply_pyramid_attention_broadcast(denoiser, self.pab_config)
23782385

23792386
inputs = self.get_dummy_inputs(device)
23802387
inputs["num_inference_steps"] = 4

0 commit comments

Comments
 (0)