Skip to content

Commit 0d04194

Browse files
committed
omnigen pipeline
1 parent b839590 commit 0d04194

File tree

1 file changed

+16
-6
lines changed

1 file changed

+16
-6
lines changed

scripts/convert_omnigen_to_diffusers.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,25 @@
44
import torch
55
from safetensors.torch import load_file
66
from transformers import AutoModel, AutoTokenizer, AutoConfig
7+
from huggingface_hub import snapshot_download
78

89
from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler, OmniGenTransformer2DModel, OmniGenPipeline
910

1011

1112
def main(args):
1213
# checkpoint from https://huggingface.co/Shitao/OmniGen-v1
13-
ckpt = load_file(args.origin_ckpt_path, device="cpu")
14+
15+
if not os.path.exists(args.origin_ckpt_path):
16+
print("Model not found, downloading...")
17+
cache_folder = os.getenv('HF_HUB_CACHE')
18+
args.origin_ckpt_path = snapshot_download(repo_id=args.origin_ckpt_path,
19+
cache_dir=cache_folder,
20+
ignore_patterns=['flax_model.msgpack', 'rust_model.ot', 'tf_model.h5',
21+
'model.pt'])
22+
print(f"Downloaded model to {args.origin_ckpt_path}")
23+
24+
ckpt = os.path.join(args.origin_ckpt_path, 'model.safetensors')
25+
ckpt = load_file(ckpt, device="cpu")
1426

1527
mapping_dict = {
1628
"pos_embed": "patch_embedding.pos_embed",
@@ -27,15 +39,13 @@ def main(args):
2739

2840
converted_state_dict = {}
2941
for k, v in ckpt.items():
30-
# new_ckpt[k] = v
3142
if k in mapping_dict:
3243
converted_state_dict[mapping_dict[k]] = v
3344
else:
3445
converted_state_dict[k] = v
3546

3647
transformer_config = AutoConfig.from_pretrained(args.origin_ckpt_path)
3748

38-
# Lumina-Next-SFT 2B
3949
transformer = OmniGenTransformer2DModel(
4050
transformer_config=transformer_config,
4151
patch_size=2,
@@ -49,7 +59,7 @@ def main(args):
4959

5060
scheduler = FlowMatchEulerDiscreteScheduler()
5161

52-
vae = AutoencoderKL.from_pretrained(args.origin_ckpt_path, torch_dtype=torch.float32)
62+
vae = AutoencoderKL.from_pretrained(os.path.join(args.origin_ckpt_path, "vae"), torch_dtype=torch.float32)
5363

5464
tokenizer = AutoTokenizer.from_pretrained(args.origin_ckpt_path)
5565

@@ -64,10 +74,10 @@ def main(args):
6474
parser = argparse.ArgumentParser()
6575

6676
parser.add_argument(
67-
"--origin_ckpt_path", default=None, type=str, required=False, help="Path to the checkpoint to convert."
77+
"--origin_ckpt_path", default="Shitao/OmniGen-v1", type=str, required=False, help="Path to the checkpoint to convert."
6878
)
6979

70-
parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output pipeline.")
80+
parser.add_argument("--dump_path", default="OmniGen-v1-diffusers", type=str, required=True, help="Path to the output pipeline.")
7181

7282
args = parser.parse_args()
7383
main(args)

0 commit comments

Comments
 (0)