diff --git a/scripts/convert_flashvideo_to_diffusers.py b/scripts/convert_flashvideo_to_diffusers.py
new file mode 100644
index 000000000000..6e70216278df
--- /dev/null
+++ b/scripts/convert_flashvideo_to_diffusers.py
@@ -0,0 +1,349 @@
+import argparse
+from typing import Any, Dict
+
+import torch
+from transformers import T5EncoderModel, T5Tokenizer
+
+from diffusers import (
+    AutoencoderKLCogVideoX,
+    CogVideoXDDIMScheduler,
+    CogVideoXImageToVideoPipeline,
+    CogVideoXPipeline,
+    CogVideoXTransformer3DModel,
+)
+
+
+# NOTE: currently copied from convert_cogvideox_to_diffusers.py
+
+
+def reassign_query_key_value_inplace(key: str, state_dict: Dict[str, Any]):
+    to_q_key = key.replace("query_key_value", "to_q")
+    to_k_key = key.replace("query_key_value", "to_k")
+    to_v_key = key.replace("query_key_value", "to_v")
+    to_q, to_k, to_v = torch.chunk(state_dict[key], chunks=3, dim=0)
+    state_dict[to_q_key] = to_q
+    state_dict[to_k_key] = to_k
+    state_dict[to_v_key] = to_v
+    state_dict.pop(key)
+
+
+def reassign_query_key_layernorm_inplace(key: str, state_dict: Dict[str, Any]):
+    layer_id, weight_or_bias = key.split(".")[-2:]
+
+    if "query" in key:
+        new_key = f"transformer_blocks.{layer_id}.attn1.norm_q.{weight_or_bias}"
+    elif "key" in key:
+        new_key = f"transformer_blocks.{layer_id}.attn1.norm_k.{weight_or_bias}"
+
+    state_dict[new_key] = state_dict.pop(key)
+
+
+def reassign_adaln_norm_inplace(key: str, state_dict: Dict[str, Any]):
+    layer_id, _, weight_or_bias = key.split(".")[-3:]
+
+    weights_or_biases = state_dict[key].chunk(12, dim=0)
+    norm1_weights_or_biases = torch.cat(weights_or_biases[0:3] + weights_or_biases[6:9])
+    norm2_weights_or_biases = torch.cat(weights_or_biases[3:6] + weights_or_biases[9:12])
+
+    norm1_key = f"transformer_blocks.{layer_id}.norm1.linear.{weight_or_bias}"
+    state_dict[norm1_key] = norm1_weights_or_biases
+
+    norm2_key = f"transformer_blocks.{layer_id}.norm2.linear.{weight_or_bias}"
+    state_dict[norm2_key] = norm2_weights_or_biases
+
+    state_dict.pop(key)
+
+
+def remove_keys_inplace(key: str, state_dict: Dict[str, Any]):
+    state_dict.pop(key)
+
+
+def replace_up_keys_inplace(key: str, state_dict: Dict[str, Any]):
+    key_split = key.split(".")
+    layer_index = int(key_split[2])
+    replace_layer_index = 4 - 1 - layer_index
+
+    key_split[1] = "up_blocks"
+    key_split[2] = str(replace_layer_index)
+    new_key = ".".join(key_split)
+
+    state_dict[new_key] = state_dict.pop(key)
+
+
+TRANSFORMER_KEYS_RENAME_DICT = {
+    "transformer.final_layernorm": "norm_final",
+    "transformer": "transformer_blocks",
+    "attention": "attn1",
+    "mlp": "ff.net",
+    "dense_h_to_4h": "0.proj",
+    "dense_4h_to_h": "2",
+    ".layers": "",
+    "dense": "to_out.0",
+    "input_layernorm": "norm1.norm",
+    "post_attn1_layernorm": "norm2.norm",
+    "time_embed.0": "time_embedding.linear_1",
+    "time_embed.2": "time_embedding.linear_2",
+    "ofs_embed.0": "ofs_embedding.linear_1",
+    "ofs_embed.2": "ofs_embedding.linear_2",
+    "mixins.patch_embed": "patch_embed",
+    "mixins.final_layer.norm_final": "norm_out.norm",
+    "mixins.final_layer.linear": "proj_out",
+    "mixins.final_layer.adaLN_modulation.1": "norm_out.linear",
+    "mixins.pos_embed.pos_embedding": "patch_embed.pos_embedding",  # Specific to CogVideoX-5b-I2V
+}
+
+TRANSFORMER_SPECIAL_KEYS_REMAP = {
+    "query_key_value": reassign_query_key_value_inplace,
+    "query_layernorm_list": reassign_query_key_layernorm_inplace,
+    "key_layernorm_list": reassign_query_key_layernorm_inplace,
+    "adaln_layer.adaLN_modulations": reassign_adaln_norm_inplace,
+    "embed_tokens": remove_keys_inplace,
+    "freqs_sin": remove_keys_inplace,
+    "freqs_cos": remove_keys_inplace,
+    "position_embedding": remove_keys_inplace,
+}
+
+VAE_KEYS_RENAME_DICT = {
+    "block.": "resnets.",
+    "down.": "down_blocks.",
+    "downsample": "downsamplers.0",
+    "upsample": "upsamplers.0",
+    "nin_shortcut": "conv_shortcut",
+    "encoder.mid.block_1": "encoder.mid_block.resnets.0",
+    "encoder.mid.block_2": "encoder.mid_block.resnets.1",
+    "decoder.mid.block_1": "decoder.mid_block.resnets.0",
+    "decoder.mid.block_2": "decoder.mid_block.resnets.1",
+}
+
+VAE_SPECIAL_KEYS_REMAP = {
+    "loss": remove_keys_inplace,
+    "up.": replace_up_keys_inplace,
+}
+
+TOKENIZER_MAX_LENGTH = 226
+
+
+def get_state_dict(saved_dict: Dict[str, Any]) -> Dict[str, Any]:
+    state_dict = saved_dict
+    if "model" in saved_dict.keys():
+        state_dict = state_dict["model"]
+    if "module" in saved_dict.keys():
+        state_dict = state_dict["module"]
+    if "state_dict" in saved_dict.keys():
+        state_dict = state_dict["state_dict"]
+    return state_dict
+
+
+def update_state_dict_inplace(state_dict: Dict[str, Any], old_key: str, new_key: str) -> Dict[str, Any]:
+    state_dict[new_key] = state_dict.pop(old_key)
+
+
+def convert_transformer(
+    ckpt_path: str,
+    num_layers: int,
+    num_attention_heads: int,
+    use_rotary_positional_embeddings: bool,
+    i2v: bool,
+    dtype: torch.dtype,
+    init_kwargs: Dict[str, Any],
+):
+    PREFIX_KEY = "model.diffusion_model."
+
+    original_state_dict = get_state_dict(torch.load(ckpt_path, map_location="cpu", mmap=True))
+    transformer = CogVideoXTransformer3DModel(
+        in_channels=32 if i2v else 16,
+        num_layers=num_layers,
+        num_attention_heads=num_attention_heads,
+        use_rotary_positional_embeddings=use_rotary_positional_embeddings,
+        ofs_embed_dim=512 if (i2v and init_kwargs["patch_size_t"] is not None) else None,  # CogVideoX1.5-5B-I2V
+        use_learned_positional_embeddings=i2v and init_kwargs["patch_size_t"] is None,  # CogVideoX-5B-I2V
+        **init_kwargs,
+    ).to(dtype=dtype)
+
+    for key in list(original_state_dict.keys()):
+        new_key = key[len(PREFIX_KEY) :]
+        for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items():
+            new_key = new_key.replace(replace_key, rename_key)
+        update_state_dict_inplace(original_state_dict, key, new_key)
+
+    for key in list(original_state_dict.keys()):
+        for special_key, handler_fn_inplace in TRANSFORMER_SPECIAL_KEYS_REMAP.items():
+            if special_key not in key:
+                continue
+            handler_fn_inplace(key, original_state_dict)
+
+    transformer.load_state_dict(original_state_dict, strict=True)
+    return transformer
+
+
+def convert_vae(ckpt_path: str, scaling_factor: float, version: str, dtype: torch.dtype):
+    init_kwargs = {"scaling_factor": scaling_factor}
+    if version == "1.5":
+        init_kwargs.update({"invert_scale_latents": True})
+
+    original_state_dict = get_state_dict(torch.load(ckpt_path, map_location="cpu", mmap=True))
+    vae = AutoencoderKLCogVideoX(**init_kwargs).to(dtype=dtype)
+
+    for key in list(original_state_dict.keys()):
+        new_key = key[:]
+        for replace_key, rename_key in VAE_KEYS_RENAME_DICT.items():
+            new_key = new_key.replace(replace_key, rename_key)
+        update_state_dict_inplace(original_state_dict, key, new_key)
+
+    for key in list(original_state_dict.keys()):
+        for special_key, handler_fn_inplace in VAE_SPECIAL_KEYS_REMAP.items():
+            if special_key not in key:
+                continue
+            handler_fn_inplace(key, original_state_dict)
+
+    vae.load_state_dict(original_state_dict, strict=True)
+    return vae
+
+
+def get_transformer_init_kwargs(version: str):
+    if version == "1.0":
+        vae_scale_factor_spatial = 8
+        init_kwargs = {
+            "patch_size": 2,
+            "patch_size_t": None,
+            "patch_bias": True,
+            "sample_height": 480 // vae_scale_factor_spatial,
+            "sample_width": 720 // vae_scale_factor_spatial,
+            "sample_frames": 49,
+        }
+
+    elif version == "1.5":
+        vae_scale_factor_spatial = 8
+        init_kwargs = {
+            "patch_size": 2,
+            "patch_size_t": 2,
+            "patch_bias": False,
+            "sample_height": 300,
+            "sample_width": 300,
+            "sample_frames": 81,
+        }
+    else:
+        raise ValueError("Unsupported version of CogVideoX.")
+
+    return init_kwargs
+
+
+def get_args():
+    parser = argparse.ArgumentParser()
+    parser.add_argument(
+        "--transformer_ckpt_path", type=str, default=None, help="Path to original transformer checkpoint"
+    )
+    parser.add_argument("--vae_ckpt_path", type=str, default=None, help="Path to original vae checkpoint")
+    parser.add_argument("--output_path", type=str, required=True, help="Path where converted model should be saved")
+    parser.add_argument("--fp16", action="store_true", default=False, help="Whether to save the model weights in fp16")
+    parser.add_argument("--bf16", action="store_true", default=False, help="Whether to save the model weights in bf16")
+    parser.add_argument(
+        "--push_to_hub", action="store_true", default=False, help="Whether to push to HF Hub after saving"
+    )
+    parser.add_argument(
+        "--text_encoder_cache_dir", type=str, default=None, help="Path to text encoder cache directory"
+    )
+    parser.add_argument(
+        "--typecast_text_encoder",
+        action="store_true",
+        default=False,
+        help="Whether or not to apply fp16/bf16 precision to text_encoder",
+    )
+    # For CogVideoX-2B, num_layers is 30. For 5B, it is 42
+    parser.add_argument("--num_layers", type=int, default=30, help="Number of transformer blocks")
+    # For CogVideoX-2B, num_attention_heads is 30. For 5B, it is 48
+    parser.add_argument("--num_attention_heads", type=int, default=30, help="Number of attention heads")
+    # For CogVideoX-2B, use_rotary_positional_embeddings is False. For 5B, it is True
+    parser.add_argument(
+        "--use_rotary_positional_embeddings", action="store_true", default=False, help="Whether to use RoPE or not"
+    )
+    # For CogVideoX-2B, scaling_factor is 1.15258426. For 5B, it is 0.7
+    parser.add_argument("--scaling_factor", type=float, default=1.15258426, help="Scaling factor in the VAE")
+    # For CogVideoX-2B, snr_shift_scale is 3.0. For 5B, it is 1.0
+    parser.add_argument("--snr_shift_scale", type=float, default=3.0, help="Scaling factor in the VAE")
+    parser.add_argument(
+        "--i2v",
+        action="store_true",
+        default=False,
+        help="Whether the model to be converted is the Image-to-Video version of CogVideoX.",
+    )
+    parser.add_argument(
+        "--version",
+        choices=["1.0", "1.5"],
+        default="1.0",
+        help="Which version of CogVideoX to use for initializing default modeling parameters.",
+    )
+    return parser.parse_args()
+
+
+if __name__ == "__main__":
+    args = get_args()
+
+    transformer = None
+    vae = None
+
+    if args.fp16 and args.bf16:
+        raise ValueError("You cannot pass both --fp16 and --bf16 at the same time.")
+
+    dtype = torch.float16 if args.fp16 else torch.bfloat16 if args.bf16 else torch.float32
+
+    if args.transformer_ckpt_path is not None:
+        init_kwargs = get_transformer_init_kwargs(args.version)
+        transformer = convert_transformer(
+            args.transformer_ckpt_path,
+            args.num_layers,
+            args.num_attention_heads,
+            args.use_rotary_positional_embeddings,
+            args.i2v,
+            dtype,
+            init_kwargs,
+        )
+    if args.vae_ckpt_path is not None:
+        # Keep VAE in float32 for better quality
+        vae = convert_vae(args.vae_ckpt_path, args.scaling_factor, args.version, torch.float32)
+
+    text_encoder_id = "google/t5-v1_1-xxl"
+    tokenizer = T5Tokenizer.from_pretrained(text_encoder_id, model_max_length=TOKENIZER_MAX_LENGTH)
+    text_encoder = T5EncoderModel.from_pretrained(text_encoder_id, cache_dir=args.text_encoder_cache_dir)
+
+    if args.typecast_text_encoder:
+        text_encoder = text_encoder.to(dtype=dtype)
+
+    # Apparently, the conversion does not work anymore without this :shrug:
+    for param in text_encoder.parameters():
+        param.data = param.data.contiguous()
+
+    scheduler = CogVideoXDDIMScheduler.from_config(
+        {
+            "snr_shift_scale": args.snr_shift_scale,
+            "beta_end": 0.012,
+            "beta_schedule": "scaled_linear",
+            "beta_start": 0.00085,
+            "clip_sample": False,
+            "num_train_timesteps": 1000,
+            "prediction_type": "v_prediction",
+            "rescale_betas_zero_snr": True,
+            "set_alpha_to_one": True,
+            "timestep_spacing": "trailing",
+        }
+    )
+    if args.i2v:
+        pipeline_cls = CogVideoXImageToVideoPipeline
+    else:
+        pipeline_cls = CogVideoXPipeline
+
+    pipe = pipeline_cls(
+        tokenizer=tokenizer,
+        text_encoder=text_encoder,
+        vae=vae,
+        transformer=transformer,
+        scheduler=scheduler,
+    )
+
+    # We don't use variant here because the model must be run in fp16 (2B) or bf16 (5B). It would be weird
+    # for users to specify variant when the default is not fp32 and they want to run with the correct default (which
+    # is either fp16/bf16 here).
+
+    # This is necessary This is necessary for users with insufficient memory,
+    # such as those using Colab and notebooks, as it can save some memory used for model loading.
+    pipe.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB", push_to_hub=args.push_to_hub)
diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py
index 30d497892ff5..95e3765c6996 100644
--- a/src/diffusers/__init__.py
+++ b/src/diffusers/__init__.py
@@ -192,6 +192,7 @@
             "CosmosTransformer3DModel",
             "DiTTransformer2DModel",
             "EasyAnimateTransformer3DModel",
