Skip to content

Commit 446e6ea

Browse files
formatting
1 parent 1e674b4 commit 446e6ea

File tree

5 files changed

+61
-33
lines changed

5 files changed

+61
-33
lines changed

scripts/convert_cosmos_to_diffusers.py

Lines changed: 25 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -159,9 +159,11 @@
159159
def remove_keys_(key: str, state_dict: Dict[str, Any]):
160160
state_dict.pop(key)
161161

162+
162163
def update_state_dict_(state_dict: Dict[str, Any], old_key: str, new_key: str) -> Dict[str, Any]:
163164
state_dict[new_key] = state_dict.pop(old_key)
164165

166+
165167
def rename_transformer_blocks_(key: str, state_dict: Dict[str, Any]):
166168
block_index = int(key.split(".")[1].removeprefix("block"))
167169
new_key = key
@@ -459,9 +461,7 @@ def rename_transformer_blocks_(key: str, state_dict: Dict[str, Any]):
459461
}
460462

461463

462-
CONTROLNET_SPECIAL_KEYS_REMAP = {
463-
**TRANSFORMER_SPECIAL_KEYS_REMAP_COSMOS_2_0
464-
}
464+
CONTROLNET_SPECIAL_KEYS_REMAP = {**TRANSFORMER_SPECIAL_KEYS_REMAP_COSMOS_2_0}
465465

