Skip to content

Commit f6f737c

Browse files
fix chunk_idxs + subtle latents bug
1 parent e5e6a36 commit f6f737c

File tree

2 files changed

+25
-15
lines changed

2 files changed

+25
-15
lines changed

src/diffusers/models/controlnets/controlnet_cosmos.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -194,7 +194,9 @@ def forward(
194194
if condition_mask is not None:
195195
control_hidden_states = torch.cat([control_hidden_states, condition_mask], dim=1)
196196
else:
197-
control_hidden_states = torch.cat([control_hidden_states, torch.zeros_like(controls_latents[:, :1])], dim=1)
197+
control_hidden_states = torch.cat(
198+
[control_hidden_states, torch.zeros_like(controls_latents[:, :1])], dim=1
199+
)
198200

199201
padding_mask_resized = transforms.functional.resize(
200202
padding_mask, list(control_hidden_states.shape[-2:]), interpolation=transforms.InterpolationMode.NEAREST

src/diffusers/pipelines/cosmos/pipeline_cosmos2_5_transfer.py

Lines changed: 22 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -521,7 +521,7 @@ def __call__(
521521
num_frames_per_chunk: int = 93,
522522
num_inference_steps: int = 36,
523523
guidance_scale: float = 3.0,
524-
num_videos_per_prompt: Optional[int] = 1,
524+
num_videos_per_prompt: int = 1,
525525
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
526526
latents: Optional[torch.Tensor] = None,
527527
prompt_embeds: Optional[torch.Tensor] = None,
@@ -721,15 +721,23 @@ def __call__(
721721
vae_dtype = self.vae.dtype
722722
transformer_dtype = self.transformer.dtype
723723

724-
img_context = torch.zeros(
725-
batch_size,
726-
self.transformer.config.img_context_num_tokens,
727-
self.transformer.config.img_context_dim_in,
728-
device=prompt_embeds.device,
729-
dtype=transformer_dtype,
730-
)
731-
encoder_hidden_states = (prompt_embeds, img_context)
732-
neg_encoder_hidden_states = (negative_prompt_embeds, img_context)
724+
if getattr(self.transformer.config, "img_context_dim_in", None):
725+
img_context = torch.zeros(
726+
batch_size,
727+
self.transformer.config.img_context_num_tokens,
728+
self.transformer.config.img_context_dim_in,
729+
device=prompt_embeds.device,
730+
dtype=transformer_dtype,
731+
)
732+
733+
if num_videos_per_prompt > 1:
734+
img_context = img_context.repeat_interleave(num_videos_per_prompt, dim=0)
735+
736+
encoder_hidden_states = (prompt_embeds, img_context)
737+
neg_encoder_hidden_states = (negative_prompt_embeds, img_context)
738+
else:
739+
encoder_hidden_states = prompt_embeds
740+
neg_encoder_hidden_states = negative_prompt_embeds
733741

734742
if controls is not None and self.controlnet is None:
735743
logger.warning("`controls` was provided but `controlnet` is None; ignoring `controls`.")
@@ -798,7 +806,7 @@ def __call__(
798806
chunk_stride = num_frames_per_chunk - num_conditional_frames
799807
chunk_idxs = [
800808
(start_idx, min(start_idx + num_frames_per_chunk, num_frames_out))
801-
for start_idx in range(0, num_frames_out, chunk_stride)
809+
for start_idx in range(0, num_frames_out - num_conditional_frames, chunk_stride)
802810
]
803811

804812
video_chunks = []
@@ -810,6 +818,7 @@ def decode_latents(latents):
810818
video = self.vae.decode(latents.to(dtype=self.vae.dtype, device=device), return_dict=False)[0]
811819
return video
812820

821+
latents_arg = latents
813822
initial_num_cond_latent_frames = 0 if video is None or controls is not None else num_cond_latent_frames
814823
latent_chunks = []
815824
num_chunks = len(chunk_idxs)
@@ -844,7 +853,7 @@ def decode_latents(latents):
844853
num_cond_latent_frames=initial_num_cond_latent_frames
845854
if chunk_idx == 0
846855
else num_cond_latent_frames,
847-
# latents=latents,
856+
latents=latents_arg,
848857
)
849858
cond_mask = cond_mask.to(transformer_dtype)
850859
cond_timestep = torch.ones_like(cond_indicator) * conditional_frame_timestep
@@ -866,7 +875,6 @@ def decode_latents(latents):
866875
latents_std = self.latents_std.to(device=device, dtype=transformer_dtype)
867876
controls_latents = (controls_latents - latents_mean) / latents_std
868877

869-
# breakpoint()
870878
# Denoising loop
871879
self.scheduler.set_timesteps(num_inference_steps, device=device)
872880
timesteps = self.scheduler.timesteps
@@ -980,7 +988,7 @@ def decode_latents(latents):
980988
video = (video * 255).astype(np.uint8)
981989
video_batch = []
982990
for vid in video:
983-
# vid = self.safety_checker.check_video_safety(vid)
991+
vid = self.safety_checker.check_video_safety(vid)
984992
if vid is None:
985993
video_batch.append(np.zeros_like(video[0]))
986994
else:

0 commit comments

Comments
 (0)