Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions .github/workflows/pr-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -122,17 +122,17 @@ jobs:
# Actual tests
encoder-test:
- 'fastvideo/v1/models/encoders/**'
- 'fastvideo/v1/models/loaders/**'
- 'fastvideo/v1/models/loader/**'
- 'fastvideo/v1/tests/encoders/**'
- *common-paths
vae-test:
- 'fastvideo/v1/models/vaes/**'
- 'fastvideo/v1/models/loaders/**'
- 'fastvideo/v1/models/loader/**'
- 'fastvideo/v1/tests/vaes/**'
- *common-paths
transformer-test:
- 'fastvideo/v1/models/dits/**'
- 'fastvideo/v1/models/loaders/**'
- 'fastvideo/v1/models/loader/**'
- 'fastvideo/v1/tests/transformers/**'
- 'fastvideo/v1/layers/**'
- 'fastvideo/v1/attention/**'
Expand Down
2 changes: 1 addition & 1 deletion examples/inference/basic/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ def main():
# attempt to identify the optimal arguments.
generator = VideoGenerator.from_pretrained(
"Wan-AI/Wan2.1-T2V-1.3B-Diffusers",
# if num_gpus > 1, FastVideo will automatically handle distributed setup
# FastVideo will automatically handle distributed setup
num_gpus=2,
use_fsdp_inference=True,
use_cpu_offload=False
Expand Down
6 changes: 4 additions & 2 deletions fastvideo/v1/configs/models/base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
from dataclasses import dataclass, field, fields
from typing import Any, Dict
from typing import Any, Dict, List, Tuple

from fastvideo.v1.logger import init_logger

Expand All @@ -12,7 +12,9 @@
# 3. Any field in ArchConfig is fixed upon initialization, and should be hidden away from users
@dataclass
class ArchConfig:
pass
stacked_params_mapping: List[Tuple[str, str, str]] = field(
default_factory=list
) # mapping from huggingface weight names to custom names


@dataclass
Expand Down
8 changes: 3 additions & 5 deletions fastvideo/v1/configs/models/dits/stepvideo.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,11 @@
from fastvideo.v1.configs.models.dits.base import DiTArchConfig, DiTConfig


def is_blocks(n: str, m) -> bool:
return "blocks" in n and str.isdigit(n.split(".")[-1])


@dataclass
class StepVideoArchConfig(DiTArchConfig):
_fsdp_shard_conditions: list = field(default_factory=lambda: [is_blocks])
_fsdp_shard_conditions: list = field(
default_factory=lambda:
[lambda n, m: "transformer_blocks" in n and n.split(".")[-1].isdigit()])

_param_names_mapping: dict = field(
default_factory=lambda: {
Expand Down
5 changes: 4 additions & 1 deletion fastvideo/v1/configs/models/encoders/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,11 @@ class TextEncoderArchConfig(EncoderArchConfig):
output_past: bool = True
scalable_attention: bool = True
tie_word_embeddings: bool = False

stacked_params_mapping: List[Tuple[str, str, str]] = field(
default_factory=list
) # mapping from huggingface weight names to custom names
tokenizer_kwargs: Dict[str, Any] = field(default_factory=dict)
_fsdp_shard_conditions: list = field(default_factory=lambda: [])

def __post_init__(self) -> None:
self.tokenizer_kwargs = {
Expand Down
26 changes: 25 additions & 1 deletion fastvideo/v1/configs/models/encoders/clip.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,21 @@
# SPDX-License-Identifier: Apache-2.0
from dataclasses import dataclass, field
from typing import Optional
from typing import List, Optional, Tuple

from fastvideo.v1.configs.models.encoders.base import (ImageEncoderArchConfig,
ImageEncoderConfig,
TextEncoderArchConfig,
TextEncoderConfig)


def _is_transformer_layer(n: str, m) -> bool:
return "layers" in n and str.isdigit(n.split(".")[-1])


def _is_embeddings(n: str, m) -> bool:
return n.endswith("embeddings")


@dataclass
class CLIPTextArchConfig(TextEncoderArchConfig):
vocab_size: int = 49408
Expand All @@ -27,6 +35,15 @@ class CLIPTextArchConfig(TextEncoderArchConfig):
bos_token_id: int = 49406
eos_token_id: int = 49407
text_len: int = 77
stacked_params_mapping: List[Tuple[str, str,
str]] = field(default_factory=lambda: [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
])
_fsdp_shard_conditions: list = field(
default_factory=lambda: [_is_transformer_layer, _is_embeddings])


@dataclass
Expand All @@ -45,6 +62,13 @@ class CLIPVisionArchConfig(ImageEncoderArchConfig):
attention_dropout: float = 0.0
initializer_range: float = 0.02
initializer_factor: float = 1.0
stacked_params_mapping: List[Tuple[str, str,
str]] = field(default_factory=lambda: [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
])


@dataclass
Expand Down
26 changes: 25 additions & 1 deletion fastvideo/v1/configs/models/encoders/llama.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,23 @@
# SPDX-License-Identifier: Apache-2.0
from dataclasses import dataclass, field
from typing import Optional
from typing import List, Optional, Tuple

from fastvideo.v1.configs.models.encoders.base import (TextEncoderArchConfig,
TextEncoderConfig)


def _is_transformer_layer(n: str, m) -> bool:
return "layers" in n and str.isdigit(n.split(".")[-1])


def _is_embeddings(n: str, m) -> bool:
return n.endswith("embed_tokens")


def _is_final_norm(n: str, m) -> bool:
return n.endswith("norm")


@dataclass
class LlamaArchConfig(TextEncoderArchConfig):
vocab_size: int = 32000
Expand All @@ -32,6 +44,18 @@ class LlamaArchConfig(TextEncoderArchConfig):
head_dim: Optional[int] = None
hidden_state_skip_layer: int = 2
text_len: int = 256
stacked_params_mapping: List[Tuple[str, str, str]] = field(
default_factory=lambda: [
# (param_name, shard_name, shard_id)
(".qkv_proj", ".q_proj", "q"),
(".qkv_proj", ".k_proj", "k"),
(".qkv_proj", ".v_proj", "v"),
(".gate_up_proj", ".gate_proj", 0), # type: ignore
(".gate_up_proj", ".up_proj", 1), # type: ignore
])
_fsdp_shard_conditions: list = field(
default_factory=lambda:
[_is_transformer_layer, _is_embeddings, _is_final_norm])


@dataclass
Expand Down
24 changes: 23 additions & 1 deletion fastvideo/v1/configs/models/encoders/t5.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,23 @@
# SPDX-License-Identifier: Apache-2.0
from dataclasses import dataclass, field
from typing import Optional
from typing import List, Optional, Tuple

from fastvideo.v1.configs.models.encoders.base import (TextEncoderArchConfig,
TextEncoderConfig)


def _is_transformer_layer(n: str, m) -> bool:
return "block" in n and str.isdigit(n.split(".")[-1])


def _is_embeddings(n: str, m) -> bool:
return n.endswith("shared")


def _is_final_layernorm(n: str, m) -> bool:
return n.endswith("final_layer_norm")


@dataclass
class T5ArchConfig(TextEncoderArchConfig):
vocab_size: int = 32128
Expand All @@ -29,6 +41,16 @@ class T5ArchConfig(TextEncoderArchConfig):
eos_token_id: int = 1
classifier_dropout: float = 0.0
text_len: int = 512
stacked_params_mapping: List[Tuple[str, str,
str]] = field(default_factory=lambda: [
# (param_name, shard_name, shard_id)
(".qkv_proj", ".q", "q"),
(".qkv_proj", ".k", "k"),
(".qkv_proj", ".v", "v"),
])
_fsdp_shard_conditions: list = field(
default_factory=lambda:
[_is_transformer_layer, _is_embeddings, _is_final_layernorm])

# Referenced from https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/configuration_t5.py
def __post_init__(self):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
build_parquet_iterable_style_dataloader)
from fastvideo.v1.distributed import get_world_rank
from fastvideo.v1.distributed.parallel_state import (
cleanup_dist_env_and_memory, get_torch_device,
cleanup_dist_env_and_memory, get_local_torch_device,
maybe_init_distributed_environment_and_model_parallel)
from fastvideo.v1.logger import init_logger

Expand Down Expand Up @@ -148,8 +148,8 @@ def main() -> None:
break

# Move data to device
latents = latents.to(get_torch_device())
embeddings = embeddings.to(get_torch_device())
latents = latents.to(get_local_torch_device())
embeddings = embeddings.to(get_local_torch_device())

# Calculate actual batch size
batch_size = latents.size(0)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
build_parquet_map_style_dataloader)
from fastvideo.v1.distributed import get_world_rank
from fastvideo.v1.distributed.parallel_state import (
cleanup_dist_env_and_memory, get_torch_device,
cleanup_dist_env_and_memory, get_local_torch_device,
maybe_init_distributed_environment_and_model_parallel)
from fastvideo.v1.logger import init_logger

Expand Down Expand Up @@ -165,8 +165,8 @@ def main() -> None:
break

# Move data to device
latents = latents.to(get_torch_device())
embeddings = embeddings.to(get_torch_device())
latents = latents.to(get_local_torch_device())
embeddings = embeddings.to(get_local_torch_device())

# Calculate actual batch size
batch_size = latents.size(0)
Expand Down
10 changes: 5 additions & 5 deletions fastvideo/v1/distributed/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@
from fastvideo.v1.distributed.communication_op import *
from fastvideo.v1.distributed.parallel_state import (
cleanup_dist_env_and_memory, get_dp_group, get_dp_rank, get_dp_world_size,
get_sp_group, get_sp_parallel_rank, get_sp_world_size, get_torch_device,
get_tp_group, get_tp_rank, get_tp_world_size, get_world_group,
get_world_rank, get_world_size, init_distributed_environment,
initialize_model_parallel,
get_local_torch_device, get_sp_group, get_sp_parallel_rank,
get_sp_world_size, get_tp_group, get_tp_rank, get_tp_world_size,
get_world_group, get_world_rank, get_world_size,
init_distributed_environment, initialize_model_parallel,
maybe_init_distributed_environment_and_model_parallel,
model_parallel_is_initialized)
from fastvideo.v1.distributed.utils import *
Expand Down Expand Up @@ -40,5 +40,5 @@
"get_tp_world_size",

# Get torch device
"get_torch_device",
"get_local_torch_device",
]
4 changes: 2 additions & 2 deletions fastvideo/v1/distributed/parallel_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -904,7 +904,7 @@ def get_dp_rank() -> int:
return get_dp_group().rank_in_group


def get_torch_device() -> torch.device:
def get_local_torch_device() -> torch.device:
"""Return the torch device for the current rank."""
return torch.device(f"cuda:{envs.LOCAL_RANK}")

Expand Down Expand Up @@ -1232,4 +1232,4 @@ def initialize_sequence_parallel_group(
backend,
group_name=group_name)

return sp_group
return sp_group
20 changes: 17 additions & 3 deletions fastvideo/v1/fastvideo_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,10 @@ class FastVideoArgs:

output_type: str = "pil"

use_cpu_offload: bool = True
use_cpu_offload: bool = True # For DiT
use_fsdp_inference: bool = True
text_encoder_offload: bool = True
pin_cpu_memory: bool = True

# STA (Sliding Tile Attention) parameters
mask_strategy_file_path: Optional[str] = None
Expand Down Expand Up @@ -208,15 +210,27 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
"--use-cpu-offload",
action=StoreBoolean,
help=
"Use CPU offload for model inference. Enable if run out of memory with FSDP.",
"Use CPU offload for DiT inference. Enable if run out of memory with FSDP.",
)
parser.add_argument(
"--use-fsdp-inference",
action=StoreBoolean,
help=
"Use FSDP for inference by sharding the model weights. Latency is very low due to prefetch--enable if run out of memory.",
)

