Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
183 changes: 183 additions & 0 deletions scripts/convert_cogview3_to_diffusers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,183 @@
"""
Convert a CogView3 checkpoint to the Diffusers format.

This script converts a CogView3 checkpoint to the Diffusers format, which can then be used
with the Diffusers library.

Example usage:
python scripts/convert_cogview3_to_diffusers.py \
--original_state_dict_repo_id "THUDM/cogview3" \
--filename "cogview3.pt" \
--transformer \
--output_path "./cogview3_diffusers" \
--dtype "bf16"

Alternatively, if you have a local checkpoint:
python scripts/convert_cogview3_to_diffusers.py \
--checkpoint_path '/raid/.cache/huggingface/models--ZP2HF--CogView3-SAT/snapshots/ca86ce9ba94f9a7f2dd109e7a59e4c8ad04121be/cogview3plus_3b/1/mp_rank_00_model_states.pt' \
--transformer \
--output_path "/raid/yiyi/cogview3_diffusers" \
--dtype "bf16"

Arguments:
--original_state_dict_repo_id: The Hugging Face repo ID containing the original checkpoint.
--filename: The filename of the checkpoint in the repo (default: "flux.safetensors").
--checkpoint_path: Path to a local checkpoint file (alternative to repo_id and filename).
--transformer: Flag to convert the transformer model.
--output_path: The path to save the converted model.
--dtype: The dtype to save the model in (default: "bf16", options: "fp16", "bf16", "fp32").

Note: You must provide either --original_state_dict_repo_id or --checkpoint_path.
"""

import argparse
from contextlib import nullcontext

import torch
from accelerate import init_empty_weights
from huggingface_hub import hf_hub_download

from diffusers import CogView3PlusTransformer2DModel
from diffusers.utils.import_utils import is_accelerate_available


CTX = init_empty_weights if is_accelerate_available else nullcontext

parser = argparse.ArgumentParser()
parser.add_argument("--original_state_dict_repo_id", default=None, type=str)
parser.add_argument("--filename", default="flux.safetensors", type=str)
parser.add_argument("--checkpoint_path", default=None, type=str)
parser.add_argument("--transformer", action="store_true")
parser.add_argument("--output_path", type=str)
parser.add_argument("--dtype", type=str, default="bf16")

args = parser.parse_args()


def load_original_checkpoint(args):
if args.original_state_dict_repo_id is not None:
ckpt_path = hf_hub_download(repo_id=args.original_state_dict_repo_id, filename=args.filename)
elif args.checkpoint_path is not None:
ckpt_path = args.checkpoint_path
else:
raise ValueError("Please provide either `original_state_dict_repo_id` or a local `checkpoint_path`")

original_state_dict = torch.load(ckpt_path, map_location="cpu")
return original_state_dict


# this is specific to `AdaLayerNormContinuous`:
# diffusers imnplementation split the linear projection into the scale, shift while CogView3 split it tino shift, scale
def swap_scale_shift(weight, dim):
shift, scale = weight.chunk(2, dim=0)
new_weight = torch.cat([scale, shift], dim=0)
return new_weight


def convert_cogview3_transformer_checkpoint_to_diffusers(original_state_dict):
new_state_dict = {}

# Convert pos_embed
new_state_dict["pos_embed.proj.weight"] = original_state_dict.pop("mixins.patch_embed.proj.weight")
new_state_dict["pos_embed.proj.bias"] = original_state_dict.pop("mixins.patch_embed.proj.bias")
new_state_dict["pos_embed.text_proj.weight"] = original_state_dict.pop("mixins.patch_embed.text_proj.weight")
new_state_dict["pos_embed.text_proj.bias"] = original_state_dict.pop("mixins.patch_embed.text_proj.bias")