+            "FlashVideoTransformer3DModel",
             "FluxControlNetModel",
             "FluxMultiControlNetModel",
             "FluxTransformer2DModel",
@@ -409,6 +410,8 @@
             "EasyAnimateControlPipeline",
             "EasyAnimateInpaintPipeline",
             "EasyAnimatePipeline",
+            "FlashVideoPipeline",
+            "FlashVideoVideoToVideoPipeline",
             "FluxControlImg2ImgPipeline",
             "FluxControlInpaintPipeline",
             "FluxControlNetImg2ImgPipeline",
@@ -846,6 +849,7 @@
             CosmosTransformer3DModel,
             DiTTransformer2DModel,
             EasyAnimateTransformer3DModel,
+            FlashVideoTransformer3DModel,
             FluxControlNetModel,
             FluxMultiControlNetModel,
             FluxTransformer2DModel,
@@ -1038,6 +1042,8 @@
             EasyAnimateControlPipeline,
             EasyAnimateInpaintPipeline,
             EasyAnimatePipeline,
+            FlashVideoPipeline,
+            FlashVideoVideoToVideoPipeline,
             FluxControlImg2ImgPipeline,
             FluxControlInpaintPipeline,
             FluxControlNetImg2ImgPipeline,
diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py
index cd1df3667a18..a67d7515158e 100755
--- a/src/diffusers/models/__init__.py
+++ b/src/diffusers/models/__init__.py
@@ -80,6 +80,7 @@
     _import_structure["transformers.transformer_cogview4"] = ["CogView4Transformer2DModel"]
     _import_structure["transformers.transformer_cosmos"] = ["CosmosTransformer3DModel"]
     _import_structure["transformers.transformer_easyanimate"] = ["EasyAnimateTransformer3DModel"]
+    _import_structure["transformers.transformer_flashvideo"] = ["FlashVideoTransformer3DModel"]
     _import_structure["transformers.transformer_flux"] = ["FluxTransformer2DModel"]
     _import_structure["transformers.transformer_hidream_image"] = ["HiDreamImageTransformer2DModel"]
     _import_structure["transformers.transformer_hunyuan_video"] = ["HunyuanVideoTransformer3DModel"]
@@ -164,6 +165,7 @@
             DiTTransformer2DModel,
             DualTransformer2DModel,
             EasyAnimateTransformer3DModel,
+            FlashVideoTransformer3DModel,
             FluxTransformer2DModel,
             HiDreamImageTransformer2DModel,
             HunyuanDiT2DModel,
diff --git a/src/diffusers/models/transformers/__init__.py b/src/diffusers/models/transformers/__init__.py
index dd8813369b5d..c80b05ee0372 100755
--- a/src/diffusers/models/transformers/__init__.py
+++ b/src/diffusers/models/transformers/__init__.py
@@ -22,6 +22,7 @@
     from .transformer_cogview4 import CogView4Transformer2DModel
     from .transformer_cosmos import CosmosTransformer3DModel
     from .transformer_easyanimate import EasyAnimateTransformer3DModel
+    from .transformer_flashvideo import FlashVideoTransformer3DModel
     from .transformer_flux import FluxTransformer2DModel
     from .transformer_hidream_image import HiDreamImageTransformer2DModel
     from .transformer_hunyuan_video import HunyuanVideoTransformer3DModel