466466
VAE_KEYS_RENAME_DICT = {
467467
"down.0": "down_blocks.0",
@@ -553,7 +553,9 @@ def get_state_dict(saved_dict: Dict[str, Any]) -> Dict[str, Any]:
553553

554554

555555
def convert_transformer(
556-
transformer_type: str, state_dict: Optional[Dict[str, Any]] = None, weights_only: bool = True,
556+
transformer_type: str,
557+
state_dict: Optional[Dict[str, Any]] = None,
558+
weights_only: bool = True,
557559
):
558560
PREFIX_KEY = "net."
559561

@@ -613,7 +615,12 @@ def convert_transformer(
613615
return transformer
614616

615617

616-
def convert_controlnet(transformer_type: str, control_state_dict: Dict[str, Any], base_state_dict: Dict[str, Any], weights_only: bool = True):
618+
def convert_controlnet(
619+
transformer_type: str,
620+
control_state_dict: Dict[str, Any],
621+
base_state_dict: Dict[str, Any],
622+
weights_only: bool = True,
623+
):
617624
"""
618625
Convert controlnet weights.
619626
@@ -657,7 +664,7 @@ def convert_controlnet(transformer_type: str, control_state_dict: Dict[str, Any]
657664
for key in list(base_state_dict.keys()):
658665
for transformer_prefix, controlnet_prefix in shared_module_mappings.items():
659666
if key.startswith(transformer_prefix):
660-
controlnet_key = controlnet_prefix + key[len(transformer_prefix):]
667+
controlnet_key = controlnet_prefix + key[len(transformer_prefix) :]
661668
control_state_dict[controlnet_key] = base_state_dict[key].clone()
662669
print(f"Copied shared weight: {key} -> {controlnet_key}", flush=True)
663670
break
@@ -864,7 +871,9 @@ def get_args():
864871
raw_state_dict = None
865872
if args.transformer_ckpt_path is not None:
866873
weights_only = "Cosmos-1.0" in args.transformer_type
867-
raw_state_dict = get_state_dict(torch.load(args.transformer_ckpt_path, map_location="cpu", weights_only=weights_only))
874+
raw_state_dict = get_state_dict(
875+
torch.load(args.transformer_ckpt_path, map_location="cpu", weights_only=weights_only)
876+
)
868877

869878
if raw_state_dict is not None:
870879
if "Transfer" in args.transformer_type:
@@ -879,14 +888,18 @@ def get_args():
879888
assert len(base_state_dict.keys() & control_state_dict.keys()) == 0
880889

881890
# Convert transformer first to get the processed base state dict
882-
transformer = convert_transformer(args.transformer_type, state_dict=base_state_dict, weights_only=weights_only)
891+
transformer = convert_transformer(
892+
args.transformer_type, state_dict=base_state_dict, weights_only=weights_only
893+
)
883894
transformer = transformer.to(dtype=dtype)
884895

885896
# Get converted transformer state dict to copy shared weights to controlnet
886897
converted_base_state_dict = transformer.state_dict()
887898

888899
# Convert controlnet with both control-specific and shared weights from transformer
889-
controlnet = convert_controlnet(args.transformer_type, control_state_dict, converted_base_state_dict, weights_only=weights_only)
900+
controlnet = convert_controlnet(
901+
args.transformer_type, control_state_dict, converted_base_state_dict, weights_only=weights_only
902+
)
890903
controlnet = controlnet.to(dtype=dtype)
891904

892905
if not args.save_pipeline:
@@ -895,7 +908,9 @@ def get_args():
895908
pathlib.Path(args.output_path) / "controlnet", safe_serialization=True, max_shard_size="5GB"
896909
)
897910
else:
898-
transformer = convert_transformer(args.transformer_type, state_dict=raw_state_dict, weights_only=weights_only)
911+
transformer = convert_transformer(
912+
args.transformer_type, state_dict=raw_state_dict, weights_only=weights_only
913+
)
899914
transformer = transformer.to(dtype=dtype)
900915
if not args.save_pipeline:
901916
transformer.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB")

src/diffusers/models/controlnets/controlnet_cosmos.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,8 @@ class CosmosControlNetModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
4141
ControlNet for Cosmos Transfer2.5.
4242
4343
This model duplicates the shared embedding modules from the transformer (patch_embed, time_embed,
44-
learnable_pos_embed, img_context_proj) to enable proper CPU offloading. The forward() method
45-
computes everything internally from raw inputs.
44+
learnable_pos_embed, img_context_proj) to enable proper CPU offloading. The forward() method computes everything
45+
internally from raw inputs.
4646
"""
4747

4848
_supports_gradient_checkpointing = True
@@ -184,7 +184,9 @@ def forward(
184184
control_hidden_states = torch.cat(
185185
[
186186
control_hidden_states,
187-
torch.zeros((B, pad_C, T, H, W), dtype=control_hidden_states.dtype, device=control_hidden_states.device),
187+
torch.zeros(
188+
(B, pad_C, T, H, W), dtype=control_hidden_states.dtype, device=control_hidden_states.device
189+
),
188190
],
189191
dim=1,
190192
)

src/diffusers/models/transformers/transformer_cosmos.py

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -225,13 +225,11 @@ def __call__(
225225

226226
return hidden_states
227227

228+
228229
class CosmosAttnProcessor2_5(CosmosAttnProcessor2_0):
229230
def __init__(self):
230231
if not hasattr(torch.nn.functional, "scaled_dot_product_attention"):
231-
raise ImportError(
232-
"CosmosAttnProcessor2_5 requires PyTorch 2.0. "
233-
"Please upgrade PyTorch to 2.0 or newer."
234-
)
232+
raise ImportError("CosmosAttnProcessor2_5 requires PyTorch 2.0. Please upgrade PyTorch to 2.0 or newer.")
235233

236234
def compute_attn_i2v(
237235
self,
@@ -302,6 +300,7 @@ def __call__(
302300
hidden_states = attn.to_out[1](hidden_states)
303301
return hidden_states
304302

303+
305304
class CosmosAttention(Attention):
306305
def __init__(self, *args, **kwargs):
307306
super().__init__(*args, **kwargs)
@@ -400,7 +399,9 @@ def __init__(
400399
def forward(
401400
self,
402401
hidden_states: torch.Tensor,
403-
encoder_hidden_states: Union[Optional[torch.Tensor], Optional[Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]]],
402+
encoder_hidden_states: Union[
403+
Optional[torch.Tensor], Optional[Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]]
404+
],
404405
embedded_timestep: torch.Tensor,
405406
temb: Optional[torch.Tensor] = None,
406407
image_rotary_emb: Optional[torch.Tensor] = None,
@@ -581,11 +582,11 @@ class CosmosTransformer3DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
581582
img_context_dim_in (`int`, *optional*):
582583
The dimension of the input image context feature vector, i.e. it is the D in [B, N, D].
583584
img_context_num_tokens (`int`):
584-
The number of tokens in the image context feature vector, i.e. it is
585-
the N in [B, N, D]. If `img_context_dim_in` is not provided, then this parameter is ignored.
586-
img_context_dim_out (`int`):
587-
The output dimension of the image context projection layer. If
585+
The number of tokens in the image context feature vector, i.e. it is the N in [B, N, D]. If
588586
`img_context_dim_in` is not provided, then this parameter is ignored.
587+
img_context_dim_out (`int`):
588+
The output dimension of the image context projection layer. If `img_context_dim_in` is not provided, then
589+
this parameter is ignored.
589590
"""
590591

591592
_supports_gradient_checkpointing = True
@@ -739,14 +740,18 @@ def forward(
739740
raise ValueError(f"Expected timestep to have shape [B, 1, T, 1, 1] or [T], but got {timestep.shape}")
740741

741742
# 5. Process encoder hidden states
742-
text_context, img_context = encoder_hidden_states if isinstance(encoder_hidden_states, tuple) else (encoder_hidden_states, None)
743+
text_context, img_context = (
744+
encoder_hidden_states if isinstance(encoder_hidden_states, tuple) else (encoder_hidden_states, None)
745+
)
743746
if self.config.use_crossattn_projection:
744747
text_context = self.crossattn_proj(text_context)
745748

746749
if img_context is not None and self.config.img_context_dim_in:
747750
img_context = self.img_context_proj(img_context)
748751

749-
processed_encoder_hidden_states = (text_context, img_context) if isinstance(encoder_hidden_states, tuple) else text_context
752+
processed_encoder_hidden_states = (
753+
(text_context, img_context) if isinstance(encoder_hidden_states, tuple) else text_context
754+
)
750755

751756
# 6. Build controlnet block index map
752757
controlnet_block_index_map = {}

src/diffusers/pipelines/cosmos/pipeline_cosmos2_5_transfer.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -53,13 +53,15 @@ def __init__(self, *args, **kwargs):
5353

5454
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
5555

56+
5657
def _maybe_pad_video(video: torch.Tensor, num_frames: int):
5758
n_pad_frames = num_frames - video.shape[2]
5859
if n_pad_frames > 0:
5960
last_frame = video[:, :, -1:, :, :]
6061
video = torch.cat((video, last_frame.repeat(1, 1, n_pad_frames, 1, 1)), dim=2)
6162
return video
6263

64+
6365
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
6466
def retrieve_latents(
6567
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
@@ -73,6 +75,7 @@ def retrieve_latents(
7375
else:
7476
raise AttributeError("Could not access latents of provided encoder_output")
7577

78+
7679
def transfer2_5_forward(
7780
transformer: CosmosTransformer3DModel,
7881
controlnet: CosmosControlNetModel,
@@ -87,9 +90,9 @@ def transfer2_5_forward(
8790
"""
8891
Forward pass for Transfer2.5 pipeline.
8992
90-
This function calls both transformer and controlnet's forward() methods directly,
91-
enabling proper CPU offloading. The controlnet computes its own embeddings internally
92-
using duplicated modules (patch_embed_base, time_embed, etc.).
93+
This function calls both transformer and controlnet's forward() methods directly, enabling proper CPU offloading.
94+
The controlnet computes its own embeddings internally using duplicated modules (patch_embed_base, time_embed,
95+
etc.).
9396
9497
Args:
9598
transformer: The CosmosTransformer3DModel
@@ -130,6 +133,7 @@ def transfer2_5_forward(
130133
)[0]
131134
return noise_pred
132135

136+
133137
DEFAULT_NEGATIVE_PROMPT = "The video captures a series of frames showing ugly scenes, static with no motion, motion blur, over-saturation, shaky footage, low resolution, grainy texture, pixelated images, poorly lit areas, underexposed and overexposed scenes, poor color balance, washed out colors, choppy sequences, jerky movements, low frame rate, artifacting, color banding, unnatural transitions, outdated special effects, fake elements, unconvincing visuals, poorly edited content, jump cuts, visual noise, and flickering. Overall, the video is of poor quality."
134138

135139
EXAMPLE_DOC_STRING = """
@@ -501,7 +505,9 @@ def _encode_controls(
501505
control_video = _maybe_pad_video(control_video, num_frames)
502506

503507
control_video = control_video.to(device=device, dtype=self.vae.dtype)
504-
control_latents = [retrieve_latents(self.vae.encode(vid.unsqueeze(0)), generator=generator) for vid in control_video]
508+
control_latents = [
509+
retrieve_latents(self.vae.encode(vid.unsqueeze(0)), generator=generator) for vid in control_video
510+
]
505511
control_latents = torch.cat(control_latents, dim=0).to(dtype)
506512

507513
latents_mean = self.latents_mean.to(device=device, dtype=dtype)
@@ -611,7 +617,8 @@ def __call__(
611617
height (`int`, defaults to `704`):
612618
The height in pixels of the generated image.
613619
width (`int`, *optional*):
614-
The width in pixels of the generated image. If not provided, this will be determined based on the aspect ratio of the input and the provided height.
620+
The width in pixels of the generated image. If not provided, this will be determined based on the
621+
aspect ratio of the input and the provided height.
615622
num_frames (`int`, defaults to `93`):
616623
Number of output frames. Use `93` for world (video) generation; set to `1` to return a single frame.
617624
num_inference_steps (`int`, defaults to `35`):
@@ -684,7 +691,7 @@ def __call__(
684691
frame = controls[0]
685692

686693
if frame is None:
687-
width = int((height + 16) * (1280/720))
694+
width = int((height + 16) * (1280 / 720))
688695
elif isinstance(frame, PIL.Image.Image):
689696
width = int((height + 16) * (frame.width / frame.height))
690697
else:
@@ -839,7 +846,7 @@ def __call__(
839846
in_timestep=in_timestep,
840847
encoder_hidden_states=encoder_hidden_states,
841848
cond_mask=cond_mask,
842-
padding_mask=padding_mask
849+
padding_mask=padding_mask,
843850
)
844851
noise_pred = gt_velocity + noise_pred * (1 - cond_mask)
845852

@@ -853,7 +860,7 @@ def __call__(
853860
in_timestep=in_timestep,
854861
encoder_hidden_states=neg_encoder_hidden_states, # NOTE: negative prompt
855862
cond_mask=cond_mask,
856-
padding_mask=padding_mask
863+
padding_mask=padding_mask,
857864
)
858865
# NOTE: replace velocity (noise_pred_neg) with gt_velocity for conditioning inputs only
859866
noise_pred_neg = gt_velocity + noise_pred_neg * (1 - cond_mask)

tests/pipelines/cosmos/test_cosmos2_5_transfer.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -384,4 +384,3 @@ def test_save_load_optional_components(self, expected_max_difference=1e-4):
384384
)
385385
def test_encode_prompt_works_in_isolation(self):
386386
pass
387-

0 commit comments

Comments
 (0)