# Convert time_text_embed
new_state_dict["time_text_embed.timestep_embedder.linear_1.weight"] = original_state_dict.pop(
"time_embed.0.weight"
)
new_state_dict["time_text_embed.timestep_embedder.linear_1.bias"] = original_state_dict.pop("time_embed.0.bias")
new_state_dict["time_text_embed.timestep_embedder.linear_2.weight"] = original_state_dict.pop(
"time_embed.2.weight"
)
new_state_dict["time_text_embed.timestep_embedder.linear_2.bias"] = original_state_dict.pop("time_embed.2.bias")
new_state_dict["time_text_embed.text_embedder.linear_1.weight"] = original_state_dict.pop("label_emb.0.0.weight")
new_state_dict["time_text_embed.text_embedder.linear_1.bias"] = original_state_dict.pop("label_emb.0.0.bias")
new_state_dict["time_text_embed.text_embedder.linear_2.weight"] = original_state_dict.pop("label_emb.0.2.weight")
new_state_dict["time_text_embed.text_embedder.linear_2.bias"] = original_state_dict.pop("label_emb.0.2.bias")

# Convert transformer blocks
for i in range(30):
block_prefix = f"transformer_blocks.{i}."
old_prefix = f"transformer.layers.{i}."
adaln_prefix = f"mixins.adaln.adaln_modules.{i}."

new_state_dict[block_prefix + "norm1.linear.weight"] = original_state_dict.pop(adaln_prefix + "1.weight")
new_state_dict[block_prefix + "norm1.linear.bias"] = original_state_dict.pop(adaln_prefix + "1.bias")

qkv_weight = original_state_dict.pop(old_prefix + "attention.query_key_value.weight")
qkv_bias = original_state_dict.pop(old_prefix + "attention.query_key_value.bias")
q, k, v = qkv_weight.chunk(3, dim=0)
q_bias, k_bias, v_bias = qkv_bias.chunk(3, dim=0)

new_state_dict[block_prefix + "attn.to_q.weight"] = q
new_state_dict[block_prefix + "attn.to_q.bias"] = q_bias
new_state_dict[block_prefix + "attn.to_k.weight"] = k
new_state_dict[block_prefix + "attn.to_k.bias"] = k_bias
new_state_dict[block_prefix + "attn.to_v.weight"] = v
new_state_dict[block_prefix + "attn.to_v.bias"] = v_bias

new_state_dict[block_prefix + "attn.to_out.0.weight"] = original_state_dict.pop(
old_prefix + "attention.dense.weight"
)
new_state_dict[block_prefix + "attn.to_out.0.bias"] = original_state_dict.pop(
old_prefix + "attention.dense.bias"
)

new_state_dict[block_prefix + "ff.net.0.proj.weight"] = original_state_dict.pop(
old_prefix + "mlp.dense_h_to_4h.weight"
)
new_state_dict[block_prefix + "ff.net.0.proj.bias"] = original_state_dict.pop(
old_prefix + "mlp.dense_h_to_4h.bias"
)
new_state_dict[block_prefix + "ff.net.2.weight"] = original_state_dict.pop(
old_prefix + "mlp.dense_4h_to_h.weight"
)
new_state_dict[block_prefix + "ff.net.2.bias"] = original_state_dict.pop(old_prefix + "mlp.dense_4h_to_h.bias")

# Convert final norm and projection
new_state_dict["norm_out.linear.weight"] = swap_scale_shift(
original_state_dict.pop("mixins.final_layer.adaln.1.weight"), dim=0
)
new_state_dict["norm_out.linear.bias"] = swap_scale_shift(
original_state_dict.pop("mixins.final_layer.adaln.1.bias"), dim=0
)
new_state_dict["proj_out.weight"] = original_state_dict.pop("mixins.final_layer.linear.weight")
new_state_dict["proj_out.bias"] = original_state_dict.pop("mixins.final_layer.linear.bias")

return new_state_dict


def main(args):
original_ckpt = load_original_checkpoint(args)
original_ckpt = original_ckpt["module"]
original_ckpt = {k.replace("model.diffusion_model.", ""): v for k, v in original_ckpt.items()}

original_dtype = next(iter(original_ckpt.values())).dtype
dtype = None
if args.dtype is None:
dtype = original_dtype
elif args.dtype == "fp16":
dtype = torch.float16
elif args.dtype == "bf16":
dtype = torch.bfloat16
elif args.dtype == "fp32":
dtype = torch.float32
else:
raise ValueError(f"Unsupported dtype: {args.dtype}")

