Skip to content

Commit 775bb8c

Browse files
fix
1 parent 347dd17 commit 775bb8c

File tree

2 files changed

+31
-6
lines changed

2 files changed

+31
-6
lines changed

src/diffusers/pipelines/cogview4/pipeline_cogview4.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -389,14 +389,18 @@ def do_classifier_free_guidance(self):
389389
def num_timesteps(self):
390390
return self._num_timesteps
391391

392-
@property
393-
def interrupt(self):
394-
return self._interrupt
395-
396392
@property
397393
def attention_kwargs(self):
398394
return self._attention_kwargs
399395

396+
@property
397+
def current_timestep(self):
398+
return self._current_timestep
399+
400+
@property
401+
def interrupt(self):
402+
return self._interrupt
403+
400404
@torch.no_grad()
401405
@replace_example_docstring(EXAMPLE_DOC_STRING)
402406
def __call__(
@@ -533,6 +537,7 @@ def __call__(
533537
)
534538
self._guidance_scale = guidance_scale
535539
self._attention_kwargs = attention_kwargs
540+
self._current_timestep = None
536541
self._interrupt = False
537542

538543
# Default call parameters
@@ -610,6 +615,7 @@ def __call__(
610615
if self.interrupt:
611616
continue
612617

618+
self._current_timestep = t
613619
latent_model_input = latents.to(transformer_dtype)
614620

615621
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
@@ -661,6 +667,8 @@ def __call__(
661667
if XLA_AVAILABLE:
662668
xm.mark_step()
663669

670+
self._current_timestep = None
671+
664672
if not output_type == "latent":
665673
latents = latents.to(self.vae.dtype) / self.vae.config.scaling_factor
666674
image = self.vae.decode(latents, return_dict=False, generator=generator)[0]

src/diffusers/pipelines/cogview4/pipeline_cogview4_control.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
# limitations under the License.
1515

1616
import inspect
17-
from typing import Callable, Dict, List, Optional, Tuple, Union
17+
from typing import Callable, Dict, List, Optional, Tuple, Union, Any
1818

1919
import numpy as np
2020
import torch
@@ -43,7 +43,7 @@
4343
Examples:
4444
```python
4545
>>> import torch
46-
>>> from diffusers import CogView4Pipeline
46+
>>> from diffusers import CogView4ControlPipeline
4747
4848
>>> pipe = CogView4ControlPipeline.from_pretrained("THUDM/CogView4-6B-Control", torch_dtype=torch.bfloat16)
4949
>>> control_image = load_image(
@@ -420,6 +420,14 @@ def do_classifier_free_guidance(self):
420420
def num_timesteps(self):
421421
return self._num_timesteps
422422

423+
@property
424+
def attention_kwargs(self):
425+
return self._attention_kwargs
426+
427+
@property
428+
def current_timestep(self):
429+
return self._current_timestep
430+
423431
@property
424432
def interrupt(self):
425433
return self._interrupt
@@ -446,6 +454,7 @@ def __call__(
446454
crops_coords_top_left: Tuple[int, int] = (0, 0),
447455
output_type: str = "pil",
448456
return_dict: bool = True,
457+
attention_kwargs: Optional[Dict[str, Any]] = None,
449458
callback_on_step_end: Optional[
450459
Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
451460
] = None,
@@ -559,6 +568,8 @@ def __call__(
559568
negative_prompt_embeds,
560569
)
561570
self._guidance_scale = guidance_scale
571+
self._attention_kwargs = attention_kwargs
572+
self._current_timestep = None
562573
self._interrupt = False
563574

564575
# Default call parameters
@@ -652,6 +663,8 @@ def __call__(
652663
for i, t in enumerate(timesteps):
653664
if self.interrupt:
654665
continue
666+
667+
self._current_timestep = t
655668
latent_model_input = torch.cat([latents, control_image], dim=1).to(transformer_dtype)
656669

657670
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
@@ -664,6 +677,7 @@ def __call__(
664677
original_size=original_size,
665678
target_size=target_size,
666679
crop_coords=crops_coords_top_left,
680+
attention_kwargs=attention_kwargs,
667681
return_dict=False,
668682
)[0]
669683

@@ -676,6 +690,7 @@ def __call__(
676690
original_size=original_size,
677691
target_size=target_size,
678692
crop_coords=crops_coords_top_left,
693+
attention_kwargs=attention_kwargs,
679694
return_dict=False,
680695
)[0]
681696

@@ -700,6 +715,8 @@ def __call__(
700715
if XLA_AVAILABLE:
701716
xm.mark_step()
702717

718+
self._current_timestep = None
719+
703720
if not output_type == "latent":
704721
latents = latents.to(self.vae.dtype) / self.vae.config.scaling_factor
705722
image = self.vae.decode(latents, return_dict=False, generator=generator)[0]

0 commit comments

Comments
 (0)