Skip to content

Commit bbe2b98

Browse files
committed
update OmniGenTransformerModel
1 parent 36eee40 commit bbe2b98

File tree

5 files changed

+60
-2
lines changed

5 files changed

+60
-2
lines changed

src/diffusers/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,7 @@
108108
"MotionAdapter",
109109
"MultiAdapter",
110110
"MultiControlNetModel",
111+
"OmniGenTransformerModel",
111112
"PixArtTransformer2DModel",
112113
"PriorTransformer",
113114
"SD3ControlNetModel",
@@ -599,6 +600,7 @@
599600
MotionAdapter,
600601
MultiAdapter,
601602
MultiControlNetModel,
603+
OmniGenTransformerModel,
602604
PixArtTransformer2DModel,
603605
PriorTransformer,
604606
SD3ControlNetModel,

src/diffusers/models/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@
6666
_import_structure["transformers.transformer_mochi"] = ["MochiTransformer3DModel"]
6767
_import_structure["transformers.transformer_sd3"] = ["SD3Transformer2DModel"]
6868
_import_structure["transformers.transformer_temporal"] = ["TransformerTemporalModel"]
69+
_import_structure["transformers.transformer_omnigen"] = ["OmniGenTransformerModel"]
6970
_import_structure["unets.unet_1d"] = ["UNet1DModel"]
7071
_import_structure["unets.unet_2d"] = ["UNet2DModel"]
7172
_import_structure["unets.unet_2d_condition"] = ["UNet2DConditionModel"]
@@ -125,6 +126,7 @@
125126
LatteTransformer3DModel,
126127
LuminaNextDiT2DModel,
127128
MochiTransformer3DModel,
129+
OmniGenTransformerModel,
128130
PixArtTransformer2DModel,
129131
PriorTransformer,
130132
SD3Transformer2DModel,

src/diffusers/models/transformers/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,3 +20,4 @@
2020
from .transformer_mochi import MochiTransformer3DModel
2121
from .transformer_sd3 import SD3Transformer2DModel
2222
from .transformer_temporal import TransformerTemporalModel
23+
from .transformer_omnigen import OmniGenTransformerModel

src/diffusers/models/transformers/transformer_omnigen.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from ...loaders import PeftAdapterMixin
2727
from ...utils import logging
2828
from ..attention_processor import AttentionProcessor
29-
from ..normalization import AdaLayerNorm, CogVideoXLayerNormZero
29+
from ..normalization import AdaLayerNorm
3030
from ..embeddings import OmniGenPatchEmbed, OmniGenTimestepEmbed
3131
from ..modeling_utils import ModelMixin
3232

@@ -162,7 +162,7 @@ def forward(
162162
)
163163

164164

165-
class OmniGenTransformer(ModelMixin, ConfigMixin, PeftAdapterMixin):
165+
class OmniGenTransformerModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
166166
"""
167167
The Transformer model introduced in OmniGen.
168168
@@ -343,3 +343,4 @@ def forward(self,
343343

344344

345345

346+

test.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
import os
2+
os.environ['HF_HUB_CACHE'] = '/share/shitao/downloaded_models2'
3+
4+
from huggingface_hub import snapshot_download
5+
6+
from diffusers.models import OmniGenTransformerModel
7+
from transformers import Phi3Model, Phi3Config
8+
9+
10+
from safetensors.torch import load_file
11+
12+
model_name = "Shitao/OmniGen-v1"
13+
config = Phi3Config.from_pretrained("Shitao/OmniGen-v1")
14+
model = OmniGenTransformerModel(transformer_config=config)
15+
cache_folder = os.getenv('HF_HUB_CACHE')
16+
model_name = snapshot_download(repo_id=model_name,
17+
cache_dir=cache_folder,
18+
ignore_patterns=['flax_model.msgpack', 'rust_model.ot', 'tf_model.h5'])
19+
print(model_name)
20+
model_path = os.path.join(model_name, 'model.safetensors')
21+
ckpt = load_file(model_path, 'cpu')
22+
23+
24+
mapping_dict = {
25+
"pos_embed": "patch_embedding.pos_embed",
26+
"x_embedder.proj.weight": "patch_embedding.output_image_proj.weight",
27+
"x_embedder.proj.bias": "patch_embedding.output_image_proj.bias",
28+
"input_x_embedder.proj.weight": "patch_embedding.input_image_proj.weight",
29+
"input_x_embedder.proj.bias": "patch_embedding.input_image_proj.bias",
30+
"final_layer.adaLN_modulation.1.weight": "norm_out.linear.weight",
31+
"final_layer.adaLN_modulation.1.bias": "norm_out.linear.bias",
32+
"final_layer.linear.weight": "proj_out.weight",
33+
"final_layer.linear.bias": "proj_out.bias",
34+
35+
}
36+
37+
new_ckpt = {}
38+
for k, v in ckpt.items():
39+
# new_ckpt[k] = v
40+
if k in mapping_dict:
41+
new_ckpt[mapping_dict[k]] = v
42+
else:
43+
new_ckpt[k] = v
44+
45+
46+
47+
model.load_state_dict(new_ckpt)
48+
49+
50+
51+
52+

0 commit comments

Comments
 (0)