Skip to content

Commit 85abe5e

Browse files
committed
update omnigen_pipeline
1 parent 0d04194 commit 85abe5e

File tree

11 files changed

+570
-135
lines changed

11 files changed

+570
-135
lines changed

scripts/convert_omnigen_to_diffusers.py

Lines changed: 138 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import argparse
22
import os
3+
os.environ['HF_HUB_CACHE'] = '/share/shitao/downloaded_models2'
34

45
import torch
56
from safetensors.torch import load_file
@@ -44,15 +45,149 @@ def main(args):
4445
else:
4546
converted_state_dict[k] = v
4647

47-
transformer_config = AutoConfig.from_pretrained(args.origin_ckpt_path)
48-
48+
# transformer_config = AutoConfig.from_pretrained(args.origin_ckpt_path)
49+
# print(type(transformer_config.__dict__))
50+
# print(transformer_config.__dict__)
51+
52+
transformer_config = {
53+
"_name_or_path": "Phi-3-vision-128k-instruct",
54+
"architectures": [
55+
"Phi3ForCausalLM"
56+
],
57+
"attention_dropout": 0.0,
58+
"bos_token_id": 1,
59+
"eos_token_id": 2,
60+
"hidden_act": "silu",
61+
"hidden_size": 3072,
62+
"initializer_range": 0.02,
63+
"intermediate_size": 8192,
64+
"max_position_embeddings": 131072,
65+
"model_type": "phi3",
66+
"num_attention_heads": 32,
67+
"num_hidden_layers": 32,
68+
"num_key_value_heads": 32,
69+
"original_max_position_embeddings": 4096,
70+
"rms_norm_eps": 1e-05,
71+
"rope_scaling": {
72+
"long_factor": [
73+
1.0299999713897705,
74+
1.0499999523162842,
75+
1.0499999523162842,
76+
1.0799999237060547,
77+
1.2299998998641968,
78+
1.2299998998641968,
79+
1.2999999523162842,
80+
1.4499999284744263,
81+
1.5999999046325684,
82+
1.6499998569488525,
83+
1.8999998569488525,
84+
2.859999895095825,
85+
3.68999981880188,
86+
5.419999599456787,
87+
5.489999771118164,
88+
5.489999771118164,
89+
9.09000015258789,
90+
11.579999923706055,
91+
15.65999984741211,
92+
15.769999504089355,
93+
15.789999961853027,
94+
18.360000610351562,
95+
21.989999771118164,
96+
23.079999923706055,
97+
30.009998321533203,
98+
32.35000228881836,
99+
32.590003967285156,
100+
35.56000518798828,
101+
39.95000457763672,
102+
53.840003967285156,
103+
56.20000457763672,
104+
57.95000457763672,
105+
59.29000473022461,
106+
59.77000427246094,
107+
59.920005798339844,
108+
61.190006256103516,
109+
61.96000671386719,
110+
62.50000762939453,
111+
63.3700065612793,
112+
63.48000717163086,
113+
63.48000717163086,
114+
63.66000747680664,
115+
63.850006103515625,
116+
64.08000946044922,
117+
64.760009765625,
118+
64.80001068115234,
119+
64.81001281738281,
120+
64.81001281738281
121+
],
122+
"short_factor": [
123+
1.05,
124+
1.05,
125+
1.05,
126+
1.1,
127+
1.1,
128+
1.1,
129+
1.2500000000000002,
130+
1.2500000000000002,
131+
1.4000000000000004,
132+
1.4500000000000004,
133+
1.5500000000000005,
134+
1.8500000000000008,
135+
1.9000000000000008,
136+
2.000000000000001,
137+
2.000000000000001,
138+
2.000000000000001,
139+
2.000000000000001,
140+
2.000000000000001,
141+
2.000000000000001,
142+
2.000000000000001,
143+
2.000000000000001,
144+
2.000000000000001,
145+
2.000000000000001,
146+
2.000000000000001,
147+
2.000000000000001,
148+
2.000000000000001,
149+
2.000000000000001,
150+
2.000000000000001,
151+
2.000000000000001,
152+
2.000000000000001,
153+
2.000000000000001,
154+
2.000000000000001,
155+
2.1000000000000005,
156+
2.1000000000000005,
157+
2.2,
158+
2.3499999999999996,
159+
2.3499999999999996,
160+
2.3499999999999996,
161+
2.3499999999999996,
162+
2.3999999999999995,
163+
2.3999999999999995,
164+
2.6499999999999986,
165+
2.6999999999999984,
166+
2.8999999999999977,
167+
2.9499999999999975,
168+
3.049999999999997,
169+
3.049999999999997,
170+
3.049999999999997
171+
],
172+
"type": "su"
173+
},
174+
"rope_theta": 10000.0,
175+
"sliding_window": 131072,
176+
"tie_word_embeddings": False,
177+
"torch_dtype": "bfloat16",
178+
"transformers_version": "4.38.1",
179+
"use_cache": True,
180+
"vocab_size": 32064,
181+
"_attn_implementation": "sdpa"
182+
}
49183
transformer = OmniGenTransformer2DModel(
50184
transformer_config=transformer_config,
51185
patch_size=2,
52186
in_channels=4,
53187
pos_embed_max_size=192,
54188
)
55189
transformer.load_state_dict(converted_state_dict, strict=True)
190+
transformer.to(torch.bfloat16)
56191