if args.transformer:
converted_transformer_state_dict = convert_cogview3_transformer_checkpoint_to_diffusers(original_ckpt)
transformer = CogView3PlusTransformer2DModel()
transformer.load_state_dict(converted_transformer_state_dict, strict=True)

print(f"Saving CogView3 Transformer in Diffusers format in {args.output_path}/transformer")
transformer.to(dtype).save_pretrained(f"{args.output_path}/transformer")

if len(original_ckpt) > 0:
print(f"Warning: {len(original_ckpt)} keys were not converted and will be saved as is: {original_ckpt.keys()}")


if __name__ == "__main__":
main(args)
4 changes: 4 additions & 0 deletions src/diffusers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@
"AutoencoderOobleck",
"AutoencoderTiny",
"CogVideoXTransformer3DModel",
"CogView3PlusTransformer2DModel",
"ConsistencyDecoderVAE",
"ControlNetModel",
"ControlNetXSAdapter",
Expand Down Expand Up @@ -258,6 +259,7 @@
"CogVideoXImageToVideoPipeline",
"CogVideoXPipeline",
"CogVideoXVideoToVideoPipeline",
"CogView3PlusPipeline",
"CycleDiffusionPipeline",
"FluxControlNetImg2ImgPipeline",
"FluxControlNetInpaintPipeline",
Expand Down Expand Up @@ -558,6 +560,7 @@
AutoencoderOobleck,
AutoencoderTiny,
CogVideoXTransformer3DModel,
CogView3PlusTransformer2DModel,
ConsistencyDecoderVAE,
ControlNetModel,
ControlNetXSAdapter,
Expand Down Expand Up @@ -710,6 +713,7 @@
CogVideoXImageToVideoPipeline,
CogVideoXPipeline,
CogVideoXVideoToVideoPipeline,
CogView3PlusPipeline,
CycleDiffusionPipeline,
FluxControlNetImg2ImgPipeline,
FluxControlNetInpaintPipeline,
Expand Down
2 changes: 2 additions & 0 deletions src/diffusers/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
_import_structure["transformers.stable_audio_transformer"] = ["StableAudioDiTModel"]
_import_structure["transformers.t5_film_transformer"] = ["T5FilmDecoder"]
_import_structure["transformers.transformer_2d"] = ["Transformer2DModel"]
_import_structure["transformers.transformer_cogview3plus"] = ["CogView3PlusTransformer2DModel"]
_import_structure["transformers.transformer_flux"] = ["FluxTransformer2DModel"]
_import_structure["transformers.transformer_sd3"] = ["SD3Transformer2DModel"]
_import_structure["transformers.transformer_temporal"] = ["TransformerTemporalModel"]
Expand Down Expand Up @@ -98,6 +99,7 @@
from .transformers import (
AuraFlowTransformer2DModel,
CogVideoXTransformer3DModel,
CogView3PlusTransformer2DModel,
DiTTransformer2DModel,
DualTransformer2DModel,
FluxTransformer2DModel,
Expand Down
5 changes: 3 additions & 2 deletions src/diffusers/models/attention_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ def __init__(
out_dim: int = None,
context_pre_only=None,
pre_only=False,
layrnorm_elementwise_affine: bool = True,
):
super().__init__()

Expand Down Expand Up @@ -179,8 +180,8 @@ def __init__(
self.norm_q = None
self.norm_k = None
elif qk_norm == "layer_norm":
self.norm_q = nn.LayerNorm(dim_head, eps=eps)
self.norm_k = nn.LayerNorm(dim_head, eps=eps)
self.norm_q = nn.LayerNorm(dim_head, eps=eps, elementwise_affine=layrnorm_elementwise_affine)
self.norm_k = nn.LayerNorm(dim_head, eps=eps, elementwise_affine=layrnorm_elementwise_affine)
elif qk_norm == "fp32_layer_norm":
self.norm_q = FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps)
self.norm_k = FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps)
Expand Down
79 changes: 76 additions & 3 deletions src/diffusers/models/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -714,6 +714,58 @@ def forward(self, ids: torch.Tensor) -> torch.Tensor:
return freqs_cos, freqs_sin


