Skip to content

Commit 256fb47

Browse files
oahzxlHaiShuangFanLiewFengLexarymade
authored
fix opensora init (#202)
* update (#201) --------- Co-authored-by: HaishuangFan <[email protected]> Co-authored-by: LiewFeng <[email protected]> Co-authored-by: Lexarymade <[email protected]>
1 parent 7f6ab94 commit 256fb47

File tree

2 files changed

+8
-23
lines changed

2 files changed

+8
-23
lines changed

videosys/models/transformers/open_sora_transformer_3d.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
# --------------------------------------------------------
99

1010

11-
import os
1211
from collections.abc import Iterable
1312
from functools import partial
1413

@@ -635,12 +634,3 @@ def unpatchify(self, x, N_t, N_h, N_w, R_t, R_h, R_w):
635634
# unpad
636635
x = x[:, :, :R_t, :R_h, :R_w]
637636
return x
638-
639-
640-
def STDiT3_XL_2(from_pretrained=None, **kwargs):
641-
if from_pretrained is not None and not os.path.isdir(from_pretrained):
642-
model = STDiT3.from_pretrained(from_pretrained, **kwargs)
643-
else:
644-
config = STDiT3Config(depth=28, hidden_size=1152, patch_size=(1, 2, 2), num_heads=16, **kwargs)
645-
model = STDiT3(config)
646-
return model

videosys/pipelines/open_sora/pipeline_open_sora.py

Lines changed: 8 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from videosys.core.pab_mgr import PABConfig, set_pab_manager
1313
from videosys.core.pipeline import VideoSysPipeline, VideoSysPipelineOutput
1414
from videosys.models.autoencoders.autoencoder_kl_open_sora import OpenSoraVAE_V1_2
15-
from videosys.models.transformers.open_sora_transformer_3d import STDiT3_XL_2
15+
from videosys.models.transformers.open_sora_transformer_3d import STDiT3
1616
from videosys.schedulers.scheduling_rflow_open_sora import RFLOW
1717
from videosys.utils.utils import save_video
1818

@@ -175,10 +175,10 @@ class OpenSoraPipeline(VideoSysPipeline):
175175
tokenizer (`T5Tokenizer`):
176176
Tokenizer of class
177177
[T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
178-
transformer ([`Transformer2DModel`]):
179-
A text conditioned `Transformer2DModel` to denoise the encoded image latents.
178+
transformer ([`STDiT3`]):
179+
A text conditioned `STDiT3` to denoise the encoded video latents.
180180
scheduler ([`SchedulerMixin`]):
181-
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
181+
A scheduler to be used in combination with `transformer` to denoise the encoded video latents.
182182
"""
183183
bad_punct_regex = re.compile(
184184
r"[" + "#®•©™&@·º½¾¿¡§~" + "\)" + "\(" + "\]" + "\[" + "\}" + "\{" + "\|" + "\\" + "\/" + "\*" + r"]{1,}"
@@ -193,7 +193,7 @@ def __init__(
193193
text_encoder: Optional[T5EncoderModel] = None,
194194
tokenizer: Optional[AutoTokenizer] = None,
195195
vae: Optional[AutoencoderKL] = None,
196-
transformer: Optional[STDiT3_XL_2] = None,
196+
transformer: Optional[STDiT3] = None,
197197
scheduler: Optional[RFLOW] = None,
198198
device: torch.device = torch.device("cuda"),
199199
dtype: torch.dtype = torch.bfloat16,
@@ -215,14 +215,9 @@ def __init__(
215215
micro_batch_size=config.tiling_size,
216216
).to(dtype)
217217
if transformer is None:
218-
transformer = STDiT3_XL_2(
219-
from_pretrained=config.transformer,
220-
qk_norm=True,
221-
enable_flash_attn=config.enable_flash_attn,
222-
in_channels=vae.out_channels,
223-
caption_channels=text_encoder.config.d_model,
224-
model_max_length=300,
225-
).to(device, dtype)
218+
transformer = STDiT3.from_pretrained(config.transformer, enable_flash_attn=config.enable_flash_attn).to(
219+
dtype
220+
)
226221
if scheduler is None:
227222
scheduler = RFLOW(
228223
use_timestep_transform=True, num_sampling_steps=config.num_sampling_steps, cfg_scale=config.cfg_scale

0 commit comments

Comments
 (0)