11import argparse
22import os
3- os .environ ['HF_HUB_CACHE' ] = '/share/shitao/downloaded_models2'
43
54import torch
6- from safetensors .torch import load_file
7- from transformers import AutoModel , AutoTokenizer , AutoConfig
85from huggingface_hub import snapshot_download
6+ from safetensors .torch import load_file
7+ from transformers import AutoTokenizer
98
109from diffusers import AutoencoderKL , FlowMatchEulerDiscreteScheduler , OmniGenTransformer2DModel , OmniGenPipeline
1110
@@ -17,9 +16,9 @@ def main(args):
1716 print ("Model not found, downloading..." )
1817 cache_folder = os .getenv ('HF_HUB_CACHE' )
1918 args .origin_ckpt_path = snapshot_download (repo_id = args .origin_ckpt_path ,
20- cache_dir = cache_folder ,
21- ignore_patterns = ['flax_model.msgpack' , 'rust_model.ot' , 'tf_model.h5' ,
22- 'model.pt' ])
19+ cache_dir = cache_folder ,
20+ ignore_patterns = ['flax_model.msgpack' , 'rust_model.ot' , 'tf_model.h5' ,
21+ 'model.pt' ])
2322 print (f"Downloaded model to { args .origin_ckpt_path } " )
2423
2524 ckpt = os .path .join (args .origin_ckpt_path , 'model.safetensors' )
@@ -48,7 +47,7 @@ def main(args):
4847 # transformer_config = AutoConfig.from_pretrained(args.origin_ckpt_path)
4948 # print(type(transformer_config.__dict__))
5049 # print(transformer_config.__dict__)
51-
50+
5251 transformer_config = {
5352 "_name_or_path" : "Phi-3-vision-128k-instruct" ,
5453 "architectures" : [
@@ -70,104 +69,104 @@ def main(args):
7069 "rms_norm_eps" : 1e-05 ,
7170 "rope_scaling" : {
7271 "long_factor" : [
73- 1.0299999713897705 ,
74- 1.0499999523162842 ,
75- 1.0499999523162842 ,
76- 1.0799999237060547 ,
77- 1.2299998998641968 ,
78- 1.2299998998641968 ,
79- 1.2999999523162842 ,
80- 1.4499999284744263 ,
81- 1.5999999046325684 ,
82- 1.6499998569488525 ,
83- 1.8999998569488525 ,
84- 2.859999895095825 ,
85- 3.68999981880188 ,
86- 5.419999599456787 ,
87- 5.489999771118164 ,
88- 5.489999771118164 ,
89- 9.09000015258789 ,
90- 11.579999923706055 ,
91- 15.65999984741211 ,
92- 15.769999504089355 ,
93- 15.789999961853027 ,
94- 18.360000610351562 ,
95- 21.989999771118164 ,
96- 23.079999923706055 ,
97- 30.009998321533203 ,
98- 32.35000228881836 ,
99- 32.590003967285156 ,
100- 35.56000518798828 ,
101- 39.95000457763672 ,
102- 53.840003967285156 ,
103- 56.20000457763672 ,
104- 57.95000457763672 ,
105- 59.29000473022461 ,
106- 59.77000427246094 ,
107- 59.920005798339844 ,
108- 61.190006256103516 ,
109- 61.96000671386719 ,
110- 62.50000762939453 ,
111- 63.3700065612793 ,
112- 63.48000717163086 ,
113- 63.48000717163086 ,
114- 63.66000747680664 ,
115- 63.850006103515625 ,
116- 64.08000946044922 ,
117- 64.760009765625 ,
118- 64.80001068115234 ,
119- 64.81001281738281 ,
120- 64.81001281738281
72+ 1.0299999713897705 ,
73+ 1.0499999523162842 ,
74+ 1.0499999523162842 ,
75+ 1.0799999237060547 ,
76+ 1.2299998998641968 ,
77+ 1.2299998998641968 ,
78+ 1.2999999523162842 ,
79+ 1.4499999284744263 ,
80+ 1.5999999046325684 ,
81+ 1.6499998569488525 ,
82+ 1.8999998569488525 ,
83+ 2.859999895095825 ,
84+ 3.68999981880188 ,
85+ 5.419999599456787 ,
86+ 5.489999771118164 ,
87+ 5.489999771118164 ,
88+ 9.09000015258789 ,
89+ 11.579999923706055 ,
90+ 15.65999984741211 ,
91+ 15.769999504089355 ,
92+ 15.789999961853027 ,
93+ 18.360000610351562 ,
94+ 21.989999771118164 ,
95+ 23.079999923706055 ,
96+ 30.009998321533203 ,
97+ 32.35000228881836 ,
98+ 32.590003967285156 ,
99+ 35.56000518798828 ,
100+ 39.95000457763672 ,
101+ 53.840003967285156 ,
102+ 56.20000457763672 ,
103+ 57.95000457763672 ,
104+ 59.29000473022461 ,
105+ 59.77000427246094 ,
106+ 59.920005798339844 ,
107+ 61.190006256103516 ,
108+ 61.96000671386719 ,
109+ 62.50000762939453 ,
110+ 63.3700065612793 ,
111+ 63.48000717163086 ,
112+ 63.48000717163086 ,
113+ 63.66000747680664 ,
114+ 63.850006103515625 ,
115+ 64.08000946044922 ,
116+ 64.760009765625 ,
117+ 64.80001068115234 ,
118+ 64.81001281738281 ,
119+ 64.81001281738281
121120 ],
122121 "short_factor" : [
123- 1.05 ,
124- 1.05 ,
125- 1.05 ,
126- 1.1 ,
127- 1.1 ,
128- 1.1 ,
129- 1.2500000000000002 ,
130- 1.2500000000000002 ,
131- 1.4000000000000004 ,
132- 1.4500000000000004 ,
133- 1.5500000000000005 ,
134- 1.8500000000000008 ,
135- 1.9000000000000008 ,
136- 2.000000000000001 ,
137- 2.000000000000001 ,
138- 2.000000000000001 ,
139- 2.000000000000001 ,
140- 2.000000000000001 ,
141- 2.000000000000001 ,
142- 2.000000000000001 ,
143- 2.000000000000001 ,
144- 2.000000000000001 ,
145- 2.000000000000001 ,
146- 2.000000000000001 ,
147- 2.000000000000001 ,
148- 2.000000000000001 ,
149- 2.000000000000001 ,
150- 2.000000000000001 ,
151- 2.000000000000001 ,
152- 2.000000000000001 ,
153- 2.000000000000001 ,
154- 2.000000000000001 ,
155- 2.1000000000000005 ,
156- 2.1000000000000005 ,
157- 2.2 ,
158- 2.3499999999999996 ,
159- 2.3499999999999996 ,
160- 2.3499999999999996 ,
161- 2.3499999999999996 ,
162- 2.3999999999999995 ,
163- 2.3999999999999995 ,
164- 2.6499999999999986 ,
165- 2.6999999999999984 ,
166- 2.8999999999999977 ,
167- 2.9499999999999975 ,
168- 3.049999999999997 ,
169- 3.049999999999997 ,
170- 3.049999999999997
122+ 1.05 ,
123+ 1.05 ,
124+ 1.05 ,
125+ 1.1 ,
126+ 1.1 ,
127+ 1.1 ,
128+ 1.2500000000000002 ,
129+ 1.2500000000000002 ,
130+ 1.4000000000000004 ,
131+ 1.4500000000000004 ,
132+ 1.5500000000000005 ,
133+ 1.8500000000000008 ,
134+ 1.9000000000000008 ,
135+ 2.000000000000001 ,
136+ 2.000000000000001 ,
137+ 2.000000000000001 ,
138+ 2.000000000000001 ,
139+ 2.000000000000001 ,
140+ 2.000000000000001 ,
141+ 2.000000000000001 ,
142+ 2.000000000000001 ,
143+ 2.000000000000001 ,
144+ 2.000000000000001 ,
145+ 2.000000000000001 ,
146+ 2.000000000000001 ,
147+ 2.000000000000001 ,
148+ 2.000000000000001 ,
149+ 2.000000000000001 ,
150+ 2.000000000000001 ,
151+ 2.000000000000001 ,
152+ 2.000000000000001 ,
153+ 2.000000000000001 ,
154+ 2.1000000000000005 ,
155+ 2.1000000000000005 ,
156+ 2.2 ,
157+ 2.3499999999999996 ,
158+ 2.3499999999999996 ,
159+ 2.3499999999999996 ,
160+ 2.3499999999999996 ,
161+ 2.3999999999999995 ,
162+ 2.3999999999999995 ,
163+ 2.6499999999999986 ,
164+ 2.6999999999999984 ,
165+ 2.8999999999999977 ,
166+ 2.9499999999999975 ,
167+ 3.049999999999997 ,
168+ 3.049999999999997 ,
169+ 3.049999999999997
171170 ],
172171 "type" : "su"
173172 },
@@ -179,7 +178,7 @@ def main(args):
179178 "use_cache" : True ,
180179 "vocab_size" : 32064 ,
181180 "_attn_implementation" : "sdpa"
182- }
181+ }
183182 transformer = OmniGenTransformer2DModel (
184183 transformer_config = transformer_config ,
185184 patch_size = 2 ,
@@ -198,7 +197,6 @@ def main(args):
198197
199198 tokenizer = AutoTokenizer .from_pretrained (args .origin_ckpt_path )
200199
201-
202200 pipeline = OmniGenPipeline (
203201 tokenizer = tokenizer , transformer = transformer , vae = vae , scheduler = scheduler
204202 )
@@ -209,10 +207,12 @@ def main(args):
209207 parser = argparse .ArgumentParser ()
210208
211209 parser .add_argument (
212- "--origin_ckpt_path" , default = "Shitao/OmniGen-v1" , type = str , required = False , help = "Path to the checkpoint to convert."
210+ "--origin_ckpt_path" , default = "Shitao/OmniGen-v1" , type = str , required = False ,
211+ help = "Path to the checkpoint to convert."
213212 )
214213
215- parser .add_argument ("--dump_path" , default = "/share/shitao/repos/OmniGen-v1-diffusers" , type = str , required = False , help = "Path to the output pipeline." )
214+ parser .add_argument ("--dump_path" , default = "OmniGen-v1-diffusers" , type = str , required = False ,
215+ help = "Path to the output pipeline." )
216216
217217 args = parser .parse_args ()
218218 main (args )
0 commit comments