Skip to content

Commit b839590

Browse files
committed
omnigen pipeline
1 parent bbe2b98 commit b839590

File tree

15 files changed

+832
-451
lines changed

15 files changed

+832
-451
lines changed
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
import argparse
2+
import os
3+
4+
import torch
5+
from safetensors.torch import load_file
6+
from transformers import AutoModel, AutoTokenizer, AutoConfig
7+
8+
from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler, OmniGenTransformer2DModel, OmniGenPipeline
9+
10+
11+
def main(args):
12+
# checkpoint from https://huggingface.co/Shitao/OmniGen-v1
13+
ckpt = load_file(args.origin_ckpt_path, device="cpu")
14+
15+
mapping_dict = {
16+
"pos_embed": "patch_embedding.pos_embed",
17+
"x_embedder.proj.weight": "patch_embedding.output_image_proj.weight",
18+
"x_embedder.proj.bias": "patch_embedding.output_image_proj.bias",
19+
"input_x_embedder.proj.weight": "patch_embedding.input_image_proj.weight",
20+
"input_x_embedder.proj.bias": "patch_embedding.input_image_proj.bias",
21+
"final_layer.adaLN_modulation.1.weight": "norm_out.linear.weight",
22+
"final_layer.adaLN_modulation.1.bias": "norm_out.linear.bias",
23+
"final_layer.linear.weight": "proj_out.weight",
24+
"final_layer.linear.bias": "proj_out.bias",
25+
26+
}
27+
28+
converted_state_dict = {}
29+
for k, v in ckpt.items():
30+
# new_ckpt[k] = v
31+
if k in mapping_dict:
32+
converted_state_dict[mapping_dict[k]] = v
33+
else:
34+
converted_state_dict[k] = v
35+
36+
transformer_config = AutoConfig.from_pretrained(args.origin_ckpt_path)
37+
38+
# Lumina-Next-SFT 2B
39+
transformer = OmniGenTransformer2DModel(
40+
transformer_config=transformer_config,
41+
patch_size=2,
42+
in_channels=4,
43+
pos_embed_max_size=192,
44+
)
45+
transformer.load_state_dict(converted_state_dict, strict=True)
46+
47+
num_model_params = sum(p.numel() for p in transformer.parameters())
48+
print(f"Total number of transformer parameters: {num_model_params}")
49+
50+
scheduler = FlowMatchEulerDiscreteScheduler()
51+
52+
vae = AutoencoderKL.from_pretrained(args.origin_ckpt_path, torch_dtype=torch.float32)
53+
54+
tokenizer = AutoTokenizer.from_pretrained(args.origin_ckpt_path)
55+
56+
57+
pipeline = OmniGenPipeline(
58+
tokenizer=tokenizer, transformer=transformer, vae=vae, scheduler=scheduler
59+
)
60+
pipeline.save_pretrained(args.dump_path)
61+
62+
63+
if __name__ == "__main__":
64+
parser = argparse.ArgumentParser()
65+
66+
parser.add_argument(
67+
"--origin_ckpt_path", default=None, type=str, required=False, help="Path to the checkpoint to convert."
68+
)
69+
70+
parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output pipeline.")
71+
72+
args = parser.parse_args()
73+
main(args)

src/diffusers/__init__.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@
108108
"MotionAdapter",
109109
"MultiAdapter",
110110
"MultiControlNetModel",
111-
"OmniGenTransformerModel",
111+
"OmniGenTransformer2DModel",
112112
"PixArtTransformer2DModel",
113113
"PriorTransformer",
114114
"SD3ControlNetModel",
@@ -321,6 +321,7 @@
321321
"MarigoldNormalsPipeline",
322322
"MochiPipeline",
323323
"MusicLDMPipeline",
324+
"OmniGenPipeline",
324325
"PaintByExamplePipeline",
325326
"PIAPipeline",
326327
"PixArtAlphaPipeline",
@@ -600,7 +601,7 @@
600601
MotionAdapter,
601602
MultiAdapter,
602603
MultiControlNetModel,
603-
OmniGenTransformerModel,
604+
OmniGenTransformer2DModel,
604605
PixArtTransformer2DModel,
605606
PriorTransformer,
606607
SD3ControlNetModel,
@@ -792,6 +793,7 @@
792793
MarigoldNormalsPipeline,
793794
MochiPipeline,
794795
MusicLDMPipeline,
796+
OmniGenPipeline,
795797
PaintByExamplePipeline,
796798
PIAPipeline,
797799
PixArtAlphaPipeline,

src/diffusers/models/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@
6666
_import_structure["transformers.transformer_mochi"] = ["MochiTransformer3DModel"]
6767
_import_structure["transformers.transformer_sd3"] = ["SD3Transformer2DModel"]
6868
_import_structure["transformers.transformer_temporal"] = ["TransformerTemporalModel"]
69-
_import_structure["transformers.transformer_omnigen"] = ["OmniGenTransformerModel"]
69+
_import_structure["transformers.transformer_omnigen"] = ["OmniGenTransformer2DModel"]
7070
_import_structure["unets.unet_1d"] = ["UNet1DModel"]
7171
_import_structure["unets.unet_2d"] = ["UNet2DModel"]
7272
_import_structure["unets.unet_2d_condition"] = ["UNet2DConditionModel"]
@@ -126,7 +126,7 @@
126126
LatteTransformer3DModel,
127127
LuminaNextDiT2DModel,
128128
MochiTransformer3DModel,
129-
OmniGenTransformerModel,
129+
OmniGenTransformer2DModel,
130130
PixArtTransformer2DModel,
131131
PriorTransformer,
132132
SD3Transformer2DModel,

src/diffusers/models/embeddings.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -351,7 +351,20 @@ def patch_embeddings(self, latent, is_input_image: bool):
351351
latent = latent.flatten(2).transpose(1, 2)
352352
return latent
353353

354-
def forward(self, latent, is_input_image: bool, padding_latent=None):
354+
def forward(self,
355+
latent: torch.Tensor,
356+
is_input_image: bool,
357+
padding_latent: torch.Tensor = None
358+
):
359+
"""
360+
Args:
361+
latent:
362+
is_input_image:
363+
padding_latent: When sizes of target images are inconsistent, use `padding_latent` to maintain consistent sequence length.
364+
365+
Returns: torch.Tensor
366+
367+
"""
355368
if isinstance(latent, list):
356369
if padding_latent is None:
357370
padding_latent = [None] * len(latent)
@@ -362,7 +375,7 @@ def forward(self, latent, is_input_image: bool, padding_latent=None):
362375
pos_embed = self.cropped_pos_embed(height, width)
363376
sub_latent = sub_latent + pos_embed
364377
if padding is not None:
365-
sub_latent = torch.cat([sub_latent, padding], dim=-2)
378+
sub_latent = torch.cat([sub_latent, padding.to(sub_latent.device)], dim=-2)
366379
patched_latents.append(sub_latent)
367380
else:
368381
height, width = latent.shape[-2:]

src/diffusers/models/transformers/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,4 +20,4 @@
2020
from .transformer_mochi import MochiTransformer3DModel
2121
from .transformer_sd3 import SD3Transformer2DModel
2222
from .transformer_temporal import TransformerTemporalModel
23-
from .transformer_omnigen import OmniGenTransformerModel
23+
from .transformer_omnigen import OmniGenTransformer2DModel

0 commit comments

Comments
 (0)