57192
num_model_params = sum(p.numel() for p in transformer.parameters())
58193
print(f"Total number of transformer parameters: {num_model_params}")
@@ -77,7 +212,7 @@ def main(args):
77212
"--origin_ckpt_path", default="Shitao/OmniGen-v1", type=str, required=False, help="Path to the checkpoint to convert."
78213
)
79214

80-
parser.add_argument("--dump_path", default="OmniGen-v1-diffusers", type=str, required=True, help="Path to the output pipeline.")
215+
parser.add_argument("--dump_path", default="/share/shitao/repos/OmniGen-v1-diffusers", type=str, required=False, help="Path to the output pipeline.")
81216

82217
args = parser.parse_args()
83218
main(args)

src/diffusers/models/embeddings.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -381,9 +381,9 @@ def forward(self,
381381
height, width = latent.shape[-2:]
382382
pos_embed = self.cropped_pos_embed(height, width)
383383
latent = self.patch_embeddings(latent, is_input_image)
384-
latent = latent + pos_embed
384+
patched_latents = latent + pos_embed
385385

386-
return latent
386+
return patched_latents
387387

388388

389389
class LuminaPatchEmbed(nn.Module):

src/diffusers/models/transformers/transformer_omnigen.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ def forward(
125125
)
126126
use_cache = False
127127

128-
# kept for BC (non `Cache` `past_key_values` inputs)
128+
# kept for BC (non `Cache` `past_key_values` inputs)
129129
return_legacy_cache = False
130130
if use_cache and not isinstance(past_key_values, Cache):
131131
return_legacy_cache = True
@@ -240,7 +240,7 @@ class OmniGenTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
240240
@register_to_config
241241
def __init__(
242242
self,
243-
transformer_config: Phi3Config,
243+
transformer_config: Dict,
244244
patch_size=2,
245245
in_channels=4,
246246
pos_embed_max_size: int = 192,
@@ -251,6 +251,7 @@ def __init__(
251251
self.patch_size = patch_size
252252
self.pos_embed_max_size = pos_embed_max_size
253253

254+
transformer_config = Phi3Config(**transformer_config)
254255
hidden_size = transformer_config.hidden_size
255256

256257
self.patch_embedding = OmniGenPatchEmbed(patch_size=patch_size,
@@ -386,7 +387,7 @@ def forward(self,
386387
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
387388
)
388389

389-
height, width = hidden_states.size(-2)
390+
height, width = hidden_states.size()[-2:]
390391
hidden_states = self.patch_embedding(hidden_states, is_input_image=False)
391392
num_tokens_for_output_image = hidden_states.size(1)
392393

@@ -405,7 +406,7 @@ def forward(self,
405406

406407
image_embedding = output[:, -num_tokens_for_output_image:]
407408
time_emb = self.t_embedder(timestep, dtype=hidden_states.dtype)
408-
x = self.final_layer(image_embedding, time_emb)
409+
x = self.proj_out(self.norm_out(image_embedding, temb=time_emb))
409410
output = self.unpatchify(x, height, width)
410411

411412
if not return_dict:

src/diffusers/pipelines/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -254,6 +254,7 @@
254254
)
255255
_import_structure["mochi"] = ["MochiPipeline"]
256256
_import_structure["musicldm"] = ["MusicLDMPipeline"]
257+
_import_structure["omnigen"] = ["OmniGenPipeline"]
257258
_import_structure["paint_by_example"] = ["PaintByExamplePipeline"]
258259
_import_structure["pia"] = ["PIAPipeline"]
259260
_import_structure["pixart_alpha"] = ["PixArtAlphaPipeline", "PixArtSigmaPipeline"]
@@ -584,6 +585,7 @@
584585
)
585586
from .mochi import MochiPipeline
586587
from .musicldm import MusicLDMPipeline
588+
from .omnigen import OmniGenPipeline
587589
from .pag import (
588590
AnimateDiffPAGPipeline,
589591
HunyuanDiTPAGPipeline,

src/diffusers/pipelines/omnigen/kvcache_omnigen.py

Lines changed: 26 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,20 @@
1+
from tqdm import tqdm
12
from typing import Optional, Dict, Any, Tuple, List
3+
import gc
24

35
import torch
4-
from transformers.cache_utils import DynamicCache
6+
from transformers.cache_utils import Cache, DynamicCache, OffloadedCache
7+
58

69

710
class OmniGenCache(DynamicCache):
8-
def __init__(self,
9-
num_tokens_for_img: int, offload_kv_cache: bool = False) -> None:
11+
def __init__(self,
12+
num_tokens_for_img: int,
13+
offload_kv_cache: bool=False) -> None:
1014
if not torch.cuda.is_available():
11-
raise RuntimeError(
12-
"OmniGenCache can only be used with a GPU. If there is no GPU, you need to set use_kv_cache=False, which will result in longer inference time!")
15+
# print("No avaliable GPU, offload_kv_cache wiil be set to False, which will result in large memory usage and time cost when input multiple images!!!")
16+
# offload_kv_cache = False
17+
raise RuntimeError("OffloadedCache can only be used with a GPU. If there is no GPU, you need to set use_kv_cache=False, which will result in longer inference time!")
1318
super().__init__()
1419
self.original_device = []
1520
self.prefetch_stream = torch.cuda.Stream()
@@ -25,17 +30,19 @@ def prefetch_layer(self, layer_idx: int):
2530
self.key_cache[layer_idx] = self.key_cache[layer_idx].to(device, non_blocking=True)
2631
self.value_cache[layer_idx] = self.value_cache[layer_idx].to(device, non_blocking=True)
2732

33+
2834
def evict_previous_layer(self, layer_idx: int):
2935
"Moves the previous layer cache to the CPU"
3036
if len(self) > 2:
3137
# We do it on the default stream so it occurs after all earlier computations on these tensors are done
32-
if layer_idx == 0:
38+
if layer_idx == 0:
3339
prev_layer_idx = -1
3440
else:
3541
prev_layer_idx = (layer_idx - 1) % len(self)
3642
self.key_cache[prev_layer_idx] = self.key_cache[prev_layer_idx].to("cpu", non_blocking=True)
3743
self.value_cache[prev_layer_idx] = self.value_cache[prev_layer_idx].to("cpu", non_blocking=True)
3844

45+
3946
def __getitem__(self, layer_idx: int) -> List[Tuple[torch.Tensor]]:
4047
"Gets the cache for this layer to the device. Prefetches the next and evicts the previous layer."
4148
if layer_idx < len(self):
@@ -44,12 +51,12 @@ def __getitem__(self, layer_idx: int) -> List[Tuple[torch.Tensor]]:
4451
torch.cuda.current_stream().synchronize()
4552
self.evict_previous_layer(layer_idx)
4653
# Load current layer cache to its original device if not already there
47-
# original_device = self.original_device[layer_idx]
54+
original_device = self.original_device[layer_idx]
4855
# self.prefetch_stream.synchronize(original_device)
49-
self.prefetch_stream.synchronize()
56+
torch.cuda.synchronize(self.prefetch_stream)
5057
key_tensor = self.key_cache[layer_idx]
5158
value_tensor = self.value_cache[layer_idx]
52-
59+
5360
# Prefetch the next layer
5461
self.prefetch_layer((layer_idx + 1) % len(self))
5562
else:
@@ -58,13 +65,13 @@ def __getitem__(self, layer_idx: int) -> List[Tuple[torch.Tensor]]:
5865
return (key_tensor, value_tensor)
5966
else:
6067
raise KeyError(f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}")
61-
68+
6269
def update(
63-
self,
64-
key_states: torch.Tensor,
65-
value_states: torch.Tensor,
66-
layer_idx: int,
67-
cache_kwargs: Optional[Dict[str, Any]] = None,
70+
self,
71+
key_states: torch.Tensor,
72+
value_states: torch.Tensor,
73+
layer_idx: int,
74+
cache_kwargs: Optional[Dict[str, Any]] = None,
6875
) -> Tuple[torch.Tensor, torch.Tensor]:
6976
"""
7077
Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
@@ -85,13 +92,13 @@ def update(
8592
raise ValueError("OffloadedCache does not support model usage where layers are skipped. Use DynamicCache.")
8693
elif len(self.key_cache) == layer_idx:
8794
# only cache the states for condition tokens
88-
key_states = key_states[..., :-(self.num_tokens_for_img + 1), :]
89-
value_states = value_states[..., :-(self.num_tokens_for_img + 1), :]
95+
key_states = key_states[..., :-(self.num_tokens_for_img+1), :]
96+
value_states = value_states[..., :-(self.num_tokens_for_img+1), :]
9097

91-
# Update the number of seen tokens
98+
# Update the number of seen tokens
9299
if layer_idx == 0:
93100
self._seen_tokens += key_states.shape[-2]
94-
101+
95102
self.key_cache.append(key_states)
96103
self.value_cache.append(value_states)
97104
self.original_device.append(key_states.device)

0 commit comments

Comments
 (0)