diff --git a/src/diffusers/models/transformers/transformer_flashvideo.py b/src/diffusers/models/transformers/transformer_flashvideo.py
new file mode 100644
index 000000000000..515bd23569d0
--- /dev/null
+++ b/src/diffusers/models/transformers/transformer_flashvideo.py
@@ -0,0 +1,588 @@
+# Copyright 2024 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team.
+# All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Any, Dict, Optional, Tuple, Union
+
+import torch
+from torch import nn
+
+from ...configuration_utils import ConfigMixin, register_to_config
+from ...loaders import PeftAdapterMixin
+from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
+from ...utils.torch_utils import maybe_allow_in_graph
+from ..attention import Attention, FeedForward
+from ..attention_processor import AttentionProcessor, CogVideoXAttnProcessor2_0, FusedCogVideoXAttnProcessor2_0
+from ..cache_utils import CacheMixin
+from ..embeddings import CogVideoXPatchEmbed, TimestepEmbedding, Timesteps
+from ..modeling_outputs import Transformer2DModelOutput
+from ..modeling_utils import ModelMixin
+from ..normalization import AdaLayerNorm, CogVideoXLayerNormZero
+
+
+logger = logging.get_logger(__name__)  # pylint: disable=invalid-name
+
+
+# Copied from diffusers.models.transformers.cogvideox_transformer_3d.CogVideoXBlock with CogVideoX->FlashVideo
+@maybe_allow_in_graph
+class FlashVideoBlock(nn.Module):
+    r"""
+    Transformer block used in [CogVideoX](https://github.com/THUDM/CogVideo) model.
+
+    Parameters:
+        dim (`int`):
+            The number of channels in the input and output.
+        num_attention_heads (`int`):
+            The number of heads to use for multi-head attention.
+        attention_head_dim (`int`):
+            The number of channels in each head.
+        time_embed_dim (`int`):
+            The number of channels in timestep embedding.
+        dropout (`float`, defaults to `0.0`):
+            The dropout probability to use.
+        activation_fn (`str`, defaults to `"gelu-approximate"`):
+            Activation function to be used in feed-forward.
+        attention_bias (`bool`, defaults to `False`):
+            Whether or not to use bias in attention projection layers.
+        qk_norm (`bool`, defaults to `True`):
+            Whether or not to use normalization after query and key projections in Attention.
+        norm_elementwise_affine (`bool`, defaults to `True`):
+            Whether to use learnable elementwise affine parameters for normalization.
+        norm_eps (`float`, defaults to `1e-5`):
+            Epsilon value for normalization layers.
+        final_dropout (`bool` defaults to `False`):
+            Whether to apply a final dropout after the last feed-forward layer.
+        ff_inner_dim (`int`, *optional*, defaults to `None`):
+            Custom hidden dimension of Feed-forward layer. If not provided, `4 * dim` is used.
+        ff_bias (`bool`, defaults to `True`):
+            Whether or not to use bias in Feed-forward layer.
+        attention_out_bias (`bool`, defaults to `True`):
+            Whether or not to use bias in Attention output projection layer.
+    """
+
+    def __init__(
+        self,
+        dim: int,
+        num_attention_heads: int,
+        attention_head_dim: int,
+        time_embed_dim: int,
+        dropout: float = 0.0,
+        activation_fn: str = "gelu-approximate",
+        attention_bias: bool = False,
+        qk_norm: bool = True,
+        norm_elementwise_affine: bool = True,
+        norm_eps: float = 1e-5,
+        final_dropout: bool = True,
+        ff_inner_dim: Optional[int] = None,
+        ff_bias: bool = True,
+        attention_out_bias: bool = True,
+    ):
+        super().__init__()
+
+        # 1. Self Attention
+        self.norm1 = CogVideoXLayerNormZero(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True)
+
+        self.attn1 = Attention(
+            query_dim=dim,
+            dim_head=attention_head_dim,
+            heads=num_attention_heads,
+            qk_norm="layer_norm" if qk_norm else None,
+            eps=1e-6,
+            bias=attention_bias,
+            out_bias=attention_out_bias,
+            processor=CogVideoXAttnProcessor2_0(),
+        )
+
+        # 2. Feed Forward
+        self.norm2 = CogVideoXLayerNormZero(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True)
+
+        self.ff = FeedForward(
+            dim,
+            dropout=dropout,
+            activation_fn=activation_fn,
+            final_dropout=final_dropout,
+            inner_dim=ff_inner_dim,
+            bias=ff_bias,
+        )
+
+    def forward(
+        self,
+        hidden_states: torch.Tensor,
+        encoder_hidden_states: torch.Tensor,
+        temb: torch.Tensor,
+        image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
+        attention_kwargs: Optional[Dict[str, Any]] = None,
+    ) -> torch.Tensor:
+        text_seq_length = encoder_hidden_states.size(1)
+        attention_kwargs = attention_kwargs or {}
+
+        # norm & modulate
+        norm_hidden_states, norm_encoder_hidden_states, gate_msa, enc_gate_msa = self.norm1(
+            hidden_states, encoder_hidden_states, temb
+        )
+
+        # attention
+        attn_hidden_states, attn_encoder_hidden_states = self.attn1(
+            hidden_states=norm_hidden_states,
+            encoder_hidden_states=norm_encoder_hidden_states,
+            image_rotary_emb=image_rotary_emb,
+            **attention_kwargs,
+        )
+
+        hidden_states = hidden_states + gate_msa * attn_hidden_states
+        encoder_hidden_states = encoder_hidden_states + enc_gate_msa * attn_encoder_hidden_states
+
+        # norm & modulate
+        norm_hidden_states, norm_encoder_hidden_states, gate_ff, enc_gate_ff = self.norm2(
+            hidden_states, encoder_hidden_states, temb
+        )
+
+        # feed-forward
+        norm_hidden_states = torch.cat([norm_encoder_hidden_states, norm_hidden_states], dim=1)
+        ff_output = self.ff(norm_hidden_states)
+
+        hidden_states = hidden_states + gate_ff * ff_output[:, text_seq_length:]
+        encoder_hidden_states = encoder_hidden_states + enc_gate_ff * ff_output[:, :text_seq_length]
+
+        return hidden_states, encoder_hidden_states
+
+
+# Copied from diffusers.models.transformers.cogvideox_transformer_3d.CogVideoXTransformer3DModel with CogVideoX->FlashVideo
+class FlashVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, CacheMixin):
+    """
+    A Transformer model for video-like data in [CogVideoX](https://github.com/THUDM/CogVideo).
+
+    Parameters:
+        num_attention_heads (`int`, defaults to `30`):
+            The number of heads to use for multi-head attention.
+        attention_head_dim (`int`, defaults to `64`):
+            The number of channels in each head.
+        in_channels (`int`, defaults to `16`):
+            The number of channels in the input.
+        out_channels (`int`, *optional*, defaults to `16`):
+            The number of channels in the output.
+        flip_sin_to_cos (`bool`, defaults to `True`):
+            Whether to flip the sin to cos in the time embedding.
+        time_embed_dim (`int`, defaults to `512`):
+            Output dimension of timestep embeddings.
+        ofs_embed_dim (`int`, defaults to `512`):
+            Output dimension of "ofs" embeddings used in CogVideoX-5b-I2B in version 1.5
+        text_embed_dim (`int`, defaults to `4096`):
+            Input dimension of text embeddings from the text encoder.
+        num_layers (`int`, defaults to `30`):
+            The number of layers of Transformer blocks to use.
+        dropout (`float`, defaults to `0.0`):
+            The dropout probability to use.
+        attention_bias (`bool`, defaults to `True`):
+            Whether to use bias in the attention projection layers.
+        sample_width (`int`, defaults to `90`):
+            The width of the input latents.
+        sample_height (`int`, defaults to `60`):
+            The height of the input latents.
+        sample_frames (`int`, defaults to `49`):
+            The number of frames in the input latents. Note that this parameter was incorrectly initialized to 49
+            instead of 13 because CogVideoX processed 13 latent frames at once in its default and recommended settings,
+            but cannot be changed to the correct value to ensure backwards compatibility. To create a transformer with
+            K latent frames, the correct value to pass here would be: ((K - 1) * temporal_compression_ratio + 1).
+        patch_size (`int`, defaults to `2`):
+            The size of the patches to use in the patch embedding layer.
+        temporal_compression_ratio (`int`, defaults to `4`):
+            The compression ratio across the temporal dimension. See documentation for `sample_frames`.
+        max_text_seq_length (`int`, defaults to `226`):
+            The maximum sequence length of the input text embeddings.
+        activation_fn (`str`, defaults to `"gelu-approximate"`):
+            Activation function to use in feed-forward.
+        timestep_activation_fn (`str`, defaults to `"silu"`):
+            Activation function to use when generating the timestep embeddings.
+        norm_elementwise_affine (`bool`, defaults to `True`):
+            Whether to use elementwise affine in normalization layers.
+        norm_eps (`float`, defaults to `1e-5`):
+            The epsilon value to use in normalization layers.
+        spatial_interpolation_scale (`float`, defaults to `1.875`):
+            Scaling factor to apply in 3D positional embeddings across spatial dimensions.
+        temporal_interpolation_scale (`float`, defaults to `1.0`):
+            Scaling factor to apply in 3D positional embeddings across temporal dimensions.
+        use_rotary_positional_embeddings (`bool`, defaults to `False`):
+            Whether to use rotary positional embeddings (RoPE) for the attention layers.
+        use_learned_positional_embeddings (`bool`, defaults to `False`):
+            Whether to use learned positional embeddings for the attention layers. Note that at the time of writing
+            FlashVideo checkpoints do not use learned positional embeddings.
+        patch_bias (`bool`, defaults to `True`):
+            Whether to use a bias in the convolutional projection of the patch embedding layer.
+        use_time_size_embedding (`bool`, defaults to `True`):
+            Whether to use a new timestep embedding which adds information about the number of latent frames to
+            the timestep embedding. This is used in FlashVideo Stage 2 (but not Stage 1) checkpoints, and not used by
+            CogVideoX models.
+        use_noise_timestep_embedding (`bool`, defaults to `True`):
+            Whether to add a new timestep embedding which adds information about the noise timestep used for latent
+            degradation to the timestep embedding. This is used in FlashVideo Stage 2 (but not Stage 1) checkpoints,
+            and not used by CogVideoX models.
+    """
+
+    _skip_layerwise_casting_patterns = ["patch_embed", "norm"]
+    _supports_gradient_checkpointing = True
+    _no_split_modules = ["CogVideoXBlock", "CogVideoXPatchEmbed"]
+
+    @register_to_config
+    def __init__(
+        self,
+        num_attention_heads: int = 30,
+        attention_head_dim: int = 64,
+        in_channels: int = 16,
+        out_channels: Optional[int] = 16,
+        flip_sin_to_cos: bool = True,
+        freq_shift: int = 0,
+        time_embed_dim: int = 512,
+        ofs_embed_dim: Optional[int] = None,
+        text_embed_dim: int = 4096,
+        num_layers: int = 30,
+        dropout: float = 0.0,
+        attention_bias: bool = True,
+        sample_width: int = 90,
+        sample_height: int = 60,
+        sample_frames: int = 49,
+        patch_size: int = 2,
+        patch_size_t: Optional[int] = None,
+        temporal_compression_ratio: int = 4,
+        max_text_seq_length: int = 226,
+        activation_fn: str = "gelu-approximate",
+        timestep_activation_fn: str = "silu",
+        norm_elementwise_affine: bool = True,
+        norm_eps: float = 1e-5,
+        spatial_interpolation_scale: float = 1.875,
+        temporal_interpolation_scale: float = 1.0,
+        use_rotary_positional_embeddings: bool = False,
+        use_learned_positional_embeddings: bool = False,
+        patch_bias: bool = True,
+        use_time_size_embedding: bool = True,
+        use_noise_timestep_embedding: bool = True,
+    ):
+        super().__init__()
+        inner_dim = num_attention_heads * attention_head_dim
+
+        if not use_rotary_positional_embeddings and use_learned_positional_embeddings:
+            raise ValueError(
+                "There are no CogVideoX checkpoints available with disable rotary embeddings and learned positional "
+                "embeddings. If you're using a custom model and/or believe this should be supported, please open an "
+                "issue at https://github.com/huggingface/diffusers/issues."
+            )
+
+        # 1. Patch embedding
+        self.patch_embed = CogVideoXPatchEmbed(
+            patch_size=patch_size,
+            patch_size_t=patch_size_t,
+            in_channels=in_channels,
+            embed_dim=inner_dim,
+            text_embed_dim=text_embed_dim,
+            bias=patch_bias,
+            sample_width=sample_width,
+            sample_height=sample_height,
+            sample_frames=sample_frames,
+            temporal_compression_ratio=temporal_compression_ratio,
+            max_text_seq_length=max_text_seq_length,
+            spatial_interpolation_scale=spatial_interpolation_scale,
+            temporal_interpolation_scale=temporal_interpolation_scale,
+            use_positional_embeddings=not use_rotary_positional_embeddings,
+            use_learned_positional_embeddings=use_learned_positional_embeddings,
+        )
+        self.embedding_dropout = nn.Dropout(dropout)
+
+        # 2. Time embeddings and ofs embedding(Only CogVideoX1.5-5B I2V have)
+
+        self.time_proj = Timesteps(inner_dim, flip_sin_to_cos, freq_shift)
+        self.time_embedding = TimestepEmbedding(inner_dim, time_embed_dim, timestep_activation_fn)
+
+        self.ofs_proj = None
+        self.ofs_embedding = None
+        if ofs_embed_dim:
+            self.ofs_proj = Timesteps(ofs_embed_dim, flip_sin_to_cos, freq_shift)
+            self.ofs_embedding = TimestepEmbedding(
+                ofs_embed_dim, ofs_embed_dim, timestep_activation_fn
+            )  # same as time embeddings, for ofs
+
+        # TODO: I believe the OFS embedding from CogVideoX1.5 is different from either the time size embedding or the
+        # noise timestep embedding from FlashVideo.
+        # NOTE: this is new to the FlashVideo DiT as compared to the base CogVideoX DiT.
+        # Handle time size and noise timestep embeddings. These have the same shape as the normal time embeddings.
+        self.time_size_proj = None
+        self.time_size_embedding = None
+        if use_time_size_embedding:
+            self.time_size_proj = Timesteps(inner_dim, flip_sin_to_cos, freq_shift)
+            self.time_size_embedding = TimestepEmbedding(inner_dim, time_embed_dim, timestep_activation_fn)
+
+        self.noise_timestep_proj = None
+        self.noise_timestep_embedding = None
+        if use_noise_timestep_embedding:
+            self.noise_timestep_proj = Timesteps(inner_dim, flip_sin_to_cos, freq_shift)
+            self.noise_timestep_embedding = TimestepEmbedding(inner_dim, flip_sin_to_cos, freq_shift)
+
+        # 3. Define spatio-temporal transformers blocks
+        self.transformer_blocks = nn.ModuleList(
+            [
+                FlashVideoBlock(
+                    dim=inner_dim,
+                    num_attention_heads=num_attention_heads,
+                    attention_head_dim=attention_head_dim,
+                    time_embed_dim=time_embed_dim,
+                    dropout=dropout,
+                    activation_fn=activation_fn,
+                    attention_bias=attention_bias,
+                    norm_elementwise_affine=norm_elementwise_affine,
+                    norm_eps=norm_eps,
+                )
+                for _ in range(num_layers)
+            ]
+        )
+        self.norm_final = nn.LayerNorm(inner_dim, norm_eps, norm_elementwise_affine)
+
+        # 4. Output blocks
+        self.norm_out = AdaLayerNorm(
+            embedding_dim=time_embed_dim,
+            output_dim=2 * inner_dim,
+            norm_elementwise_affine=norm_elementwise_affine,
+            norm_eps=norm_eps,
+            chunk_dim=1,
+        )
+
+        if patch_size_t is None:
+            # For CogVideox 1.0
+            output_dim = patch_size * patch_size * out_channels
+        else:
+            # For CogVideoX 1.5
+            output_dim = patch_size * patch_size * patch_size_t * out_channels
+
+        self.proj_out = nn.Linear(inner_dim, output_dim)
+
+        self.gradient_checkpointing = False
+
+    @property
+    # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
+    def attn_processors(self) -> Dict[str, AttentionProcessor]:
+        r"""
+        Returns:
+            `dict` of attention processors: A dictionary containing all attention processors used in the model with
+            indexed by its weight name.
+        """
+        # set recursively
+        processors = {}
+
+        def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
+            if hasattr(module, "get_processor"):
+                processors[f"{name}.processor"] = module.get_processor()
+
+            for sub_name, child in module.named_children():
+                fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
+
+            return processors
+
+        for name, module in self.named_children():
+            fn_recursive_add_processors(name, module, processors)
+
+        return processors
+
+    # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
+    def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
+        r"""
+        Sets the attention processor to use to compute attention.
+
+        Parameters:
+            processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
+                The instantiated processor class or a dictionary of processor classes that will be set as the processor
+                for **all** `Attention` layers.
+
+                If `processor` is a dict, the key needs to define the path to the corresponding cross attention
+                processor. This is strongly recommended when setting trainable attention processors.
+
+        """
+        count = len(self.attn_processors.keys())
+
+        if isinstance(processor, dict) and len(processor) != count:
+            raise ValueError(
+                f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
+                f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
+            )
+
+        def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
+            if hasattr(module, "set_processor"):
+                if not isinstance(processor, dict):
+                    module.set_processor(processor)
+                else:
+                    module.set_processor(processor.pop(f"{name}.processor"))
+
+            for sub_name, child in module.named_children():
+                fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
+
+        for name, module in self.named_children():
+            fn_recursive_attn_processor(name, module, processor)
+
+    # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedCogVideoXAttnProcessor2_0
+    def fuse_qkv_projections(self):
+        """
+        Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
+        are fused. For cross-attention modules, key and value projection matrices are fused.
+
+        
+
+        This API is 🧪 experimental.
+
+        
+        """
+        self.original_attn_processors = None
+
+        for _, attn_processor in self.attn_processors.items():
+            if "Added" in str(attn_processor.__class__.__name__):
+                raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
+
+        self.original_attn_processors = self.attn_processors
+
+        for module in self.modules():
+            if isinstance(module, Attention):
+                module.fuse_projections(fuse=True)
+
+        self.set_attn_processor(FusedCogVideoXAttnProcessor2_0())
+
+    # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
+    def unfuse_qkv_projections(self):
+        """Disables the fused QKV projection if enabled.
+
+        
+
+        This API is 🧪 experimental.
+
+        
+
+        """
+        if self.original_attn_processors is not None:
+            self.set_attn_processor(self.original_attn_processors)
+
+    def forward(
+        self,
+        hidden_states: torch.Tensor,
+        encoder_hidden_states: torch.Tensor,
+        timestep: Union[int, float, torch.LongTensor],
+        timestep_cond: Optional[torch.Tensor] = None,
+        noise_timestep: Optional[Union[int, float, torch.LongTensor]] = None,
+        time_size_tensor: Optional[torch.LongTensor] = None,
+        ofs: Optional[Union[int, float, torch.LongTensor]] = None,
+        image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
+        attention_kwargs: Optional[Dict[str, Any]] = None,
+        return_dict: bool = True,
+    ):
+        if attention_kwargs is not None:
+            attention_kwargs = attention_kwargs.copy()
+            lora_scale = attention_kwargs.pop("scale", 1.0)
+        else:
+            lora_scale = 1.0
+
+        if USE_PEFT_BACKEND:
+            # weight the lora layers by setting `lora_scale` for each PEFT layer
+            scale_lora_layers(self, lora_scale)
+        else:
+            if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
+                logger.warning(
+                    "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
+                )
+
+        batch_size, num_frames, channels, height, width = hidden_states.shape
+
+        # 1. Time embedding
+        timesteps = timestep
+        t_emb = self.time_proj(timesteps)
+
+        # timesteps does not contain any weights and will always return f32 tensors
+        # but time_embedding might actually be running in fp16. so we need to cast here.
+        # there might be better ways to encapsulate this.
+        t_emb = t_emb.to(dtype=hidden_states.dtype)
+        emb = self.time_embedding(t_emb, timestep_cond)
+
+        if self.ofs_embedding is not None:
+            ofs_emb = self.ofs_proj(ofs)
+            ofs_emb = ofs_emb.to(dtype=hidden_states.dtype)
+            ofs_emb = self.ofs_embedding(ofs_emb)
+            emb = emb + ofs_emb
+
+        # NOTE: this is new to the FlashVideo DiT as compared to the base CogVideoX DiT.
+        # Handle time step and noise timestep embeddings, if using
+        if self.time_size_embedding is not None:
+            # Explicitly prepare num_frames as a 1D tensor for the sinusoidal timestep embedding
+            if time_size_tensor is not None:
+                time_size_tensor = torch.full((batch_size,), num_frames)
+            time_size_emb = self.time_size_proj(time_size_tensor)
+            time_size_emb = time_size_emb.to(dtype=hidden_states.dtype)
+            time_size_emb = self.time_size_embedding(time_size_emb)
+            emb = emb + time_size_emb
+
+        if self.noise_timestep_embedding is not None:
+            # Explicitly prepare noise_timestep as a 1D tensor for the sinusoidal timestep embedding
+            if not isinstance(noise_timestep, torch.Tensor):
+                noise_timestep = torch.full((batch_size,), noise_timestep)
+            noise_timestep_emb = self.noise_timestep_proj(noise_timestep)
+            noise_timestep_emb = noise_timestep_emb.to(dtype=hidden_states.dtype)
+            noise_timestep_emb = self.noise_timestep_embedding(noise_timestep_emb)
+            emb = emb + noise_timestep_emb
+
+        # 2. Patch embedding
+        hidden_states = self.patch_embed(encoder_hidden_states, hidden_states)
+        hidden_states = self.embedding_dropout(hidden_states)
+
+        text_seq_length = encoder_hidden_states.shape[1]
+        encoder_hidden_states = hidden_states[:, :text_seq_length]
+        hidden_states = hidden_states[:, text_seq_length:]
+
+        # 3. Transformer blocks
+        for i, block in enumerate(self.transformer_blocks):
+            if torch.is_grad_enabled() and self.gradient_checkpointing:
+                hidden_states, encoder_hidden_states = self._gradient_checkpointing_func(
+                    block,
+                    hidden_states,
+                    encoder_hidden_states,
+                    emb,
+                    image_rotary_emb,
+                    attention_kwargs,
+                )
+            else:
+                hidden_states, encoder_hidden_states = block(
+                    hidden_states=hidden_states,
+                    encoder_hidden_states=encoder_hidden_states,
+                    temb=emb,
+                    image_rotary_emb=image_rotary_emb,
+                    attention_kwargs=attention_kwargs,
+                )
+
+        hidden_states = self.norm_final(hidden_states)
+
+        # 4. Final block
+        hidden_states = self.norm_out(hidden_states, temb=emb)
+        hidden_states = self.proj_out(hidden_states)
+
+        # 5. Unpatchify
+        p = self.config.patch_size
+        p_t = self.config.patch_size_t
+
+        if p_t is None:
+            output = hidden_states.reshape(batch_size, num_frames, height // p, width // p, -1, p, p)
+            output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4)
+        else:
+            output = hidden_states.reshape(
+                batch_size, (num_frames + p_t - 1) // p_t, height // p, width // p, -1, p_t, p, p
+            )
+            output = output.permute(0, 1, 5, 4, 2, 6, 3, 7).flatten(6, 7).flatten(4, 5).flatten(1, 2)
+
+        if USE_PEFT_BACKEND:
+            # remove `lora_scale` from each PEFT layer
+            unscale_lora_layers(self, lora_scale)
+
+        if not return_dict:
+            return (output,)
+        return Transformer2DModelOutput(sample=output)
diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py
index c8fbdf0c6c29..84c80de49454 100644
--- a/src/diffusers/pipelines/__init__.py
+++ b/src/diffusers/pipelines/__init__.py
@@ -231,6 +231,7 @@
         "EasyAnimateInpaintPipeline",
         "EasyAnimateControlPipeline",
     ]
+    _import_structure["flashvideo"] = ["FlashVideoPipeline", "FlashVideoVideoToVideoPipeline"]
     _import_structure["hidream_image"] = ["HiDreamImagePipeline"]
     _import_structure["hunyuandit"] = ["HunyuanDiTPipeline"]
     _import_structure["hunyuan_video"] = [
@@ -608,6 +609,7 @@
             EasyAnimateInpaintPipeline,
             EasyAnimatePipeline,
         )