class CogView3PlusPatchEmbed(nn.Module):
def __init__(
self,
in_channels: int = 16,
hidden_size: int = 2560,
patch_size: int = 2,
text_hidden_size: int = 4096,
pos_embed_max_size: int = 128,
):
super().__init__()
self.in_channels = in_channels
self.hidden_size = hidden_size
self.patch_size = patch_size
self.text_hidden_size = text_hidden_size
self.pos_embed_max_size = pos_embed_max_size
# Linear projection for image patches
self.proj = nn.Linear(in_channels * patch_size**2, hidden_size)

# Linear projection for text embeddings
self.text_proj = nn.Linear(text_hidden_size, hidden_size)

pos_embed = get_2d_sincos_pos_embed(hidden_size, pos_embed_max_size, base_size=pos_embed_max_size)
pos_embed = pos_embed.reshape(pos_embed_max_size, pos_embed_max_size, hidden_size)
self.register_buffer("pos_embed", torch.from_numpy(pos_embed).float(), persistent=False)

def forward(self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor = None) -> torch.Tensor:
batch_size, channel, height, width = hidden_states.shape
if height % self.patch_size != 0 or width % self.patch_size != 0:
raise ValueError("Height and width must be divisible by patch size")
height = height // self.patch_size
width = width // self.patch_size
hidden_states = hidden_states.view(batch_size, channel, height, self.patch_size, width, self.patch_size)
hidden_states = hidden_states.permute(0, 2, 4, 1, 3, 5).contiguous()
hidden_states = hidden_states.view(batch_size, height * width, channel * self.patch_size * self.patch_size)

# Project the patches
hidden_states = self.proj(hidden_states)
encoder_hidden_states = self.text_proj(encoder_hidden_states)
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)

# Calculate text_length
text_length = encoder_hidden_states.shape[1]

image_pos_embed = self.pos_embed[:height, :width].reshape(height * width, -1)
text_pos_embed = torch.zeros(
(text_length, self.hidden_size), dtype=image_pos_embed.dtype, device=image_pos_embed.device
)
pos_embed = torch.cat([text_pos_embed, image_pos_embed], dim=0)[None, ...]

return (hidden_states + pos_embed).to(hidden_states.dtype)


class TimestepEmbedding(nn.Module):
def __init__(
self,
Expand Down Expand Up @@ -1018,6 +1070,27 @@ def forward(self, image_embeds: torch.Tensor):
return self.norm(x)


class CogView3CombineTimestepLabelEmbedding(nn.Module):
def __init__(self, time_embed_dim, label_embed_dim, in_channels=2560):
super().__init__()

self.time_proj = Timesteps(num_channels=in_channels, flip_sin_to_cos=True, downscale_freq_shift=1)
self.timestep_embedder = TimestepEmbedding(in_channels=in_channels, time_embed_dim=time_embed_dim)
self.label_embedder = nn.Sequential(
nn.Linear(label_embed_dim, time_embed_dim),
nn.SiLU(),
nn.Linear(time_embed_dim, time_embed_dim),
)

def forward(self, timestep, class_labels, hidden_dtype=None):
t_proj = self.time_proj(timestep)
t_emb = self.timestep_embedder(t_proj.to(dtype=hidden_dtype))
label_emb = self.label_embedder(class_labels)
emb = t_emb + label_emb

return emb


class CombinedTimestepLabelEmbeddings(nn.Module):
def __init__(self, num_classes, embedding_dim, class_dropout_prob=0.1):
super().__init__()
Expand All @@ -1038,11 +1111,11 @@ def forward(self, timestep, class_labels, hidden_dtype=None):


class CombinedTimestepTextProjEmbeddings(nn.Module):
def __init__(self, embedding_dim, pooled_projection_dim):
def __init__(self, embedding_dim, pooled_projection_dim, timesteps_dim=256):
super().__init__()

self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
self.time_proj = Timesteps(num_channels=timesteps_dim, flip_sin_to_cos=True, downscale_freq_shift=0)
self.timestep_embedder = TimestepEmbedding(in_channels=timesteps_dim, time_embed_dim=embedding_dim)
self.text_embedder = PixArtAlphaTextProjection(pooled_projection_dim, embedding_dim, act_fn="silu")

def forward(self, timestep, pooled_projection):
Expand Down
Loading
Loading