parser.add_argument(
"--text-encoder-cpu-offload",
action=StoreBoolean,
help=
"Use CPU offload for text encoder. Enable if run out of memory.",
)
parser.add_argument(
"--pin-cpu-memory",
action=StoreBoolean,
help=
"Pin memory for CPU offload. Only added as a temp workaround if it throws \"CUDA error: invalid argument\". "
"Should be enabled in almost all cases",
)
parser.add_argument(
"--disable-autocast",
action=StoreBoolean,
Expand Down
5 changes: 1 addition & 4 deletions fastvideo/v1/models/dits/stepvideo.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,10 +455,7 @@ def forward(self,

class StepVideoModel(BaseDiT):
# (Optional) Keep the same attribute for compatibility with splitting, etc.
_fsdp_shard_conditions = [
lambda n, m: "transformer_blocks" in n and n.split(".")[-1].isdigit(),
# lambda n, m: "pos_embed" in n # If needed for the patch embedding.
]
_fsdp_shard_conditions = StepVideoConfig()._fsdp_shard_conditions
_param_names_mapping = StepVideoConfig()._param_names_mapping
_reverse_param_names_mapping = StepVideoConfig(
)._reverse_param_names_mapping
Expand Down
8 changes: 4 additions & 4 deletions fastvideo/v1/models/dits/wanvideo.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,9 +318,9 @@ def forward(
value, _ = self.to_v(norm_hidden_states)

if self.norm_q is not None:
query = self.norm_q.forward_native(query)
query = self.norm_q(query)
if self.norm_k is not None:
key = self.norm_k.forward_native(key)
key = self.norm_k(key)

query = query.squeeze(1).unflatten(2, (self.num_attention_heads, -1))
key = key.squeeze(1).unflatten(2, (self.num_attention_heads, -1))
Expand Down Expand Up @@ -465,9 +465,9 @@ def forward(
gate_compress, _ = self.to_gate_compress(norm_hidden_states)

if self.norm_q is not None:
query = self.norm_q.forward_native(query)
query = self.norm_q(query)
if self.norm_k is not None:
key = self.norm_k.forward_native(key)
key = self.norm_k(key)

query = query.squeeze(1).unflatten(2, (self.num_attention_heads, -1))
key = key.squeeze(1).unflatten(2, (self.num_attention_heads, -1))
Expand Down
Loading