+        from .flashvideo import FlashVideoPipeline, FlashVideoVideoToVideoPipeline
         from .flux import (
             FluxControlImg2ImgPipeline,
             FluxControlInpaintPipeline,
diff --git a/src/diffusers/pipelines/flashvideo/__init__.py b/src/diffusers/pipelines/flashvideo/__init__.py
new file mode 100644
index 000000000000..a3ace6c8a5b1
--- /dev/null
+++ b/src/diffusers/pipelines/flashvideo/__init__.py
@@ -0,0 +1,50 @@
+from typing import TYPE_CHECKING
+
+from ...utils import (
+    DIFFUSERS_SLOW_IMPORT,
+    OptionalDependencyNotAvailable,
+    _LazyModule,
+    get_objects_from_module,
+    is_torch_available,
+    is_transformers_available,
+)
+
+
+_dummy_objects = {}
+_import_structure = {}
+
+
+try:
+    if not (is_transformers_available() and is_torch_available()):
+        raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+    from ...utils import dummy_torch_and_transformers_objects  # noqa F403
+
+    _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
+else:
+    _import_structure["pipeline_flashvideo"] = ["FlashVideoPipeline"]
+    _import_structure["pipeline_flashvideo_video2video"] = ["FlashVideoVideoToVideoPipeline"]
+
+if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
+    try:
+        if not (is_transformers_available() and is_torch_available()):
+            raise OptionalDependencyNotAvailable()
+
+    except OptionalDependencyNotAvailable:
+        from ...utils.dummy_torch_and_transformers_objects import *
+    else:
+        from .pipeline_flashvideo import FlashVideoPipeline
+        from .pipeline_flashvideo_video2video import FlashVideoVideoToVideoPipeline
+
+else:
+    import sys
+
+    sys.modules[__name__] = _LazyModule(
+        __name__,
+        globals()["__file__"],
+        _import_structure,
+        module_spec=__spec__,
+    )
+
+    for name, value in _dummy_objects.items():
+        setattr(sys.modules[__name__], name, value)
diff --git a/src/diffusers/pipelines/flashvideo/pipeline_flashvideo.py b/src/diffusers/pipelines/flashvideo/pipeline_flashvideo.py
new file mode 100644
index 000000000000..f79a77ec169a
--- /dev/null
+++ b/src/diffusers/pipelines/flashvideo/pipeline_flashvideo.py
@@ -0,0 +1,791 @@
+# Copyright 2024 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team.
+# All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+import math
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union
+
+import torch
+from transformers import T5EncoderModel, T5Tokenizer
+
+from ...callbacks import MultiPipelineCallbacks, PipelineCallback
+from ...loaders import CogVideoXLoraLoaderMixin
+from ...models import AutoencoderKLCogVideoX, FlashVideoTransformer3DModel
+from ...models.embeddings import get_3d_rotary_pos_embed
+from ...pipelines.pipeline_utils import DiffusionPipeline
+from ...schedulers import CogVideoXDDIMScheduler, CogVideoXDPMScheduler
+from ...utils import is_torch_xla_available, logging, replace_example_docstring
+from ...utils.torch_utils import randn_tensor
+from ...video_processor import VideoProcessor
+from .pipeline_output import FlashVideoPipelineOutput
+
+
+# NOTE: copied from CogVideoXPipelineOutput currently with CogVideoX->FlashVideo
+
+
+if is_torch_xla_available():
+    import torch_xla.core.xla_model as xm
+
+    XLA_AVAILABLE = True
+else:
+    XLA_AVAILABLE = False
+
+logger = logging.get_logger(__name__)  # pylint: disable=invalid-name
+
+
+EXAMPLE_DOC_STRING = """
+    Examples:
+        ```python
+        >>> import torch
+        >>> from diffusers import CogVideoXPipeline
+        >>> from diffusers.utils import export_to_video
+
+        >>> # Models: "THUDM/CogVideoX-2b" or "THUDM/CogVideoX-5b"
+        >>> pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-2b", torch_dtype=torch.float16).to("cuda")
+        >>> prompt = (
+        ...     "A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. "
+        ...     "The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other "
+        ...     "pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, "
+        ...     "casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. "
+        ...     "The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical "
+        ...     "atmosphere of this unique musical performance."
+        ... )
+        >>> video = pipe(prompt=prompt, guidance_scale=6, num_inference_steps=50).frames[0]
+        >>> export_to_video(video, "output.mp4", fps=8)
+        ```
+"""
+
+
+# Similar to diffusers.pipelines.hunyuandit.pipeline_hunyuandit.get_resize_crop_region_for_grid
+def get_resize_crop_region_for_grid(src, tgt_width, tgt_height):
+    tw = tgt_width
+    th = tgt_height
+    h, w = src
+    r = h / w
+    if r > (th / tw):
+        resize_height = th
+        resize_width = int(round(th / h * w))
+    else:
+        resize_width = tw
+        resize_height = int(round(tw / w * h))
+
+    crop_top = int(round((th - resize_height) / 2.0))
+    crop_left = int(round((tw - resize_width) / 2.0))
+
+    return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width)
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
+def retrieve_timesteps(
+    scheduler,
+    num_inference_steps: Optional[int] = None,
+    device: Optional[Union[str, torch.device]] = None,
+    timesteps: Optional[List[int]] = None,
+    sigmas: Optional[List[float]] = None,
+    **kwargs,
+):
+    r"""
+    Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
+    custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
+
+    Args:
+        scheduler (`SchedulerMixin`):
+            The scheduler to get timesteps from.
+        num_inference_steps (`int`):
+            The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
+            must be `None`.
+        device (`str` or `torch.device`, *optional*):
+            The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
+        timesteps (`List[int]`, *optional*):
+            Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
+            `num_inference_steps` and `sigmas` must be `None`.
+        sigmas (`List[float]`, *optional*):
+            Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
+            `num_inference_steps` and `timesteps` must be `None`.
+
+    Returns:
+        `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
+        second element is the number of inference steps.
+    """
+    if timesteps is not None and sigmas is not None:
+        raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
+    if timesteps is not None:
+        accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+        if not accepts_timesteps:
+            raise ValueError(
+                f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+                f" timestep schedules. Please check whether you are using the correct scheduler."
+            )
+        scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
+        timesteps = scheduler.timesteps
+        num_inference_steps = len(timesteps)
+    elif sigmas is not None:
+        accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+        if not accept_sigmas:
+            raise ValueError(
+                f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+                f" sigmas schedules. Please check whether you are using the correct scheduler."
+            )
+        scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
+        timesteps = scheduler.timesteps
+        num_inference_steps = len(timesteps)
+    else:
+        scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
+        timesteps = scheduler.timesteps
+    return timesteps, num_inference_steps
+
+
+class FlashVideoPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
+    r"""
+    Pipeline for text-to-video generation using FlashVideo, which itself is based on CogVideoX.
+
+    This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+    library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+    Args:
+        vae ([`AutoencoderKL`]):
+            Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
+        text_encoder ([`T5EncoderModel`]):
+            Frozen text-encoder. CogVideoX uses
+            [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel); specifically the
+            [t5-v1_1-xxl](https://huggingface.co/PixArt-alpha/PixArt-alpha/tree/main/t5-v1_1-xxl) variant.
+        tokenizer (`T5Tokenizer`):
+            Tokenizer of class
+            [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
+        transformer ([`FlashVideoTransformer3DModel`]):
+            A text conditioned `FlashVideoTransformer3DModel` to denoise the encoded video latents.
+        scheduler ([`SchedulerMixin`]):
+            A scheduler to be used in combination with `transformer` to denoise the encoded video latents.
+    """
+
+    _optional_components = []
+    model_cpu_offload_seq = "text_encoder->transformer->vae"
+
+    _callback_tensor_inputs = [
+        "latents",
+        "prompt_embeds",
+        "negative_prompt_embeds",
+    ]
+
+    def __init__(
+        self,
+        tokenizer: T5Tokenizer,
+        text_encoder: T5EncoderModel,
+        vae: AutoencoderKLCogVideoX,
+        transformer: FlashVideoTransformer3DModel,
+        scheduler: Union[CogVideoXDDIMScheduler, CogVideoXDPMScheduler],
+    ):
+        super().__init__()
+
+        self.register_modules(
+            tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler
+        )
+        self.vae_scale_factor_spatial = (
+            2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
+        )
+        self.vae_scale_factor_temporal = (
+            self.vae.config.temporal_compression_ratio if getattr(self, "vae", None) else 4
+        )
+        self.vae_scaling_factor_image = self.vae.config.scaling_factor if getattr(self, "vae", None) else 0.7
+
+        self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
+
+    def _get_t5_prompt_embeds(
+        self,
+        prompt: Union[str, List[str]] = None,
+        num_videos_per_prompt: int = 1,
+        max_sequence_length: int = 226,
+        device: Optional[torch.device] = None,
+        dtype: Optional[torch.dtype] = None,
+    ):
+        device = device or self._execution_device
+        dtype = dtype or self.text_encoder.dtype
+
+        prompt = [prompt] if isinstance(prompt, str) else prompt
+        batch_size = len(prompt)
+
+        text_inputs = self.tokenizer(
+            prompt,
+            padding="max_length",
+            max_length=max_sequence_length,
+            truncation=True,
+            add_special_tokens=True,
+            return_tensors="pt",
+        )
+        text_input_ids = text_inputs.input_ids
+        untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+
+        if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
+            removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])
+            logger.warning(
+                "The following part of your input was truncated because `max_sequence_length` is set to "
+                f" {max_sequence_length} tokens: {removed_text}"
+            )
+
+        prompt_embeds = self.text_encoder(text_input_ids.to(device))[0]
+        prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
+
+        # duplicate text embeddings for each generation per prompt, using mps friendly method
+        _, seq_len, _ = prompt_embeds.shape
+        prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
+        prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
+
+        return prompt_embeds
+
+    def encode_prompt(
+        self,
+        prompt: Union[str, List[str]],
+        negative_prompt: Optional[Union[str, List[str]]] = None,
+        do_classifier_free_guidance: bool = True,
+        num_videos_per_prompt: int = 1,
+        prompt_embeds: Optional[torch.Tensor] = None,
+        negative_prompt_embeds: Optional[torch.Tensor] = None,
+        max_sequence_length: int = 226,
+        device: Optional[torch.device] = None,
+        dtype: Optional[torch.dtype] = None,
+    ):
+        r"""
+        Encodes the prompt into text encoder hidden states.
+
+        Args:
+            prompt (`str` or `List[str]`, *optional*):
+                prompt to be encoded
+            negative_prompt (`str` or `List[str]`, *optional*):
+                The prompt or prompts not to guide the image generation. If not defined, one has to pass
+                `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+                less than `1`).
+            do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
+                Whether to use classifier free guidance or not.
+            num_videos_per_prompt (`int`, *optional*, defaults to 1):
+                Number of videos that should be generated per prompt. torch device to place the resulting embeddings on
+            prompt_embeds (`torch.Tensor`, *optional*):
+                Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+                provided, text embeddings will be generated from `prompt` input argument.
+            negative_prompt_embeds (`torch.Tensor`, *optional*):
+                Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+                weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+                argument.
+            device: (`torch.device`, *optional*):
+                torch device
+            dtype: (`torch.dtype`, *optional*):
+                torch dtype
+        """
+        device = device or self._execution_device
+
+        prompt = [prompt] if isinstance(prompt, str) else prompt
+        if prompt is not None:
+            batch_size = len(prompt)
+        else:
+            batch_size = prompt_embeds.shape[0]
+
+        if prompt_embeds is None:
+            prompt_embeds = self._get_t5_prompt_embeds(
+                prompt=prompt,
+                num_videos_per_prompt=num_videos_per_prompt,
+                max_sequence_length=max_sequence_length,
+                device=device,
+                dtype=dtype,
+            )
+
+        if do_classifier_free_guidance and negative_prompt_embeds is None:
+            negative_prompt = negative_prompt or ""
+            negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
+
+            if prompt is not None and type(prompt) is not type(negative_prompt):
+                raise TypeError(
+                    f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+                    f" {type(prompt)}."
+                )
+            elif batch_size != len(negative_prompt):
+                raise ValueError(
+                    f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+                    f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+                    " the batch size of `prompt`."
+                )
+
+            negative_prompt_embeds = self._get_t5_prompt_embeds(
+                prompt=negative_prompt,
+                num_videos_per_prompt=num_videos_per_prompt,
+                max_sequence_length=max_sequence_length,
+                device=device,
+                dtype=dtype,
+            )
+
+        return prompt_embeds, negative_prompt_embeds
+
+    def prepare_latents(
+        self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None
+    ):
+        if isinstance(generator, list) and len(generator) != batch_size:
+            raise ValueError(
+                f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+                f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+            )
+
+        shape = (
+            batch_size,
+            (num_frames - 1) // self.vae_scale_factor_temporal + 1,
+            num_channels_latents,
+            height // self.vae_scale_factor_spatial,
+            width // self.vae_scale_factor_spatial,
+        )
+
+        if latents is None:
+            latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+        else:
+            latents = latents.to(device)
+
+        # scale the initial noise by the standard deviation required by the scheduler
+        latents = latents * self.scheduler.init_noise_sigma
+        return latents
+
+    def decode_latents(self, latents: torch.Tensor) -> torch.Tensor:
+        latents = latents.permute(0, 2, 1, 3, 4)  # [batch_size, num_channels, num_frames, height, width]
+        latents = 1 / self.vae_scaling_factor_image * latents
+
+        frames = self.vae.decode(latents).sample
+        return frames
+
+    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
+    def prepare_extra_step_kwargs(self, generator, eta):
+        # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+        # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+        # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+        # and should be between [0, 1]
+
+        accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+        extra_step_kwargs = {}
+        if accepts_eta:
+            extra_step_kwargs["eta"] = eta
+
+        # check if the scheduler accepts generator
+        accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
+        if accepts_generator:
+            extra_step_kwargs["generator"] = generator
+        return extra_step_kwargs
+
+    # Copied from diffusers.pipelines.latte.pipeline_latte.LattePipeline.check_inputs
+    def check_inputs(
+        self,
+        prompt,
+        height,
+        width,
+        negative_prompt,
+        callback_on_step_end_tensor_inputs,
+        prompt_embeds=None,
+        negative_prompt_embeds=None,
+    ):
+        if height % 8 != 0 or width % 8 != 0:
+            raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
+
+        if callback_on_step_end_tensor_inputs is not None and not all(
+            k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
+        ):
+            raise ValueError(
+                f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
+            )
+        if prompt is not None and prompt_embeds is not None:
+            raise ValueError(
+                f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+                " only forward one of the two."
+            )
+        elif prompt is None and prompt_embeds is None:
+            raise ValueError(
+                "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+            )
+        elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+            raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+        if prompt is not None and negative_prompt_embeds is not None:
+            raise ValueError(
+                f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:"
+                f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+            )
+
+        if negative_prompt is not None and negative_prompt_embeds is not None:
+            raise ValueError(
+                f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
+                f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+            )
+
+        if prompt_embeds is not None and negative_prompt_embeds is not None:
+            if prompt_embeds.shape != negative_prompt_embeds.shape:
+                raise ValueError(
+                    "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
+                    f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
+                    f" {negative_prompt_embeds.shape}."
+                )
+
+    def fuse_qkv_projections(self) -> None:
+        r"""Enables fused QKV projections."""
+        self.fusing_transformer = True
+        self.transformer.fuse_qkv_projections()
+
+    def unfuse_qkv_projections(self) -> None:
+        r"""Disable QKV projection fusion if enabled."""
+        if not self.fusing_transformer:
+            logger.warning("The Transformer was not initially fused for QKV projections. Doing nothing.")
+        else:
+            self.transformer.unfuse_qkv_projections()
+            self.fusing_transformer = False
+
+    def _prepare_rotary_positional_embeddings(
+        self,
+        height: int,
+        width: int,
+        num_frames: int,
+        device: torch.device,
+    ) -> Tuple[torch.Tensor, torch.Tensor]:
+        grid_height = height // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
+        grid_width = width // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
+
+        p = self.transformer.config.patch_size
+        p_t = self.transformer.config.patch_size_t
+
+        base_size_width = self.transformer.config.sample_width // p
+        base_size_height = self.transformer.config.sample_height // p
+
+        if p_t is None:
+            # CogVideoX 1.0
+            grid_crops_coords = get_resize_crop_region_for_grid(
+                (grid_height, grid_width), base_size_width, base_size_height
+            )
+            freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
+                embed_dim=self.transformer.config.attention_head_dim,
+                crops_coords=grid_crops_coords,
+                grid_size=(grid_height, grid_width),
+                temporal_size=num_frames,
+                device=device,
+            )
+        else:
+            # CogVideoX 1.5
+            base_num_frames = (num_frames + p_t - 1) // p_t
+
+            freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
+                embed_dim=self.transformer.config.attention_head_dim,
+                crops_coords=None,
+                grid_size=(grid_height, grid_width),
+                temporal_size=base_num_frames,
+                grid_type="slice",
+                max_size=(base_size_height, base_size_width),
+                device=device,
+            )
+
+        return freqs_cos, freqs_sin
+
+    @property
+    def guidance_scale(self):
+        return self._guidance_scale
+
+    @property
+    def num_timesteps(self):
+        return self._num_timesteps
+
+    @property
+    def attention_kwargs(self):
+        return self._attention_kwargs
+
+    @property
+    def current_timestep(self):
+        return self._current_timestep
+
+    @property
+    def interrupt(self):
+        return self._interrupt
+
+    @torch.no_grad()
+    @replace_example_docstring(EXAMPLE_DOC_STRING)
+    def __call__(
+        self,
+        prompt: Optional[Union[str, List[str]]] = None,
+        negative_prompt: Optional[Union[str, List[str]]] = None,
+        height: Optional[int] = None,
+        width: Optional[int] = None,
+        num_frames: Optional[int] = None,
+        num_inference_steps: int = 50,
+        timesteps: Optional[List[int]] = None,
+        guidance_scale: float = 6,
+        use_dynamic_cfg: bool = False,
+        num_videos_per_prompt: int = 1,
+        eta: float = 0.0,
+        generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+        latents: Optional[torch.FloatTensor] = None,
+        prompt_embeds: Optional[torch.FloatTensor] = None,
+        negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+        output_type: str = "pil",
+        return_dict: bool = True,
+        attention_kwargs: Optional[Dict[str, Any]] = None,
+        callback_on_step_end: Optional[
+            Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
+        ] = None,
+        callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+        max_sequence_length: int = 226,
+    ) -> Union[FlashVideoPipelineOutput, Tuple]:
+        """
+        Function invoked when calling the pipeline for generation.
+
+        Args:
+            prompt (`str` or `List[str]`, *optional*):
+                The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+                instead.
+            negative_prompt (`str` or `List[str]`, *optional*):
+                The prompt or prompts not to guide the image generation. If not defined, one has to pass
+                `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+                less than `1`).
+            height (`int`, *optional*, defaults to self.transformer.config.sample_height * self.vae_scale_factor_spatial):
+                The height in pixels of the generated image. This is set to 480 by default for the best results.
+            width (`int`, *optional*, defaults to self.transformer.config.sample_height * self.vae_scale_factor_spatial):
+                The width in pixels of the generated image. This is set to 720 by default for the best results.
+            num_frames (`int`, defaults to `48`):
+                Number of frames to generate. Must be divisible by self.vae_scale_factor_temporal. Generated video will
+                contain 1 extra frame because CogVideoX is conditioned with (num_seconds * fps + 1) frames where
+                num_seconds is 6 and fps is 8. However, since videos can be saved at any fps, the only condition that
+                needs to be satisfied is that of divisibility mentioned above.
+            num_inference_steps (`int`, *optional*, defaults to 50):
+                The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+                expense of slower inference.
+            timesteps (`List[int]`, *optional*):
+                Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
+                in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
+                passed will be used. Must be in descending order.
+            guidance_scale (`float`, *optional*, defaults to 7.0):
+                Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+                `guidance_scale` is defined as `w` of equation 2. of [Imagen
+                Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+                1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+                usually at the expense of lower image quality.
+            num_videos_per_prompt (`int`, *optional*, defaults to 1):
+                The number of videos to generate per prompt.
+            generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+                One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+                to make generation deterministic.
+            latents (`torch.FloatTensor`, *optional*):
+                Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+                generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+                tensor will ge generated by sampling using the supplied random `generator`.
+            prompt_embeds (`torch.FloatTensor`, *optional*):
+                Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+                provided, text embeddings will be generated from `prompt` input argument.
+            negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+                Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+                weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+                argument.
+            output_type (`str`, *optional*, defaults to `"pil"`):
+                The output format of the generate image. Choose between
+                [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+            return_dict (`bool`, *optional*, defaults to `True`):
+                Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
+                of a plain tuple.
+            attention_kwargs (`dict`, *optional*):
+                A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+                `self.processor` in
+                [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+            callback_on_step_end (`Callable`, *optional*):
+                A function that calls at the end of each denoising steps during the inference. The function is called
+                with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
+                callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
+                `callback_on_step_end_tensor_inputs`.
+            callback_on_step_end_tensor_inputs (`List`, *optional*):
+                The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
+                will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
+                `._callback_tensor_inputs` attribute of your pipeline class.
+            max_sequence_length (`int`, defaults to `226`):
+                Maximum sequence length in encoded prompt. Must be consistent with
+                `self.transformer.config.max_text_seq_length` otherwise may lead to poor results.
+
+        Examples:
+
+        Returns:
+            [`~pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipelineOutput`] or `tuple`:
+            [`~pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipelineOutput`] if `return_dict` is True, otherwise a
+            `tuple`. When returning a tuple, the first element is a list with the generated images.
+        """
+
+        if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
+            callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
+
+        height = height or self.transformer.config.sample_height * self.vae_scale_factor_spatial
+        width = width or self.transformer.config.sample_width * self.vae_scale_factor_spatial
+        num_frames = num_frames or self.transformer.config.sample_frames
+
+        num_videos_per_prompt = 1
+
+        # 1. Check inputs. Raise error if not correct
+        self.check_inputs(
+            prompt,
+            height,
+            width,
+            negative_prompt,
+            callback_on_step_end_tensor_inputs,
+            prompt_embeds,
+            negative_prompt_embeds,
+        )
+        self._guidance_scale = guidance_scale
+        self._attention_kwargs = attention_kwargs
+        self._current_timestep = None
+        self._interrupt = False
+
+        # 2. Default call parameters
+        if prompt is not None and isinstance(prompt, str):
+            batch_size = 1
+        elif prompt is not None and isinstance(prompt, list):
+            batch_size = len(prompt)
+        else:
+            batch_size = prompt_embeds.shape[0]
+
+        device = self._execution_device
+
+        # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+        # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+        # corresponds to doing no classifier free guidance.
+        do_classifier_free_guidance = guidance_scale > 1.0
+
+        # 3. Encode input prompt
+        prompt_embeds, negative_prompt_embeds = self.encode_prompt(
+            prompt,
+            negative_prompt,
+            do_classifier_free_guidance,
+            num_videos_per_prompt=num_videos_per_prompt,
+            prompt_embeds=prompt_embeds,
+            negative_prompt_embeds=negative_prompt_embeds,
+            max_sequence_length=max_sequence_length,
+            device=device,
+        )
+        if do_classifier_free_guidance:
+            prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
+
+        # 4. Prepare timesteps
+        timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
+        self._num_timesteps = len(timesteps)
+
+        # 5. Prepare latents
+        latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
+
+        # For CogVideoX 1.5, the latent frames should be padded to make it divisible by patch_size_t
+        patch_size_t = self.transformer.config.patch_size_t
+        additional_frames = 0
+        if patch_size_t is not None and latent_frames % patch_size_t != 0:
+            additional_frames = patch_size_t - latent_frames % patch_size_t
+            num_frames += additional_frames * self.vae_scale_factor_temporal
+
+        latent_channels = self.transformer.config.in_channels
+        latents = self.prepare_latents(
+            batch_size * num_videos_per_prompt,
+            latent_channels,
+            num_frames,
+            height,
+            width,
+            prompt_embeds.dtype,
+            device,
+            generator,
+            latents,
+        )
+
+        # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
+        extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+        # 7. Create rotary embeds if required
+        image_rotary_emb = (
+            self._prepare_rotary_positional_embeddings(height, width, latents.size(1), device)
+            if self.transformer.config.use_rotary_positional_embeddings
+            else None
+        )
+
+        # 8. Denoising loop
+        num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
+
+        with self.progress_bar(total=num_inference_steps) as progress_bar:
+            # for DPM-solver++
+            old_pred_original_sample = None
+            for i, t in enumerate(timesteps):
+                if self.interrupt:
+                    continue
+
+                self._current_timestep = t
+                latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
+                latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+
+                # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
+                timestep = t.expand(latent_model_input.shape[0])
+
+                # predict noise model_output
+                noise_pred = self.transformer(
+                    hidden_states=latent_model_input,
+                    encoder_hidden_states=prompt_embeds,
+                    timestep=timestep,
+                    image_rotary_emb=image_rotary_emb,
+                    attention_kwargs=attention_kwargs,
+                    return_dict=False,
+                )[0]
+                noise_pred = noise_pred.float()
+
+                # perform guidance
+                if use_dynamic_cfg:
+                    self._guidance_scale = 1 + guidance_scale * (
+                        (1 - math.cos(math.pi * ((num_inference_steps - t.item()) / num_inference_steps) ** 5.0)) / 2
+                    )
+                if do_classifier_free_guidance:
+                    noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+                    noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+                # compute the previous noisy sample x_t -> x_t-1
+                if not isinstance(self.scheduler, CogVideoXDPMScheduler):
+                    latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
+                else:
+                    latents, old_pred_original_sample = self.scheduler.step(
+                        noise_pred,
+                        old_pred_original_sample,
+                        t,
+                        timesteps[i - 1] if i > 0 else None,
+                        latents,
+                        **extra_step_kwargs,
+                        return_dict=False,
+                    )
+                latents = latents.to(prompt_embeds.dtype)
+
+                # call the callback, if provided
+                if callback_on_step_end is not None:
+                    callback_kwargs = {}
+                    for k in callback_on_step_end_tensor_inputs:
+                        callback_kwargs[k] = locals()[k]
+                    callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+
+                    latents = callback_outputs.pop("latents", latents)
+                    prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
+                    negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
+
+                if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+                    progress_bar.update()
+
+                if XLA_AVAILABLE:
+                    xm.mark_step()
+
+        self._current_timestep = None
+
+        if not output_type == "latent":
+            # Discard any padding frames that were added for CogVideoX 1.5
+            latents = latents[:, additional_frames:]
+            video = self.decode_latents(latents)
+            video = self.video_processor.postprocess_video(video=video, output_type=output_type)
+        else:
+            video = latents
+
+        # Offload all models
+        self.maybe_free_model_hooks()
+
+        if not return_dict:
+            return (video,)
+
+        return FlashVideoPipelineOutput(frames=video)
diff --git a/src/diffusers/pipelines/flashvideo/pipeline_flashvideo_video2video.py b/src/diffusers/pipelines/flashvideo/pipeline_flashvideo_video2video.py
new file mode 100644
index 000000000000..fcd88d2729e8
--- /dev/null
+++ b/src/diffusers/pipelines/flashvideo/pipeline_flashvideo_video2video.py
@@ -0,0 +1,916 @@
+# Copyright 2024 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team.
+# All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+import math
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union
+
+import torch
+from PIL import Image
+from transformers import T5EncoderModel, T5Tokenizer
+
+from ...callbacks import MultiPipelineCallbacks, PipelineCallback
+from ...loaders import CogVideoXLoraLoaderMixin
+from ...models import AutoencoderKLCogVideoX, FlashVideoTransformer3DModel
+from ...models.embeddings import get_3d_rotary_pos_embed
+from ...pipelines.pipeline_utils import DiffusionPipeline
+from ...schedulers import CogVideoXDDIMScheduler, CogVideoXDPMScheduler
+from ...utils import is_torch_xla_available, logging, replace_example_docstring
+from ...utils.torch_utils import randn_tensor
+from ...video_processor import VideoProcessor
+from .pipeline_output import FlashVideoPipelineOutput
+
+
+# NOTE: copied from CogVideoXPipelineOutput currently with CogVideoX->FlashVideo
+
+
+if is_torch_xla_available():
+    import torch_xla.core.xla_model as xm
+
+    XLA_AVAILABLE = True
+else:
+    XLA_AVAILABLE = False
+
+logger = logging.get_logger(__name__)  # pylint: disable=invalid-name
+
+
+EXAMPLE_DOC_STRING = """
+    Examples:
+        ```python
+        >>> import torch
+        >>> from diffusers import CogVideoXPipeline
+        >>> from diffusers.utils import export_to_video
+
+        >>> # Models: "THUDM/CogVideoX-2b" or "THUDM/CogVideoX-5b"
+        >>> pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-2b", torch_dtype=torch.float16).to("cuda")
+        >>> prompt = (
+        ...     "A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. "
+        ...     "The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other "
+        ...     "pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, "
+        ...     "casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. "
+        ...     "The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical "
+        ...     "atmosphere of this unique musical performance."
+        ... )
+        >>> video = pipe(prompt=prompt, guidance_scale=6, num_inference_steps=50).frames[0]
+        >>> export_to_video(video, "output.mp4", fps=8)
+        ```
+"""
+
+
+# Similar to diffusers.pipelines.hunyuandit.pipeline_hunyuandit.get_resize_crop_region_for_grid
+def get_resize_crop_region_for_grid(src, tgt_width, tgt_height):
+    tw = tgt_width
+    th = tgt_height
+    h, w = src
+    r = h / w
+    if r > (th / tw):
+        resize_height = th
+        resize_width = int(round(th / h * w))
+    else:
+        resize_width = tw
+        resize_height = int(round(tw / w * h))
+
+    crop_top = int(round((th - resize_height) / 2.0))
+    crop_left = int(round((tw - resize_width) / 2.0))
+
+    return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width)
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
+def retrieve_timesteps(
+    scheduler,
+    num_inference_steps: Optional[int] = None,
+    device: Optional[Union[str, torch.device]] = None,
+    timesteps: Optional[List[int]] = None,
+    sigmas: Optional[List[float]] = None,
+    **kwargs,
+):
+    r"""
+    Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
+    custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
+
+    Args:
+        scheduler (`SchedulerMixin`):
+            The scheduler to get timesteps from.
+        num_inference_steps (`int`):
+            The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
+            must be `None`.
+        device (`str` or `torch.device`, *optional*):
+            The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
+        timesteps (`List[int]`, *optional*):
+            Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
+            `num_inference_steps` and `sigmas` must be `None`.
+        sigmas (`List[float]`, *optional*):
+            Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
+            `num_inference_steps` and `timesteps` must be `None`.
+
+    Returns:
+        `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
+        second element is the number of inference steps.
+    """
+    if timesteps is not None and sigmas is not None:
+        raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
+    if timesteps is not None:
+        accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+        if not accepts_timesteps:
+            raise ValueError(
+                f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+                f" timestep schedules. Please check whether you are using the correct scheduler."
+            )
+        scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
+        timesteps = scheduler.timesteps
+        num_inference_steps = len(timesteps)
+    elif sigmas is not None:
+        accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+        if not accept_sigmas:
+            raise ValueError(
+                f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+                f" sigmas schedules. Please check whether you are using the correct scheduler."
+            )
+        scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
+        timesteps = scheduler.timesteps
+        num_inference_steps = len(timesteps)
+    else:
+        scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
+        timesteps = scheduler.timesteps
+    return timesteps, num_inference_steps
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
+def retrieve_latents(
+    encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
+):
+    if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
+        return encoder_output.latent_dist.sample(generator)
+    elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
+        return encoder_output.latent_dist.mode()
+    elif hasattr(encoder_output, "latents"):
+        return encoder_output.latents
+    else:
+        raise AttributeError("Could not access latents of provided encoder_output")
+
+
+class FlashVideoVideoToVideoPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
+    r"""
+    Pipeline for text-to-video generation using FlashVideo, which itself is based on CogVideoX.
+
+    This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+    library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+    Args:
+        vae ([`AutoencoderKL`]):
+            Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
+        text_encoder ([`T5EncoderModel`]):
+            Frozen text-encoder. CogVideoX uses
+            [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel); specifically the
+            [t5-v1_1-xxl](https://huggingface.co/PixArt-alpha/PixArt-alpha/tree/main/t5-v1_1-xxl) variant.
+        tokenizer (`T5Tokenizer`):
+            Tokenizer of class
+            [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
+        transformer ([`FlashVideoTransformer3DModel`]):
+            A text conditioned `FlashVideoTransformer3DModel` to denoise the encoded video latents.
+        scheduler ([`SchedulerMixin`]):
+            A scheduler to be used in combination with `transformer` to denoise the encoded video latents.
+    """
+
+    _optional_components = []
+    model_cpu_offload_seq = "text_encoder->transformer->vae"
+
+    _callback_tensor_inputs = [
+        "latents",
+        "prompt_embeds",
+        "negative_prompt_embeds",
+    ]
+
+    def __init__(
+        self,
+        tokenizer: T5Tokenizer,
+        text_encoder: T5EncoderModel,
+        vae: AutoencoderKLCogVideoX,
+        transformer: FlashVideoTransformer3DModel,
+        scheduler: Union[CogVideoXDDIMScheduler, CogVideoXDPMScheduler],
+    ):
+        super().__init__()
+
+        self.register_modules(
+            tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler
+        )
+        self.vae_scale_factor_spatial = (
+            2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
+        )
+        self.vae_scale_factor_temporal = (
+            self.vae.config.temporal_compression_ratio if getattr(self, "vae", None) else 4
+        )
+        self.vae_scaling_factor_image = self.vae.config.scaling_factor if getattr(self, "vae", None) else 0.7
+
+        self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
+
+    def _get_t5_prompt_embeds(
+        self,
+        prompt: Union[str, List[str]] = None,
+        num_videos_per_prompt: int = 1,
+        max_sequence_length: int = 226,
+        device: Optional[torch.device] = None,
+        dtype: Optional[torch.dtype] = None,
+    ):
+        device = device or self._execution_device
+        dtype = dtype or self.text_encoder.dtype
+
+        prompt = [prompt] if isinstance(prompt, str) else prompt
+        batch_size = len(prompt)
+
+        text_inputs = self.tokenizer(
+            prompt,
+            padding="max_length",
+            max_length=max_sequence_length,
+            truncation=True,
+            add_special_tokens=True,
+            return_tensors="pt",
+        )
+        text_input_ids = text_inputs.input_ids
+        untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+
+        if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
+            removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])
+            logger.warning(
+                "The following part of your input was truncated because `max_sequence_length` is set to "
+                f" {max_sequence_length} tokens: {removed_text}"
+            )
+
+        prompt_embeds = self.text_encoder(text_input_ids.to(device))[0]
+        prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
+
+        # duplicate text embeddings for each generation per prompt, using mps friendly method
+        _, seq_len, _ = prompt_embeds.shape
+        prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
+        prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
+
+        return prompt_embeds
+
+    def encode_prompt(
+        self,
+        prompt: Union[str, List[str]],
+        negative_prompt: Optional[Union[str, List[str]]] = None,
+        do_classifier_free_guidance: bool = True,
+        num_videos_per_prompt: int = 1,
+        prompt_embeds: Optional[torch.Tensor] = None,
+        negative_prompt_embeds: Optional[torch.Tensor] = None,
+        max_sequence_length: int = 226,
+        device: Optional[torch.device] = None,
+        dtype: Optional[torch.dtype] = None,
+    ):
+        r"""
+        Encodes the prompt into text encoder hidden states.
+
+        Args:
+            prompt (`str` or `List[str]`, *optional*):
+                prompt to be encoded
+            negative_prompt (`str` or `List[str]`, *optional*):
+                The prompt or prompts not to guide the image generation. If not defined, one has to pass
+                `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+                less than `1`).
+            do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
+                Whether to use classifier free guidance or not.
+            num_videos_per_prompt (`int`, *optional*, defaults to 1):
+                Number of videos that should be generated per prompt. torch device to place the resulting embeddings on
+            prompt_embeds (`torch.Tensor`, *optional*):
+                Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+                provided, text embeddings will be generated from `prompt` input argument.
+            negative_prompt_embeds (`torch.Tensor`, *optional*):
+                Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+                weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+                argument.
+            device: (`torch.device`, *optional*):
+                torch device
+            dtype: (`torch.dtype`, *optional*):
+                torch dtype
+        """
+        device = device or self._execution_device
+
+        prompt = [prompt] if isinstance(prompt, str) else prompt
+        if prompt is not None:
+            batch_size = len(prompt)
+        else:
+            batch_size = prompt_embeds.shape[0]
+
+        if prompt_embeds is None:
+            prompt_embeds = self._get_t5_prompt_embeds(
+                prompt=prompt,
+                num_videos_per_prompt=num_videos_per_prompt,
+                max_sequence_length=max_sequence_length,
+                device=device,
+                dtype=dtype,
+            )
+
+        if do_classifier_free_guidance and negative_prompt_embeds is None:
+            negative_prompt = negative_prompt or ""
+            negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
+
+            if prompt is not None and type(prompt) is not type(negative_prompt):
+                raise TypeError(
+                    f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+                    f" {type(prompt)}."
+                )
+            elif batch_size != len(negative_prompt):
+                raise ValueError(
+                    f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+                    f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+                    " the batch size of `prompt`."
+                )
+
+            negative_prompt_embeds = self._get_t5_prompt_embeds(
+                prompt=negative_prompt,
+                num_videos_per_prompt=num_videos_per_prompt,
+                max_sequence_length=max_sequence_length,
+                device=device,
+                dtype=dtype,
+            )
+
+        return prompt_embeds, negative_prompt_embeds
+
+    def prepare_latents(
+        self,
+        video: Optional[torch.Tensor] = None,
+        batch_size: int = 1,
+        num_channels_latents: int = 16,
+        height: int = 1080,  # Should be the target height in pixels
+        width: int = 1920,  # Should be the target width in pixels
+        dtype: Optional[torch.dtype] = None,
+        device: Optional[torch.device] = None,
+        generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+        latents: Optional[torch.Tensor] = None,
+        timestep: Optional[torch.Tensor] = None,
+        interpolation_mode: str = "bilinear",
+        sample_mode: str = "sample",
+    ):
+        if isinstance(generator, list) and len(generator) != batch_size:
+            raise ValueError(
+                f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+                f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+            )
+        
+        num_frames = (video.size(2) - 1) // self.vae_scale_factor_temporal + 1 if latents is None else latents.size(1)
+
+        # NOTE: the shape here corresponds to the latent shape of the target size, which is in general different from
+        # the latent shape of the input `video`.
+        shape = (
+            batch_size,
+            num_frames,
+            num_channels_latents,
+            height // self.vae_scale_factor_spatial,
+            width // self.vae_scale_factor_spatial,
+        )
+
+        if latents is None:
+            # First upscale the conditioning video to the desired shape.
+            upscaled_video = self.upscale_latents(latents, height, width, interpolation_mode)
+
+            # Now encode the upscaled video using the 3D VAE encoder.
+            latents = self.encode_latents(upscaled_video, batch_size, generator, sample_mode)
+
+            # If timestep is not supplied, latent degradation for FlashVideo Stage 2 will be disabled. However, the
+            # authors found that latent degradation is important for the final video quality, so it is recommended to
+            # supply a timestep.
+            if timestep is not None:
+                latent_noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+                # NOTE: the original implementation allows the noise to be modified using shift/scale factors. This is
+                # currently not implemented, and the original code defaults to using shift/scale values that would
+                # keep the original noise level from the scheduler.
+                latents = self.scheduler.add_noise(latents, latent_noise, timestep)
+        else:
+            # NOTE: in line with similar pipelines like CogVideoXVideoToVideoPipeline, supplied latents are expected
+            # to be directly usable at the start of the denoising loop. In particular, the latents must have the
+            # correct shape and already be noised (assuming latent degradation is used.)
+            latents = latents.to(device)
+
+        # scale the initial noise by the standard deviation required by the scheduler
+        latents = latents * self.scheduler.init_noise_sigma
+        return latents
+
+    def decode_latents(self, latents: torch.Tensor) -> torch.Tensor:
+        latents = latents.permute(0, 2, 1, 3, 4)  # [batch_size, num_channels, num_frames, height, width]
+        latents = 1 / self.vae_scaling_factor_image * latents
+
+        frames = self.vae.decode(latents).sample
+        return frames
+
+    def encode_latents(
+        self,
+        video: torch.Tensor,
+        batch_size: int = 1,
+        generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+        sample_mode: str = "sample",
+    ):
+        if isinstance(generator, list):
+            latents = [
+                retrieve_latents(self.vae.encode(video[i].unsqueeze(0)), generator[i], sample_mode) for i in range(batch_size)
+            ]
+        else:
+            latents = [retrieve_latents(self.vae.encode(vid.unsqueeze(0)), generator, sample_mode) for vid in video]
+        latents = torch.cat(latents, dim=0).permute(0, 2, 1, 3, 4)  # [B, F, C, H, W]
+        latents = latents * self.vae_scaling_factor_image
+        return latents
+
+    def upscale_latents(
+        self,
+        latents: torch.Tensor,
+        target_height: int = 1080,
+        target_width: int = 1920,
+        interpolation_mode: str = "bilinear",
+    ):
+        """
+        Upscales the low-resolution video from Stage 1 for to the target height/width use in Stage 2 high-resolution
+        video enhancement.
+        """
+        upscaled_latents = torch.nn.functional.interpolate(
+            latents,
+            size=(target_height, target_width),
+            mode=interpolation_mode,
+            align_corners=False,
+            antialias=True,
+        )
+        return upscaled_latents
+
+    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
+    def prepare_extra_step_kwargs(self, generator, eta):
+        # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+        # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+        # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+        # and should be between [0, 1]
+
+        accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+        extra_step_kwargs = {}
+        if accepts_eta:
+            extra_step_kwargs["eta"] = eta
+
+        # check if the scheduler accepts generator
+        accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
+        if accepts_generator:
+            extra_step_kwargs["generator"] = generator
+        return extra_step_kwargs
+
+    def check_inputs(
+        self,
+        video,
+        prompt,
+        negative_prompt,
+        height,
+        width,
+        latents,
+        prompt_embeds,
+        negative_prompt_embeds,
+        callback_on_step_end_tensor_inputs,
+    ):
+        if height % 8 != 0 or width % 8 != 0:
+            raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
+
+        if callback_on_step_end_tensor_inputs is not None and not all(
+            k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
+        ):
+            raise ValueError(
+                f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
+            )
+        if prompt is not None and prompt_embeds is not None:
+            raise ValueError(
+                f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+                " only forward one of the two."
+            )
+        elif prompt is None and prompt_embeds is None:
+            raise ValueError(
+                "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+            )
+        elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+            raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+        if prompt is not None and negative_prompt_embeds is not None:
+            raise ValueError(
+                f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:"
+                f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+            )
+
+        if negative_prompt is not None and negative_prompt_embeds is not None:
+            raise ValueError(
+                f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
+                f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+            )
+
+        if prompt_embeds is not None and negative_prompt_embeds is not None:
+            if prompt_embeds.shape != negative_prompt_embeds.shape:
+                raise ValueError(
+                    "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
+                    f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
+                    f" {negative_prompt_embeds.shape}."
+                )
+
+        if video is None and latents is None:
+            raise ValueError(
+                "At least one of `video` and `latents` must be supplied, but neither is available."
+            )
+
+    def fuse_qkv_projections(self) -> None:
+        r"""Enables fused QKV projections."""
+        self.fusing_transformer = True
+        self.transformer.fuse_qkv_projections()
+
+    def unfuse_qkv_projections(self) -> None:
+        r"""Disable QKV projection fusion if enabled."""
+        if not self.fusing_transformer:
+            logger.warning("The Transformer was not initially fused for QKV projections. Doing nothing.")
+        else:
+            self.transformer.unfuse_qkv_projections()
+            self.fusing_transformer = False
+
+    def _prepare_rotary_positional_embeddings(
+        self,
+        height: int,
+        width: int,
+        num_frames: int,
+        device: torch.device,
+    ) -> Tuple[torch.Tensor, torch.Tensor]:
+        grid_height = height // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
+        grid_width = width // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
+
+        p = self.transformer.config.patch_size
+        p_t = self.transformer.config.patch_size_t
+
+        base_size_width = self.transformer.config.sample_width // p
+        base_size_height = self.transformer.config.sample_height // p
+
+        if p_t is None:
+            # CogVideoX 1.0
+            grid_crops_coords = get_resize_crop_region_for_grid(
+                (grid_height, grid_width), base_size_width, base_size_height
+            )
+            freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
+                embed_dim=self.transformer.config.attention_head_dim,
+                crops_coords=grid_crops_coords,
+                grid_size=(grid_height, grid_width),
+                temporal_size=num_frames,
+                device=device,
+            )
+        else:
+            # CogVideoX 1.5
+            base_num_frames = (num_frames + p_t - 1) // p_t
+
+            freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
+                embed_dim=self.transformer.config.attention_head_dim,
+                crops_coords=None,
+                grid_size=(grid_height, grid_width),
+                temporal_size=base_num_frames,
+                grid_type="slice",
+                max_size=(base_size_height, base_size_width),
+                device=device,
+            )
+
+        return freqs_cos, freqs_sin
+
+    @property
+    def guidance_scale(self):
+        return self._guidance_scale
+
+    @property
+    def num_timesteps(self):
+        return self._num_timesteps
+
+    @property
+    def attention_kwargs(self):
+        return self._attention_kwargs
+
+    @property
+    def current_timestep(self):
+        return self._current_timestep
+
+    @property
+    def interrupt(self):
+        return self._interrupt
+
+    @torch.no_grad()
+    @replace_example_docstring(EXAMPLE_DOC_STRING)
+    def __call__(
+        self,
+        video: List[Image.Image] = None,
+        prompt: Optional[Union[str, List[str]]] = None,
+        negative_prompt: Optional[Union[str, List[str]]] = None,
+        height: Optional[int] = None,
+        width: Optional[int] = None,
+        num_frames: Optional[int] = None,
+        num_inference_steps: int = 50,
+        timesteps: Optional[List[int]] = None,
+        noise_timestep: Optional[int] = 675,
+        guidance_scale: float = 6,
+        use_dynamic_cfg: bool = False,
+        num_videos_per_prompt: int = 1,
+        eta: float = 0.0,
+        generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+        latents: Optional[torch.FloatTensor] = None,
+        prompt_embeds: Optional[torch.FloatTensor] = None,
+        negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+        output_type: str = "pil",
+        return_dict: bool = True,
+        attention_kwargs: Optional[Dict[str, Any]] = None,
+        callback_on_step_end: Optional[
+            Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
+        ] = None,
+        callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+        max_sequence_length: int = 226,
+    ) -> Union[FlashVideoPipelineOutput, Tuple]:
+        """
+        Function invoked when calling the pipeline for generation.
+
+        Args:
+            video (`List[PIL.Image.Image]`, *optional*):
+                The input video to condition the generation on. Must be a list of images/frames of the video. If not
+                supplied, `latents` must be supplied instead.
+            prompt (`str` or `List[str]`, *optional*):
+                The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+                instead.
+            negative_prompt (`str` or `List[str]`, *optional*):
+                The prompt or prompts not to guide the image generation. If not defined, one has to pass
+                `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+                less than `1`).
+            height (`int`, *optional*, defaults to self.transformer.config.sample_height * self.vae_scale_factor_spatial):
+                The height in pixels of the generated image. This is set to 480 by default for the best results.
+            width (`int`, *optional*, defaults to self.transformer.config.sample_height * self.vae_scale_factor_spatial):
+                The width in pixels of the generated image. This is set to 720 by default for the best results.
+            num_frames (`int`, defaults to `48`):
+                Number of frames to generate. Must be divisible by self.vae_scale_factor_temporal. Generated video will
+                contain 1 extra frame because CogVideoX is conditioned with (num_seconds * fps + 1) frames where
+                num_seconds is 6 and fps is 8. However, since videos can be saved at any fps, the only condition that
+                needs to be satisfied is that of divisibility mentioned above.
+            num_inference_steps (`int`, *optional*, defaults to 50):
+                The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+                expense of slower inference.
+            timesteps (`List[int]`, *optional*):
+                Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
+                in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
+                passed will be used. Must be in descending order.
+            noise_timestep (`int`, *optional*, defaults to `675`)
+                The timestep corresponding to the noise level used for latent degradation. This controls how much
+                degradation is applied to the latents before the denoising loop (with higher values meaning more
+                degradation); the FlashVideo authors recommend a value in the range of 650-750. If `None`, latent
+                degradation will not be applied (for example, in the Stage 1 model). Note that we could express this
+                as a strength value (e.g. 0.675) like in []`CogVideoXVideoToVideoPipeline`], but unlike pipelines which
+                have a strength argument we always use the full timestep schedule for denoising.
+            noise_shift_scale (`float` or `Tuple[float, float]`, *optional*, defaults to 1.0):
+                Shift and scale values to modify the noise level for latent degradation. The default value of `1.0`
+                corresponds to using the default noise level of noise_timestep according to the pipeline's scheduler.
+            guidance_scale (`float`, *optional*, defaults to 7.0):
+                Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+                `guidance_scale` is defined as `w` of equation 2. of [Imagen
+                Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+                1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+                usually at the expense of lower image quality.
+            num_videos_per_prompt (`int`, *optional*, defaults to 1):
+                The number of videos to generate per prompt.
+            generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+                One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+                to make generation deterministic.
+            latents (`torch.FloatTensor`, *optional*):
+                Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+                generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+                tensor will ge generated by sampling using the supplied random `generator`.
+            prompt_embeds (`torch.FloatTensor`, *optional*):
+                Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+                provided, text embeddings will be generated from `prompt` input argument.
+            negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+                Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+                weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+                argument.
+            output_type (`str`, *optional*, defaults to `"pil"`):
+                The output format of the generate image. Choose between
+                [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+            return_dict (`bool`, *optional*, defaults to `True`):
+                Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
+                of a plain tuple.
+            attention_kwargs (`dict`, *optional*):
+                A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+                `self.processor` in
+                [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+            callback_on_step_end (`Callable`, *optional*):
+                A function that calls at the end of each denoising steps during the inference. The function is called
+                with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
+                callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
+                `callback_on_step_end_tensor_inputs`.
+            callback_on_step_end_tensor_inputs (`List`, *optional*):
+                The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
+                will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
+                `._callback_tensor_inputs` attribute of your pipeline class.
+            max_sequence_length (`int`, defaults to `226`):
+                Maximum sequence length in encoded prompt. Must be consistent with
+                `self.transformer.config.max_text_seq_length` otherwise may lead to poor results.
+
+        Examples:
+
+        Returns:
+            [`~pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipelineOutput`] or `tuple`:
+            [`~pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipelineOutput`] if `return_dict` is True, otherwise a
+            `tuple`. When returning a tuple, the first element is a list with the generated images.
+        """
+
+        if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
+            callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
+
+        height = height or self.transformer.config.sample_height * self.vae_scale_factor_spatial
+        width = width or self.transformer.config.sample_width * self.vae_scale_factor_spatial
+        num_frames = num_frames or self.transformer.config.sample_frames
+
+        num_videos_per_prompt = 1
+
+        # 1. Check inputs. Raise error if not correct
+        self.check_inputs(
+            video,
+            prompt,
+            negative_prompt,
+            height,
+            width,
+            latents,
+            prompt_embeds,
+            negative_prompt_embeds,
+            callback_on_step_end_tensor_inputs,
+        )
+        self._guidance_scale = guidance_scale
+        self._attention_kwargs = attention_kwargs
+        self._current_timestep = None
+        self._interrupt = False
+
+        # 2. Default call parameters
+        if prompt is not None and isinstance(prompt, str):
+            batch_size = 1
+        elif prompt is not None and isinstance(prompt, list):
+            batch_size = len(prompt)
+        else:
+            batch_size = prompt_embeds.shape[0]
+
+        device = self._execution_device
+
+        # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+        # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+        # corresponds to doing no classifier free guidance.
+        do_classifier_free_guidance = guidance_scale > 1.0
+
+        # 3. Encode input prompt
+        prompt_embeds, negative_prompt_embeds = self.encode_prompt(
+            prompt,
+            negative_prompt,
+            do_classifier_free_guidance,
+            num_videos_per_prompt=num_videos_per_prompt,
+            prompt_embeds=prompt_embeds,
+            negative_prompt_embeds=negative_prompt_embeds,
+            max_sequence_length=max_sequence_length,
+            device=device,
+        )
+        if do_classifier_free_guidance:
+            prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
+        
+        # Also prepare inputs for other embeddings (time size, noise timestep) used by FlashVideo Stage 2 here.
+        # For efficiency purposes, since the time size and noise timestep stay the same throughout the denoising loop,
+        # prepare them ahead of time here (if using). These are used only in Stage 2 models.
+        time_size_tensor = None
+        noise_timesteps = None
+        if self.transformer.config.use_time_size_embedding:
+            time_size_tensor = torch.full((batch_size * num_videos_per_prompt,), latent_frames, device=latents.device)
+        if self.transformer.config.use_noise_timestep_embedding:
+            noise_timesteps = torch.full((batch_size * num_videos_per_prompt,), noise_timestep, device=latents.device)
+
+        # 4. Prepare timesteps
+        timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
+        self._num_timesteps = len(timesteps)
+
+        # 5. Prepare latents
+        latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
+
+        # For CogVideoX 1.5, the latent frames should be padded to make it divisible by patch_size_t
+        patch_size_t = self.transformer.config.patch_size_t
+        additional_frames = 0
+        if patch_size_t is not None and latent_frames % patch_size_t != 0:
+            additional_frames = patch_size_t - latent_frames % patch_size_t
+            num_frames += additional_frames * self.vae_scale_factor_temporal
+
+        if latents is None:
+            video = self.video_processor.preprocess_video(video, height=height, width=width)
+            video = video.to(device=device, dtype=prompt_embeds.dtype)
+
+        latent_channels = self.transformer.config.in_channels
+        latents = self.prepare_latents(
+            video=video,
+            batch_size=batch_size * num_videos_per_prompt,
+            num_channels_latents=latent_channels,
+            height=height,
+            width=width,
+            dtype=prompt_embeds.dtype,
+            device=device,
+            generator=generator,
+            latents=latents,
+            timestep=noise_timesteps,
+        )
+
+        # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
+        extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+        # 7. Create rotary embeds if required
+        image_rotary_emb = (
+            self._prepare_rotary_positional_embeddings(height, width, latents.size(1), device)
+            if self.transformer.config.use_rotary_positional_embeddings
+            else None
+        )
+
+        # 8. Denoising loop
+        num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
+
+        with self.progress_bar(total=num_inference_steps) as progress_bar:
+            # for DPM-solver++
+            old_pred_original_sample = None
+            for i, t in enumerate(timesteps):
+                if self.interrupt:
+                    continue
+
+                self._current_timestep = t
+                latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
+                latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+
+                # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
+                timestep = t.expand(latent_model_input.shape[0])
+
+                # predict noise model_output
+                noise_pred = self.transformer(
+                    hidden_states=latent_model_input,
+                    encoder_hidden_states=prompt_embeds,
+                    timestep=timestep,
+                    noise_timestep=noise_timesteps,
+                    time_size_tensor=time_size_tensor,
+                    image_rotary_emb=image_rotary_emb,
+                    attention_kwargs=attention_kwargs,
+                    return_dict=False,
+                )[0]
+                noise_pred = noise_pred.float()
+
+                # perform guidance
+                if use_dynamic_cfg:
+                    self._guidance_scale = 1 + guidance_scale * (
+                        (1 - math.cos(math.pi * ((num_inference_steps - t.item()) / num_inference_steps) ** 5.0)) / 2
+                    )
+                if do_classifier_free_guidance:
+                    noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+                    noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+                # compute the previous noisy sample x_t -> x_t-1
+                if not isinstance(self.scheduler, CogVideoXDPMScheduler):
+                    latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
+                else:
+                    latents, old_pred_original_sample = self.scheduler.step(
+                        noise_pred,
+                        old_pred_original_sample,
+                        t,
+                        timesteps[i - 1] if i > 0 else None,
+                        latents,
+                        **extra_step_kwargs,
+                        return_dict=False,
+                    )
+                latents = latents.to(prompt_embeds.dtype)
+
+                # call the callback, if provided
+                if callback_on_step_end is not None:
+                    callback_kwargs = {}
+                    for k in callback_on_step_end_tensor_inputs:
+                        callback_kwargs[k] = locals()[k]
+                    callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+
+                    latents = callback_outputs.pop("latents", latents)
+                    prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
+                    negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
+
+                if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+                    progress_bar.update()
+
+                if XLA_AVAILABLE:
+                    xm.mark_step()
+
+        self._current_timestep = None
+
+        if not output_type == "latent":
+            # Discard any padding frames that were added for CogVideoX 1.5
+            latents = latents[:, additional_frames:]
+            video = self.decode_latents(latents)
+            video = self.video_processor.postprocess_video(video=video, output_type=output_type)
+        else:
+            video = latents
+
+        # Offload all models
+        self.maybe_free_model_hooks()
+
+        if not return_dict:
+            return (video,)
+
+        return FlashVideoPipelineOutput(frames=video)
diff --git a/src/diffusers/pipelines/flashvideo/pipeline_output.py b/src/diffusers/pipelines/flashvideo/pipeline_output.py
new file mode 100644
index 000000000000..85dc55e76c41
--- /dev/null
+++ b/src/diffusers/pipelines/flashvideo/pipeline_output.py
@@ -0,0 +1,21 @@
+from dataclasses import dataclass
+
+import torch
+
+from diffusers.utils import BaseOutput
+
+
+# Copied from diffusers.pipelines.cogvideo.pipeline_output.CogVideoXPipelineOutput with CogVideoX->FlashVideo
+@dataclass
+class FlashVideoPipelineOutput(BaseOutput):
+    r"""
+    Output class for FlashVideo pipelines.
+
+    Args:
+        frames (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]):
+            List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing
+            denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape
+            `(batch_size, num_frames, channels, height, width)`.
+    """
+
+    frames: torch.Tensor
diff --git a/tests/pipelines/flashvideo/__init__.py b/tests/pipelines/flashvideo/__init__.py
new file mode 100644
index 000000000000..e69de29bb2d1
diff --git a/tests/pipelines/flashvideo/test_flashvideo.py b/tests/pipelines/flashvideo/test_flashvideo.py
new file mode 100644
index 000000000000..4f140690d189
--- /dev/null
+++ b/tests/pipelines/flashvideo/test_flashvideo.py
@@ -0,0 +1,365 @@
+# Copyright 2024 The HuggingFace Team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import gc
+import inspect
+import unittest
+
+import numpy as np
+import torch
+from transformers import AutoTokenizer, T5EncoderModel
+
+from diffusers import AutoencoderKLCogVideoX, CogVideoXPipeline, CogVideoXTransformer3DModel, DDIMScheduler
+from diffusers.utils.testing_utils import (
+    enable_full_determinism,
+    numpy_cosine_similarity_distance,
+    require_torch_accelerator,
+    slow,
+    torch_device,
+)
+
+from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
+from ..test_pipelines_common import (
+    PipelineTesterMixin,
+    PyramidAttentionBroadcastTesterMixin,
+    check_qkv_fusion_matches_attn_procs_length,
+    check_qkv_fusion_processors_exist,
+    to_np,
+)
+
+
+# NOTE: currently copied from test_cogvideox.py
+
+
+enable_full_determinism()
+
+
+class CogVideoXPipelineFastTests(PipelineTesterMixin, PyramidAttentionBroadcastTesterMixin, unittest.TestCase):
+    pipeline_class = CogVideoXPipeline
+    params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"}
+    batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
+    image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
+    image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
+    required_optional_params = frozenset(
+        [
+            "num_inference_steps",
+            "generator",
+            "latents",
+            "return_dict",
+            "callback_on_step_end",
+            "callback_on_step_end_tensor_inputs",
+        ]
+    )
+    test_xformers_attention = False
+    test_layerwise_casting = True
+    test_group_offloading = True
+
+    def get_dummy_components(self, num_layers: int = 1):
+        torch.manual_seed(0)
+        transformer = CogVideoXTransformer3DModel(
+            # Product of num_attention_heads * attention_head_dim must be divisible by 16 for 3D positional embeddings
+            # But, since we are using tiny-random-t5 here, we need the internal dim of CogVideoXTransformer3DModel
+            # to be 32. The internal dim is product of num_attention_heads and attention_head_dim
+            num_attention_heads=4,
+            attention_head_dim=8,
+            in_channels=4,
+            out_channels=4,
+            time_embed_dim=2,
+            text_embed_dim=32,  # Must match with tiny-random-t5
+            num_layers=num_layers,
+            sample_width=2,  # latent width: 2 -> final width: 16
+            sample_height=2,  # latent height: 2 -> final height: 16
+            sample_frames=9,  # latent frames: (9 - 1) / 4 + 1 = 3 -> final frames: 9
+            patch_size=2,
+            temporal_compression_ratio=4,
+            max_text_seq_length=16,
+        )
+
+        torch.manual_seed(0)
+        vae = AutoencoderKLCogVideoX(
+            in_channels=3,
+            out_channels=3,
+            down_block_types=(
+                "CogVideoXDownBlock3D",
+                "CogVideoXDownBlock3D",
+                "CogVideoXDownBlock3D",
+                "CogVideoXDownBlock3D",
+            ),
+            up_block_types=(
+                "CogVideoXUpBlock3D",
+                "CogVideoXUpBlock3D",
+                "CogVideoXUpBlock3D",
+                "CogVideoXUpBlock3D",
+            ),
+            block_out_channels=(8, 8, 8, 8),
+            latent_channels=4,
+            layers_per_block=1,
+            norm_num_groups=2,
+            temporal_compression_ratio=4,
+        )
+
+        torch.manual_seed(0)
+        scheduler = DDIMScheduler()
+        text_encoder = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
+        tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
+
+        components = {
+            "transformer": transformer,
+            "vae": vae,
+            "scheduler": scheduler,
+            "text_encoder": text_encoder,
+            "tokenizer": tokenizer,
+        }
+        return components
+
+    def get_dummy_inputs(self, device, seed=0):
+        if str(device).startswith("mps"):
+            generator = torch.manual_seed(seed)
+        else:
+            generator = torch.Generator(device=device).manual_seed(seed)
+        inputs = {
+            "prompt": "dance monkey",
+            "negative_prompt": "",
+            "generator": generator,
+            "num_inference_steps": 2,
+            "guidance_scale": 6.0,
+            # Cannot reduce because convolution kernel becomes bigger than sample
+            "height": 16,
+            "width": 16,
+            "num_frames": 8,
+            "max_sequence_length": 16,
+            "output_type": "pt",
+        }
+        return inputs
+
+    def test_inference(self):
+        device = "cpu"
+
+        components = self.get_dummy_components()
+        pipe = self.pipeline_class(**components)
+        pipe.to(device)
+        pipe.set_progress_bar_config(disable=None)
+
+        inputs = self.get_dummy_inputs(device)
+        video = pipe(**inputs).frames
+        generated_video = video[0]
+
+        self.assertEqual(generated_video.shape, (8, 3, 16, 16))
+        expected_video = torch.randn(8, 3, 16, 16)
+        max_diff = np.abs(generated_video - expected_video).max()
+        self.assertLessEqual(max_diff, 1e10)
+
+    def test_callback_inputs(self):
+        sig = inspect.signature(self.pipeline_class.__call__)
+        has_callback_tensor_inputs = "callback_on_step_end_tensor_inputs" in sig.parameters
+        has_callback_step_end = "callback_on_step_end" in sig.parameters
+
+        if not (has_callback_tensor_inputs and has_callback_step_end):
+            return
+
+        components = self.get_dummy_components()
+        pipe = self.pipeline_class(**components)
+        pipe = pipe.to(torch_device)
+        pipe.set_progress_bar_config(disable=None)
+        self.assertTrue(
+            hasattr(pipe, "_callback_tensor_inputs"),
+            f" {self.pipeline_class} should have `_callback_tensor_inputs` that defines a list of tensor variables its callback function can use as inputs",
+        )
+
+        def callback_inputs_subset(pipe, i, t, callback_kwargs):
+            # iterate over callback args
+            for tensor_name, tensor_value in callback_kwargs.items():
+                # check that we're only passing in allowed tensor inputs
+                assert tensor_name in pipe._callback_tensor_inputs
+
+            return callback_kwargs
+
+        def callback_inputs_all(pipe, i, t, callback_kwargs):
+            for tensor_name in pipe._callback_tensor_inputs:
+                assert tensor_name in callback_kwargs
+
+            # iterate over callback args
+            for tensor_name, tensor_value in callback_kwargs.items():
+                # check that we're only passing in allowed tensor inputs
+                assert tensor_name in pipe._callback_tensor_inputs
+
+            return callback_kwargs
+
+        inputs = self.get_dummy_inputs(torch_device)
+
+        # Test passing in a subset
+        inputs["callback_on_step_end"] = callback_inputs_subset
+        inputs["callback_on_step_end_tensor_inputs"] = ["latents"]
+        output = pipe(**inputs)[0]
+
+        # Test passing in a everything
+        inputs["callback_on_step_end"] = callback_inputs_all
+        inputs["callback_on_step_end_tensor_inputs"] = pipe._callback_tensor_inputs
+        output = pipe(**inputs)[0]
+
+        def callback_inputs_change_tensor(pipe, i, t, callback_kwargs):
+            is_last = i == (pipe.num_timesteps - 1)
+            if is_last:
+                callback_kwargs["latents"] = torch.zeros_like(callback_kwargs["latents"])
+            return callback_kwargs
+
+        inputs["callback_on_step_end"] = callback_inputs_change_tensor
+        inputs["callback_on_step_end_tensor_inputs"] = pipe._callback_tensor_inputs
+        output = pipe(**inputs)[0]
+        assert output.abs().sum() < 1e10
+
+    def test_inference_batch_single_identical(self):
+        self._test_inference_batch_single_identical(batch_size=3, expected_max_diff=1e-3)
+
+    def test_attention_slicing_forward_pass(
+        self, test_max_difference=True, test_mean_pixel_difference=True, expected_max_diff=1e-3
+    ):
+        if not self.test_attention_slicing:
+            return
+
+        components = self.get_dummy_components()
+        pipe = self.pipeline_class(**components)
+        for component in pipe.components.values():
+            if hasattr(component, "set_default_attn_processor"):
+                component.set_default_attn_processor()
+        pipe.to(torch_device)
+        pipe.set_progress_bar_config(disable=None)
+
+        generator_device = "cpu"
+        inputs = self.get_dummy_inputs(generator_device)
+        output_without_slicing = pipe(**inputs)[0]
+
+        pipe.enable_attention_slicing(slice_size=1)
+        inputs = self.get_dummy_inputs(generator_device)
+        output_with_slicing1 = pipe(**inputs)[0]
+
+        pipe.enable_attention_slicing(slice_size=2)
+        inputs = self.get_dummy_inputs(generator_device)
+        output_with_slicing2 = pipe(**inputs)[0]
+
+        if test_max_difference:
+            max_diff1 = np.abs(to_np(output_with_slicing1) - to_np(output_without_slicing)).max()
+            max_diff2 = np.abs(to_np(output_with_slicing2) - to_np(output_without_slicing)).max()
+            self.assertLess(
+                max(max_diff1, max_diff2),
+                expected_max_diff,
+                "Attention slicing should not affect the inference results",
+            )
+
+    def test_vae_tiling(self, expected_diff_max: float = 0.2):
+        generator_device = "cpu"
+        components = self.get_dummy_components()
+
+        pipe = self.pipeline_class(**components)
+        pipe.to("cpu")
+        pipe.set_progress_bar_config(disable=None)
+
+        # Without tiling
+        inputs = self.get_dummy_inputs(generator_device)
+        inputs["height"] = inputs["width"] = 128
+        output_without_tiling = pipe(**inputs)[0]
+
+        # With tiling
+        pipe.vae.enable_tiling(
+            tile_sample_min_height=96,
+            tile_sample_min_width=96,
+            tile_overlap_factor_height=1 / 12,
+            tile_overlap_factor_width=1 / 12,
+        )
+        inputs = self.get_dummy_inputs(generator_device)
+        inputs["height"] = inputs["width"] = 128
+        output_with_tiling = pipe(**inputs)[0]
+
+        self.assertLess(
+            (to_np(output_without_tiling) - to_np(output_with_tiling)).max(),
+            expected_diff_max,
+            "VAE tiling should not affect the inference results",
+        )
+
+    def test_fused_qkv_projections(self):
+        device = "cpu"  # ensure determinism for the device-dependent torch.Generator
+        components = self.get_dummy_components()
+        pipe = self.pipeline_class(**components)
+        pipe = pipe.to(device)
+        pipe.set_progress_bar_config(disable=None)
+
+        inputs = self.get_dummy_inputs(device)
+        frames = pipe(**inputs).frames  # [B, F, C, H, W]
+        original_image_slice = frames[0, -2:, -1, -3:, -3:]
+
+        pipe.fuse_qkv_projections()
+        assert check_qkv_fusion_processors_exist(
+            pipe.transformer
+        ), "Something wrong with the fused attention processors. Expected all the attention processors to be fused."
+        assert check_qkv_fusion_matches_attn_procs_length(
+            pipe.transformer, pipe.transformer.original_attn_processors
+        ), "Something wrong with the attention processors concerning the fused QKV projections."
+
+        inputs = self.get_dummy_inputs(device)
+        frames = pipe(**inputs).frames
+        image_slice_fused = frames[0, -2:, -1, -3:, -3:]
+
+        pipe.transformer.unfuse_qkv_projections()
+        inputs = self.get_dummy_inputs(device)
+        frames = pipe(**inputs).frames
+        image_slice_disabled = frames[0, -2:, -1, -3:, -3:]
+
+        assert np.allclose(
+            original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3
+        ), "Fusion of QKV projections shouldn't affect the outputs."
+        assert np.allclose(
+            image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3
+        ), "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
+        assert np.allclose(
+            original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2
+        ), "Original outputs should match when fused QKV projections are disabled."
+
+
+@slow
+@require_torch_accelerator
+class CogVideoXPipelineIntegrationTests(unittest.TestCase):
+    prompt = "A painting of a squirrel eating a burger."
+
+    def setUp(self):
+        super().setUp()
+        gc.collect()
+        torch.cuda.empty_cache()
+
+    def tearDown(self):
+        super().tearDown()
+        gc.collect()
+        torch.cuda.empty_cache()
+
+    def test_cogvideox(self):
+        generator = torch.Generator("cpu").manual_seed(0)
+
+        pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-2b", torch_dtype=torch.float16)
+        pipe.enable_model_cpu_offload(device=torch_device)
+        prompt = self.prompt
+
+        videos = pipe(
+            prompt=prompt,
+            height=480,
+            width=720,
+            num_frames=16,
+            generator=generator,
+            num_inference_steps=2,
+            output_type="pt",
+        ).frames
+
+        video = videos[0]
+        expected_video = torch.randn(1, 16, 480, 720, 3).numpy()
+
+        max_diff = numpy_cosine_similarity_distance(video, expected_video)
+        assert max_diff < 1e-3, f"Max diff is too high. got {video}"