44import torch
55from safetensors .torch import load_file
66from transformers import AutoModel , AutoTokenizer , AutoConfig
7+ from huggingface_hub import snapshot_download
78
89from diffusers import AutoencoderKL , FlowMatchEulerDiscreteScheduler , OmniGenTransformer2DModel , OmniGenPipeline
910
1011
1112def main (args ):
1213 # checkpoint from https://huggingface.co/Shitao/OmniGen-v1
13- ckpt = load_file (args .origin_ckpt_path , device = "cpu" )
14+
15+ if not os .path .exists (args .origin_ckpt_path ):
16+ print ("Model not found, downloading..." )
17+ cache_folder = os .getenv ('HF_HUB_CACHE' )
18+ args .origin_ckpt_path = snapshot_download (repo_id = args .origin_ckpt_path ,
19+ cache_dir = cache_folder ,
20+ ignore_patterns = ['flax_model.msgpack' , 'rust_model.ot' , 'tf_model.h5' ,
21+ 'model.pt' ])
22+ print (f"Downloaded model to { args .origin_ckpt_path } " )
23+
24+ ckpt = os .path .join (args .origin_ckpt_path , 'model.safetensors' )
25+ ckpt = load_file (ckpt , device = "cpu" )
1426
1527 mapping_dict = {
1628 "pos_embed" : "patch_embedding.pos_embed" ,
@@ -27,15 +39,13 @@ def main(args):
2739
2840 converted_state_dict = {}
2941 for k , v in ckpt .items ():
30- # new_ckpt[k] = v
3142 if k in mapping_dict :
3243 converted_state_dict [mapping_dict [k ]] = v
3344 else :
3445 converted_state_dict [k ] = v
3546
3647 transformer_config = AutoConfig .from_pretrained (args .origin_ckpt_path )
3748
38- # Lumina-Next-SFT 2B
3949 transformer = OmniGenTransformer2DModel (
4050 transformer_config = transformer_config ,
4151 patch_size = 2 ,
@@ -49,7 +59,7 @@ def main(args):
4959
5060 scheduler = FlowMatchEulerDiscreteScheduler ()
5161
52- vae = AutoencoderKL .from_pretrained (args .origin_ckpt_path , torch_dtype = torch .float32 )
62+ vae = AutoencoderKL .from_pretrained (os . path . join ( args .origin_ckpt_path , "vae" ) , torch_dtype = torch .float32 )
5363
5464 tokenizer = AutoTokenizer .from_pretrained (args .origin_ckpt_path )
5565
@@ -64,10 +74,10 @@ def main(args):
6474 parser = argparse .ArgumentParser ()
6575
6676 parser .add_argument (
67- "--origin_ckpt_path" , default = None , type = str , required = False , help = "Path to the checkpoint to convert."
77+ "--origin_ckpt_path" , default = "Shitao/OmniGen-v1" , type = str , required = False , help = "Path to the checkpoint to convert."
6878 )
6979
70- parser .add_argument ("--dump_path" , default = None , type = str , required = True , help = "Path to the output pipeline." )
80+ parser .add_argument ("--dump_path" , default = "OmniGen-v1-diffusers" , type = str , required = True , help = "Path to the output pipeline." )
7181
7282 args = parser .parse_args ()
7383 main (args )
0 commit comments