1212from videosys .core .pab_mgr import PABConfig , set_pab_manager
1313from videosys .core .pipeline import VideoSysPipeline , VideoSysPipelineOutput
1414from videosys .models .autoencoders .autoencoder_kl_open_sora import OpenSoraVAE_V1_2
15- from videosys .models .transformers .open_sora_transformer_3d import STDiT3_XL_2
15+ from videosys .models .transformers .open_sora_transformer_3d import STDiT3
1616from videosys .schedulers .scheduling_rflow_open_sora import RFLOW
1717from videosys .utils .utils import save_video
1818
@@ -175,10 +175,10 @@ class OpenSoraPipeline(VideoSysPipeline):
175175 tokenizer (`T5Tokenizer`):
176176 Tokenizer of class
177177 [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
178- transformer ([`Transformer2DModel `]):
179- A text conditioned `Transformer2DModel ` to denoise the encoded image latents.
178+ transformer ([`STDiT3 `]):
179+ A text conditioned `STDiT3 ` to denoise the encoded video latents.
180180 scheduler ([`SchedulerMixin`]):
181- A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
181+ A scheduler to be used in combination with `transformer` to denoise the encoded video latents.
182182 """
183183 bad_punct_regex = re .compile (
184184 r"[" + "#®•©™&@·º½¾¿¡§~" + "\)" + "\(" + "\]" + "\[" + "\}" + "\{" + "\|" + "\\ " + "\/" + "\*" + r"]{1,}"
@@ -193,7 +193,7 @@ def __init__(
193193 text_encoder : Optional [T5EncoderModel ] = None ,
194194 tokenizer : Optional [AutoTokenizer ] = None ,
195195 vae : Optional [AutoencoderKL ] = None ,
196- transformer : Optional [STDiT3_XL_2 ] = None ,
196+ transformer : Optional [STDiT3 ] = None ,
197197 scheduler : Optional [RFLOW ] = None ,
198198 device : torch .device = torch .device ("cuda" ),
199199 dtype : torch .dtype = torch .bfloat16 ,
@@ -215,14 +215,9 @@ def __init__(
215215 micro_batch_size = config .tiling_size ,
216216 ).to (dtype )
217217 if transformer is None :
218- transformer = STDiT3_XL_2 (
219- from_pretrained = config .transformer ,
220- qk_norm = True ,
221- enable_flash_attn = config .enable_flash_attn ,
222- in_channels = vae .out_channels ,
223- caption_channels = text_encoder .config .d_model ,
224- model_max_length = 300 ,
225- ).to (device , dtype )
218+ transformer = STDiT3 .from_pretrained (config .transformer , enable_flash_attn = config .enable_flash_attn ).to (
219+ dtype
220+ )
226221 if scheduler is None :
227222 scheduler = RFLOW (
228223 use_timestep_transform = True , num_sampling_steps = config .num_sampling_steps , cfg_scale = config .cfg_scale
0 commit comments