Skip to content

Commit 88d917d

Browse files
committed
update code example
1 parent 1b92b1d commit 88d917d

File tree

2 files changed

+7
-4
lines changed

2 files changed

+7
-4
lines changed

src/diffusers/hooks/hooks.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414

1515
import functools
16-
from typing import Any, Dict, Tuple
16+
from typing import Any, Dict, Optional, Tuple
1717

1818
import torch
1919

@@ -141,9 +141,9 @@ def new_forward(module, *args, **kwargs):
141141
self.hooks[name] = hook
142142
self._hook_order.append(name)
143143

144-
def get_hook(self, name: str) -> ModelHook:
144+
def get_hook(self, name: str) -> Optional[ModelHook]:
145145
if name not in self.hooks.keys():
146-
raise ValueError(f"Hook with name {name} not found.")
146+
return None
147147
return self.hooks[name]
148148

149149
def remove_hook(self, name: str) -> None:

src/diffusers/hooks/pyramid_attention_broadcast.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -169,12 +169,15 @@ def apply_pyramid_attention_broadcast(
169169
```python
170170
>>> import torch
171171
>>> from diffusers import CogVideoXPipeline, PyramidAttentionBroadcastConfig, apply_pyramid_attention_broadcast
172+
>>> from diffusers.utils import export_to_video
172173
173174
>>> pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16)
174175
>>> pipe.to("cuda")
175176
176177
>>> config = PyramidAttentionBroadcastConfig(
177-
... spatial_attention_block_skip_range=2, spatial_attention_timestep_skip_range=(100, 800)
178+
... spatial_attention_block_skip_range=2,
179+
... spatial_attention_timestep_skip_range=(100, 800),
180+
... current_timestep_callback=lambda: pipe._current_timestep,
178181
... )
179182
>>> apply_pyramid_attention_broadcast(pipe.transformer, config)
180183
```

0 commit comments

Comments
 (0)