Skip to content

Commit 0b2629d

Browse files
committed
improve log message; fix latte test
1 parent 2b558ff commit 0b2629d

File tree

2 files changed

+20
-24
lines changed

2 files changed

+20
-24
lines changed

src/diffusers/pipelines/pyramid_attention_broadcast_utils.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -220,7 +220,7 @@ def _apply_pyramid_attention_broadcast_on_attention_class(
220220
is_cross_attention = (
221221
any(f"{identifier}." in name or identifier == name for identifier in config.cross_attention_block_identifiers)
222222
and config.cross_attention_block_skip_range is not None
223-
and not module.is_cross_attention
223+
and module.is_cross_attention
224224
)
225225

226226
block_skip_range, timestep_skip_range, block_type = None, None, None
@@ -238,7 +238,13 @@ def _apply_pyramid_attention_broadcast_on_attention_class(
238238
block_type = "cross"
239239

240240
if block_skip_range is None or timestep_skip_range is None:
241-
logger.warning(f"Unable to apply Pyramid Attention Broadcast to the selected layer: {name}.")
241+
logger.info(
242+
f'Unable to apply Pyramid Attention Broadcast to the selected layer: "{name}" because it does '
243+
f"not match any of the required criteria for spatial, temporal or cross attention layers. Note, "
244+
f"however, that this layer may still be valid for applying PAB. Please specify the correct "
245+
f"block identifiers in the configuration or use the specialized `apply_pyramid_attention_broadcast_on_module` "
246+
f"function to apply PAB to this layer."
247+
)
242248
return
243249

244250
def skip_callback(module: nnModulePAB) -> bool:

tests/pipelines/latte/test_latte.py

Lines changed: 12 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,10 @@
2222
import torch
2323
from transformers import AutoTokenizer, T5EncoderModel
2424

25-
from diffusers import (
26-
AutoencoderKL,
27-
DDIMScheduler,
28-
LattePipeline,
29-
LatteTransformer3DModel,
25+
from diffusers import AutoencoderKL, DDIMScheduler, LattePipeline, LatteTransformer3DModel
26+
from diffusers.pipelines.pyramid_attention_broadcast_utils import (
27+
PyramidAttentionBroadcastConfig,
28+
apply_pyramid_attention_broadcast,
3029
)
3130
from diffusers.utils.import_utils import is_xformers_available
3231
from diffusers.utils.testing_utils import (
@@ -277,33 +276,24 @@ def test_pyramid_attention_broadcast(self):
277276
frames = pipe(**inputs).frames # [B, F, C, H, W]
278277
original_image_slice = frames[0, -2:, -1, -3:, -3:]
279278

280-
pipe.enable_pyramid_attention_broadcast(spatial_attn_skip_range=2, spatial_attn_timestep_range=(100, 800))
281-
assert pipe.pyramid_attention_broadcast_enabled
279+
config = PyramidAttentionBroadcastConfig(
280+
spatial_attention_block_skip_range=2,
281+
temporal_attention_block_skip_range=3,
282+
spatial_attention_timestep_skip_range=(100, 800),
283+
temporal_attention_timestep_skip_range=(100, 800),
284+
)
285+
apply_pyramid_attention_broadcast(pipe, config)
282286

283287
inputs = self.get_dummy_inputs(device)
284288
inputs["num_inference_steps"] = 4
285289
frames = pipe(**inputs).frames
286290
image_slice_pab_enabled = frames[0, -2:, -1, -3:, -3:]
287291

288-
pipe.disable_pyramid_attention_broadcast()
289-
assert not pipe.pyramid_attention_broadcast_enabled
290-
291-
inputs = self.get_dummy_inputs(device)
292-
frames = pipe(**inputs).frames
293-
image_slice_pab_disabled = frames[0, -2:, -1, -3:, -3:]
294-
295292
# We need to use higher tolerance because we are using a random model. With a converged/trained
296293
# model, the tolerance can be lower.
297294
assert np.allclose(
298-
original_image_slice, image_slice_pab_enabled, atol=0.25
295+
original_image_slice, image_slice_pab_enabled, atol=0.2
299296
), "PAB outputs should not differ much in specified timestep range."
300-
print((image_slice_pab_disabled - image_slice_pab_enabled).abs().max())
301-
assert np.allclose(
302-
image_slice_pab_enabled, image_slice_pab_disabled, atol=0.25
303-
), "Outputs, with PAB enabled, shouldn't differ much when PAB is disabled in specified timestep range."
304-
assert np.allclose(
305-
original_image_slice, image_slice_pab_disabled, atol=0.25
306-
), "Original outputs should match when PAB is disabled."
307297

308298

309299
@slow

0 commit comments

Comments
 (0)