Skip to content

Commit 4fef9c8

Browse files
committed
reformat
1 parent 6b52547 commit 4fef9c8

File tree

5 files changed

+113
-173
lines changed

5 files changed

+113
-173
lines changed
Lines changed: 107 additions & 107 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,10 @@
11
import argparse
22
import os
3-
os.environ['HF_HUB_CACHE'] = '/share/shitao/downloaded_models2'
43

54
import torch
6-
from safetensors.torch import load_file
7-
from transformers import AutoModel, AutoTokenizer, AutoConfig
85
from huggingface_hub import snapshot_download
6+
from safetensors.torch import load_file
7+
from transformers import AutoTokenizer
98

109
from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler, OmniGenTransformer2DModel, OmniGenPipeline
1110

@@ -17,9 +16,9 @@ def main(args):
1716
print("Model not found, downloading...")
1817
cache_folder = os.getenv('HF_HUB_CACHE')
1918
args.origin_ckpt_path = snapshot_download(repo_id=args.origin_ckpt_path,
20-
cache_dir=cache_folder,
21-
ignore_patterns=['flax_model.msgpack', 'rust_model.ot', 'tf_model.h5',
22-
'model.pt'])
19+
cache_dir=cache_folder,
20+
ignore_patterns=['flax_model.msgpack', 'rust_model.ot', 'tf_model.h5',
21+
'model.pt'])
2322
print(f"Downloaded model to {args.origin_ckpt_path}")
2423

