Skip to content

Commit 5c6dd86

Browse files
wip
1 parent 65eab52 commit 5c6dd86

File tree

7 files changed

+1010
-49
lines changed

7 files changed

+1010
-49
lines changed

scripts/convert_cosmos_to_diffusers.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -594,14 +594,20 @@ def convert_controlnet(transformer_type: str, state_dict: Dict[str, Any], weight
594594
raise AssertionError(f"{transformer_type} does not define a ControlNet config")
595595

596596
PREFIX_KEY = "net."
597+
old2new = {}
598+
new2old = {}
597599
for key in list(state_dict.keys()):
598600
new_key = key[:]
599601
if new_key.startswith(PREFIX_KEY):
600602
new_key = new_key.removeprefix(PREFIX_KEY)
601603
for replace_key, rename_key in CONTROLNET_KEYS_RENAME_DICT.items():
602604
new_key = new_key.replace(replace_key, rename_key)
605+
old2new[key] = new_key
606+
new2old[new_key] = key
603607
update_state_dict_(state_dict, key, new_key)
604608

609+
breakpoint()
610+
605611
for key in list(state_dict.keys()):
606612
for special_key, handler_fn_inplace in CONTROLNET_SPECIAL_KEYS_REMAP.items():
607613
if special_key not in key:

src/diffusers/models/controlnets/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
if is_torch_available():
55
from .controlnet import ControlNetModel, ControlNetOutput
6-
from .controlnet_cosmos import CosmosControlNetModel, CosmosControlNetOutput
6+
from .controlnet_cosmos import CosmosControlNetModel
77
from .controlnet_flux import FluxControlNetModel, FluxControlNetOutput, FluxMultiControlNetModel
88
from .controlnet_hunyuan import (
99
HunyuanControlNetOutput,

src/diffusers/models/controlnets/controlnet_cosmos.py

Lines changed: 5 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -8,18 +8,15 @@
88
from ...loaders import FromOriginalModelMixin
99
from ...utils import BaseOutput, logging
1010
from ..modeling_utils import ModelMixin
11-
from ..transformers.transformer_cosmos import CosmosPatchEmbed
11+
from ..transformers.transformer_cosmos import (
12+
CosmosPatchEmbed,
13+
)
1214
from .controlnet import zero_module
1315

1416

1517
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
1618

1719

18-
@dataclass
19-
class CosmosControlNetOutput(BaseOutput):
20-
block_controlnet_hidden_states: Tuple[torch.Tensor]
21-
22-
2320
class CosmosControlNetBlock(nn.Module):
2421
def __init__(self, hidden_size: int):
2522
super().__init__()
@@ -82,16 +79,12 @@ def forward(
8279
encoder_hidden_states: Optional[torch.Tensor] = None,
8380
conditioning_scale: Union[float, List[float]] = 1.0,
8481
return_dict: bool = True,
85-
) -> Union[Tuple[Tuple[torch.Tensor, ...]], CosmosControlNetOutput]:
82+
) -> List[torch.Tensor]:
8683
del hidden_states, timestep, encoder_hidden_states # not used in this minimal control path
8784

8885
control_hidden_states = self.patch_embed(controlnet_cond)
8986
control_hidden_states = control_hidden_states.flatten(1, 3)
9087

9188
scales = self._expand_conditioning_scale(conditioning_scale)
9289
control_residuals = tuple(block(control_hidden_states) * scale for block, scale in zip(self.control_blocks, scales))
93-
94-
if not return_dict:
95-
return (control_residuals,)
96-
97-
return CosmosControlNetOutput(block_controlnet_hidden_states=control_residuals)
90+
return control_residuals

src/diffusers/models/transformers/transformer_cosmos.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -740,9 +740,8 @@ def forward(
740740
encoder_hidden_states = self.crossattn_proj(encoder_hidden_states)
741741

742742
controlnet_block_index_map = {}
743-
if block_controlnet_hidden_states:
743+
if block_controlnet_hidden_states is not None:
744744
n_blocks = len(self.transformer_blocks)
745-
# TODO: don't use a dict?
746745
controlnet_block_index_map = {
747746
block_idx: block_controlnet_hidden_states[idx]
748747
for idx, block_idx in list(enumerate(range(0, n_blocks, self.config.controlnet_block_every_n)))[0:self.config.n_controlnet_blocks]

src/diffusers/pipelines/cosmos/pipeline_cosmos2_5_predict.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -433,9 +433,6 @@ def prepare_latents(
433433
else:
434434
if video is None:
435435
raise ValueError("`video` must be provided when `num_frames_in` is greater than 0.")
436-
needs_preprocessing = not (isinstance(video, torch.Tensor) and video.ndim == 5 and video.shape[1] == 3)
437-
if needs_preprocessing:
438-
video = self.video_processor.preprocess_video(video, height, width)
439436
video = video.to(device=device, dtype=self.vae.dtype)
440437
if isinstance(generator, list):
441438
cond_latents = [

src/diffusers/pipelines/cosmos/pipeline_cosmos2_5_transfer.py

Lines changed: 36 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -199,7 +199,7 @@ def __init__(
199199
transformer: CosmosTransformer3DModel,
200200
vae: AutoencoderKLWan,
201201
scheduler: UniPCMultistepScheduler,
202-
controlnet: Optional[CosmosControlNetModel] = None,
202+
controlnet: CosmosControlNetModel,
203203
safety_checker: CosmosSafetyChecker = None,
204204
):
205205
super().__init__()
@@ -474,23 +474,25 @@ def prepare_latents(
474474
cond_indicator,
475475
)
476476

477-
def _encode_controlnet_image(
477+
def _encode_controls(
478478
self,
479-
control_image: Optional[torch.Tensor],
479+
controls: Optional[torch.Tensor],
480480
height: int,
481481
width: int,
482482
num_frames: int,
483483
dtype: torch.dtype,
484484
device: torch.device,
485485
) -> Optional[torch.Tensor]:
486-
if control_image is None:
486+
if controls is None:
487487
return None
488488

489-
control_video = self.video_processor.preprocess_video(control_image, height, width)
490-
if control_video.shape[2] < num_frames:
491-
n_pad_frames = num_frames - control_video.shape[2]
492-
last_frame = control_video[:, :, -1:, :, :]
493-
control_video = torch.cat((control_video, last_frame.repeat(1, 1, n_pad_frames, 1, 1)), dim=2)
489+
# TODO: handle image differently?
490+
control_video = self.video_processor.preprocess_video(controls, height, width)
491+
# TODO: is this needed?
492+
# if control_video.shape[2] < num_frames:
493+
# n_pad_frames = num_frames - control_video.shape[2]
494+
# last_frame = control_video[:, :, -1:, :, :]
495+
# control_video = torch.cat((control_video, last_frame.repeat(1, 1, n_pad_frames, 1, 1)), dim=2)
494496

495497
control_video = control_video.to(device=device, dtype=self.vae.dtype)
496498
control_latents = [retrieve_latents(self.vae.encode(vid.unsqueeze(0))) for vid in control_video]
@@ -568,8 +570,8 @@ def __call__(
568570
num_videos_per_prompt: Optional[int] = 1,
569571
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
570572
latents: Optional[torch.Tensor] = None,
571-
controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
572-
controlnet_conditioning_image: Optional[PipelineImageInput] = None,
573+
controls: Optional[PipelineImageInput | List[PipelineImageInput]] = None,
574+
controls_conditioning_scale: Union[float, List[float]] = 1.0,
573575
prompt_embeds: Optional[torch.Tensor] = None,
574576
negative_prompt_embeds: Optional[torch.Tensor] = None,
575577
output_type: Optional[str] = "pil",
@@ -623,10 +625,10 @@ def __call__(
623625
Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
624626
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
625627
tensor is generated by sampling using the supplied random `generator`.
626-
controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to `1.0`):
627-
The scale factor(s) for the ControlNet outputs. A single float is broadcast to all control blocks.
628-
controlnet_conditioning_image (`PipelineImageInput`, *optional*):
628+
controls (`PipelineImageInput`, `List[PipelineImageInput]`, *optional*):
629629
Control image or video input used by the ControlNet. If `None`, ControlNet is skipped.
630+
controls_conditioning_scale (`float` or `List[float]`, *optional*, defaults to `1.0`):
631+
The scale factor(s) for the ControlNet outputs. A single float is broadcast to all control blocks.
630632
prompt_embeds (`torch.Tensor`, *optional*):
631633
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
632634
provided, text embeddings will be generated from `prompt` input argument.
@@ -765,19 +767,20 @@ def __call__(
765767
cond_timestep = torch.ones_like(cond_indicator) * conditional_frame_timestep
766768
cond_mask = cond_mask.to(transformer_dtype)
767769

768-
controlnet_latents = None
769-
if self.controlnet is not None and controlnet_conditioning_image is not None:
770-
controlnet_latents = self._encode_controlnet_image(
771-
control_image=controlnet_conditioning_image,
770+
controls_latents = None
771+
if controls is not None:
772+
controls_latents = self._encode_controls(
773+
controls,
772774
height=height,
773775
width=width,
774776
num_frames=num_frames,
775777
dtype=torch.float32,
776778
device=device,
777779
)
778-
if controlnet_latents.shape[0] != latents.shape[0]:
779-
repeat_count = latents.shape[0] // controlnet_latents.shape[0]
780-
controlnet_latents = controlnet_latents.repeat_interleave(repeat_count, dim=0)
780+
# TODO: checkme?
781+
# if controls_latents.shape[0] != latents.shape[0]:
782+
# repeat_count = latents.shape[0] // controls_latents.shape[0]
783+
# controls_latents = controls_latents.repeat_interleave(repeat_count, dim=0)
781784

782785
padding_mask = latents.new_zeros(1, 1, height, width, dtype=transformer_dtype)
783786

@@ -805,24 +808,24 @@ def __call__(
805808
in_latents = cond_mask * cond_latent + (1 - cond_mask) * latents
806809
in_latents = in_latents.to(transformer_dtype)
807810
in_timestep = cond_indicator * cond_timestep + (1 - cond_indicator) * sigma_t
808-
control_block_samples = None
809-
if self.controlnet is not None and controlnet_latents is not None:
810-
control_block_samples = self.controlnet(
811+
control_blocks = None
812+
if controls is not None:
813+
control_blocks = self.controlnet(
811814
hidden_states=in_latents,
812-
controlnet_cond=controlnet_latents.to(dtype=transformer_dtype),
815+
controlnet_cond=controls_latents.to(dtype=transformer_dtype),
813816
timestep=in_timestep,
814817
encoder_hidden_states=prompt_embeds,
815-
conditioning_scale=controlnet_conditioning_scale,
818+
conditioning_scale=controls_conditioning_scale,
816819
return_dict=True,
817-
).block_controlnet_hidden_states
818-
control_block_samples = tuple(residual.to(dtype=transformer_dtype) for residual in control_block_samples)
820+
)
821+
819822
noise_pred = self.transformer(
820823
hidden_states=in_latents,
821824
condition_mask=cond_mask,
822825
timestep=in_timestep,
823826
encoder_hidden_states=prompt_embeds,
824827
padding_mask=padding_mask,
825-
block_controlnet_hidden_states=control_block_samples,
828+
block_controlnet_hidden_states=control_blocks,
826829
return_dict=False,
827830
)[0]
828831
# NOTE: replace velocity (noise_pred) with gt_velocity for conditioning inputs only
@@ -835,7 +838,7 @@ def __call__(
835838
timestep=in_timestep,
836839
encoder_hidden_states=negative_prompt_embeds,
837840
padding_mask=padding_mask,
838-
block_controlnet_hidden_states=control_block_samples,
841+
block_controlnet_hidden_states=control_blocks,
839842
return_dict=False,
840843
)[0]
841844
# NOTE: replace velocity (noise_pred_neg) with gt_velocity for conditioning inputs only
@@ -868,7 +871,8 @@ def __call__(
868871
latents_std = self.latents_std.to(latents.device, latents.dtype)
869872
latents = latents * latents_std + latents_mean
870873
video = self.vae.decode(latents.to(self.vae.dtype), return_dict=False)[0]
871-
video = self._match_num_frames(video, num_frames)
874+
# TODO: checkme
875+
# video = self._match_num_frames(video, num_frames)
872876

873877
assert self.safety_checker is not None
874878
self.safety_checker.to(device)
@@ -892,6 +896,7 @@ def __call__(
892896

893897
return CosmosPipelineOutput(frames=video)
894898

899+
# TODO: checkme - this seems like a hack
895900
def _match_num_frames(self, video: torch.Tensor, target_num_frames: int) -> torch.Tensor:
896901
if target_num_frames <= 0 or video.shape[2] == target_num_frames:
897902
return video

0 commit comments

Comments
 (0)