2524
ckpt = os.path.join(args.origin_ckpt_path, 'model.safetensors')
@@ -48,7 +47,7 @@ def main(args):
4847
# transformer_config = AutoConfig.from_pretrained(args.origin_ckpt_path)
4948
# print(type(transformer_config.__dict__))
5049
# print(transformer_config.__dict__)
51-
50+
5251
transformer_config = {
5352
"_name_or_path": "Phi-3-vision-128k-instruct",
5453
"architectures": [
@@ -70,104 +69,104 @@ def main(args):
7069
"rms_norm_eps": 1e-05,
7170
"rope_scaling": {
7271
"long_factor": [
73-
1.0299999713897705,
74-
1.0499999523162842,
75-
1.0499999523162842,
76-
1.0799999237060547,
77-
1.2299998998641968,
78-
1.2299998998641968,
79-
1.2999999523162842,
80-
1.4499999284744263,
81-
1.5999999046325684,
82-
1.6499998569488525,
83-
1.8999998569488525,
84-
2.859999895095825,
85-
3.68999981880188,
86-
5.419999599456787,
87-
5.489999771118164,
88-
5.489999771118164,
89-
9.09000015258789,
90-
11.579999923706055,
91-
15.65999984741211,
92-
15.769999504089355,
93-
15.789999961853027,
94-
18.360000610351562,
95-
21.989999771118164,
96-
23.079999923706055,
97-
30.009998321533203,
98-
32.35000228881836,
99-
32.590003967285156,
100-
35.56000518798828,
101-
39.95000457763672,
102-
53.840003967285156,
103-
56.20000457763672,
104-
57.95000457763672,
105-
59.29000473022461,
106-
59.77000427246094,
107-
59.920005798339844,
108-
61.190006256103516,
109-
61.96000671386719,
110-
62.50000762939453,
111-
63.3700065612793,
112-
63.48000717163086,
113-
63.48000717163086,
114-
63.66000747680664,
115-
63.850006103515625,
116-
64.08000946044922,
117-
64.760009765625,
118-
64.80001068115234,
119-
64.81001281738281,
120-
64.81001281738281
72+
1.0299999713897705,
73+
1.0499999523162842,
74+
1.0499999523162842,
75+
1.0799999237060547,
76+
1.2299998998641968,
77+
1.2299998998641968,
78+
1.2999999523162842,
79+
1.4499999284744263,
80+
1.5999999046325684,
81+
1.6499998569488525,
82+
1.8999998569488525,
83+
2.859999895095825,
84+
3.68999981880188,
85+
5.419999599456787,
86+
5.489999771118164,
87+
5.489999771118164,
88+
9.09000015258789,
89+
11.579999923706055,
90+
15.65999984741211,
91+
15.769999504089355,
92+
15.789999961853027,
93+
18.360000610351562,
94+
21.989999771118164,
95+
23.079999923706055,
96+
30.009998321533203,
97+
32.35000228881836,
98+
32.590003967285156,
99+
35.56000518798828,
100+
39.95000457763672,
101+
53.840003967285156,
102+
56.20000457763672,
103+
57.95000457763672,
104+
59.29000473022461,
105+
59.77000427246094,
106+
59.920005798339844,
107+
61.190006256103516,
108+
61.96000671386719,
109+
62.50000762939453,
110+
63.3700065612793,
111+
63.48000717163086,
112+
63.48000717163086,
113+
63.66000747680664,
114+
63.850006103515625,
115+
64.08000946044922,
116+
64.760009765625,
117+
64.80001068115234,
118+
64.81001281738281,
119+
64.81001281738281
121120
],
122121
"short_factor": [
123-
1.05,
124-
1.05,
125-
1.05,
126-
1.1,
127-
1.1,
128-
1.1,
129-
1.2500000000000002,
130-
1.2500000000000002,
131-
1.4000000000000004,
132-
1.4500000000000004,
133-
1.5500000000000005,
134-
1.8500000000000008,
135-
1.9000000000000008,
136-
2.000000000000001,
137-
2.000000000000001,
138-
2.000000000000001,
139-
2.000000000000001,
140-
2.000000000000001,
141-
2.000000000000001,
142-
2.000000000000001,
143-
2.000000000000001,
144-
2.000000000000001,
145-
2.000000000000001,
146-
2.000000000000001,
147-
2.000000000000001,
148-
2.000000000000001,
149-
2.000000000000001,
150-
2.000000000000001,
151-
2.000000000000001,
152-
2.000000000000001,
153-
2.000000000000001,
154-
2.000000000000001,
155-
2.1000000000000005,
156-
2.1000000000000005,
157-
2.2,
158-
2.3499999999999996,
159-
2.3499999999999996,
160-
2.3499999999999996,
161-
2.3499999999999996,
162-
2.3999999999999995,
163-
2.3999999999999995,
164-
2.6499999999999986,
165-
2.6999999999999984,
166-
2.8999999999999977,
167-
2.9499999999999975,
168-
3.049999999999997,
169-
3.049999999999997,
170-
3.049999999999997
122+
1.05,
123+
1.05,
124+
1.05,
125+
1.1,
126+
1.1,
127+
1.1,
128+
1.2500000000000002,
129+
1.2500000000000002,
130+
1.4000000000000004,
131+
1.4500000000000004,
132+
1.5500000000000005,
133+
1.8500000000000008,
134+
1.9000000000000008,
135+
2.000000000000001,
136+
2.000000000000001,
137+
2.000000000000001,
138+
2.000000000000001,
139+
2.000000000000001,
140+
2.000000000000001,
141+
2.000000000000001,
142+
2.000000000000001,
143+
2.000000000000001,
144+
2.000000000000001,
145+
2.000000000000001,
146+
2.000000000000001,
147+
2.000000000000001,
148+
2.000000000000001,
149+
2.000000000000001,
150+
2.000000000000001,
151+
2.000000000000001,
152+
2.000000000000001,
153+
2.000000000000001,
154+
2.1000000000000005,
155+
2.1000000000000005,
156+
2.2,
157+
2.3499999999999996,
158+
2.3499999999999996,
159+
2.3499999999999996,
160+
2.3499999999999996,
161+
2.3999999999999995,
162+
2.3999999999999995,
163+
2.6499999999999986,
164+
2.6999999999999984,
165+
2.8999999999999977,
166+
2.9499999999999975,
167+
3.049999999999997,
168+
3.049999999999997,
169+
3.049999999999997
171170
],
172171
"type": "su"
173172
},
@@ -179,7 +178,7 @@ def main(args):
179178
"use_cache": True,
180179
"vocab_size": 32064,
181180
"_attn_implementation": "sdpa"
182-
}
181+
}
183182
transformer = OmniGenTransformer2DModel(
184183
transformer_config=transformer_config,
185184
patch_size=2,
@@ -198,7 +197,6 @@ def main(args):
198197

199198
tokenizer = AutoTokenizer.from_pretrained(args.origin_ckpt_path)
200199

201-
202200
pipeline = OmniGenPipeline(
203201
tokenizer=tokenizer, transformer=transformer, vae=vae, scheduler=scheduler
204202
)
@@ -209,10 +207,12 @@ def main(args):
209207
parser = argparse.ArgumentParser()
210208

211209
parser.add_argument(
212-
"--origin_ckpt_path", default="Shitao/OmniGen-v1", type=str, required=False, help="Path to the checkpoint to convert."
210+
"--origin_ckpt_path", default="Shitao/OmniGen-v1", type=str, required=False,
211+
help="Path to the checkpoint to convert."
213212
)
214213

215-
parser.add_argument("--dump_path", default="/share/shitao/repos/OmniGen-v1-diffusers", type=str, required=False, help="Path to the output pipeline.")
214+
parser.add_argument("--dump_path", default="OmniGen-v1-diffusers", type=str, required=False,
215+
help="Path to the output pipeline.")
216216

217217
args = parser.parse_args()
218218
main(args)

src/diffusers/models/embeddings.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -358,8 +358,8 @@ def forward(self,
358358
):
359359
"""
360360
Args:
361-
latent:
362-
is_input_image:
361+
latent: encoded image latents
362+
is_input_image: use input_image_proj or output_image_proj
363363
padding_latent: When sizes of target images are inconsistent, use `padding_latent` to maintain consistent sequence length.
364364
365365
Returns: torch.Tensor

src/diffusers/pipelines/omnigen/pipeline_omnigen.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -49,13 +49,13 @@
4949
>>> import torch
5050
>>> from diffusers import OmniGenPipeline
5151
52-
>>> pipe = OmniGenPipeline.from_pretrained("****", torch_dtype=torch.bfloat16)
52+
>>> pipe = OmniGenPipeline.from_pretrained("Shitao/OmniGen-v1-diffusers", torch_dtype=torch.bfloat16)
5353
>>> pipe.to("cuda")
5454
>>> prompt = "A cat holding a sign that says hello world"
5555
>>> # Depending on the variant being used, the pipeline call will slightly vary.
5656
>>> # Refer to the pipeline documentation for more details.
57-
>>> image = pipe(prompt, num_inference_steps=4, guidance_scale=0.0).images[0]
58-
>>> image.save("flux.png")
57+
>>> image = pipe(prompt, num_inference_steps=50, guidance_scale=2.5).images[0]
58+
>>> image.save("t2i.png")
5959
```
6060
"""
6161

src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -212,7 +212,7 @@ def set_timesteps(
212212
else:
213213
sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)])
214214

215-
self.timesteps = timesteps.to(device=device)
215+
self.timesteps = timesteps.to(device=device)
216216
self.sigmas = sigmas
217217
self._step_index = None
218218
self._begin_index = None
@@ -300,7 +300,6 @@ def step(
300300

301301
sigma = self.sigmas[self.step_index]
302302
sigma_next = self.sigmas[self.step_index + 1]
303-
304303
prev_sample = sample + (sigma_next - sigma) * model_output
305304

306305
# Cast sample back to model compatible dtype

test.py

Lines changed: 0 additions & 59 deletions
This file was deleted.

0 commit comments

Comments
 (0)