Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions .github/workflows/pr-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -122,17 +122,17 @@ jobs:
# Actual tests
encoder-test:
- 'fastvideo/v1/models/encoders/**'
- 'fastvideo/v1/models/loaders/**'
- 'fastvideo/v1/models/loader/**'
- 'fastvideo/v1/tests/encoders/**'
- *common-paths
vae-test:
- 'fastvideo/v1/models/vaes/**'
- 'fastvideo/v1/models/loaders/**'
- 'fastvideo/v1/models/loader/**'
- 'fastvideo/v1/tests/vaes/**'
- *common-paths
transformer-test:
- 'fastvideo/v1/models/dits/**'
- 'fastvideo/v1/models/loaders/**'
- 'fastvideo/v1/models/loader/**'
- 'fastvideo/v1/tests/transformers/**'
- 'fastvideo/v1/layers/**'
- 'fastvideo/v1/attention/**'
Expand Down
2 changes: 1 addition & 1 deletion examples/inference/basic/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ def main():
# attempt to identify the optimal arguments.
generator = VideoGenerator.from_pretrained(
"Wan-AI/Wan2.1-T2V-1.3B-Diffusers",
# if num_gpus > 1, FastVideo will automatically handle distributed setup
# FastVideo will automatically handle distributed setup
num_gpus=2,
use_fsdp_inference=True,
use_cpu_offload=False
Expand Down
6 changes: 4 additions & 2 deletions fastvideo/v1/configs/models/base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
from dataclasses import dataclass, field, fields
from typing import Any, Dict
from typing import Any, Dict, List, Tuple

from fastvideo.v1.logger import init_logger

Expand All @@ -12,7 +12,9 @@
# 3. Any field in ArchConfig is fixed upon initialization, and should be hidden away from users
@dataclass
class ArchConfig:
pass
stacked_params_mapping: List[Tuple[str, str, str]] = field(
default_factory=list
) # mapping from huggingface weight names to custom names


@dataclass
Expand Down
8 changes: 3 additions & 5 deletions fastvideo/v1/configs/models/dits/stepvideo.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,11 @@
from fastvideo.v1.configs.models.dits.base import DiTArchConfig, DiTConfig


def is_blocks(n: str, m) -> bool:
return "blocks" in n and str.isdigit(n.split(".")[-1])


@dataclass
class StepVideoArchConfig(DiTArchConfig):
_fsdp_shard_conditions: list = field(default_factory=lambda: [is_blocks])
_fsdp_shard_conditions: list = field(
default_factory=lambda:
[lambda n, m: "transformer_blocks" in n and n.split(".")[-1].isdigit()])

_param_names_mapping: dict = field(
default_factory=lambda: {
Expand Down
5 changes: 4 additions & 1 deletion fastvideo/v1/configs/models/encoders/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,11 @@ class TextEncoderArchConfig(EncoderArchConfig):
output_past: bool = True
scalable_attention: bool = True
tie_word_embeddings: bool = False

stacked_params_mapping: List[Tuple[str, str, str]] = field(
default_factory=list
) # mapping from huggingface weight names to custom names
tokenizer_kwargs: Dict[str, Any] = field(default_factory=dict)
_fsdp_shard_conditions: list = field(default_factory=lambda: [])

def __post_init__(self) -> None:
self.tokenizer_kwargs = {
Expand Down
26 changes: 25 additions & 1 deletion fastvideo/v1/configs/models/encoders/clip.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,21 @@
# SPDX-License-Identifier: Apache-2.0
from dataclasses import dataclass, field
from typing import Optional
from typing import List, Optional, Tuple

from fastvideo.v1.configs.models.encoders.base import (ImageEncoderArchConfig,
ImageEncoderConfig,
TextEncoderArchConfig,
TextEncoderConfig)


def _is_transformer_layer(n: str, m) -> bool:
return "layers" in n and str.isdigit(n.split(".")[-1])


def _is_embeddings(n: str, m) -> bool:
return n.endswith("embeddings")


@dataclass
class CLIPTextArchConfig(TextEncoderArchConfig):
vocab_size: int = 49408
Expand All @@ -27,6 +35,15 @@ class CLIPTextArchConfig(TextEncoderArchConfig):
bos_token_id: int = 49406
eos_token_id: int = 49407
text_len: int = 77
stacked_params_mapping: List[Tuple[str, str,
str]] = field(default_factory=lambda: [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
])
_fsdp_shard_conditions: list = field(
default_factory=lambda: [_is_transformer_layer, _is_embeddings])


@dataclass
Expand All @@ -45,6 +62,13 @@ class CLIPVisionArchConfig(ImageEncoderArchConfig):
attention_dropout: float = 0.0
initializer_range: float = 0.02
initializer_factor: float = 1.0
stacked_params_mapping: List[Tuple[str, str,
str]] = field(default_factory=lambda: [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
])


@dataclass
Expand Down
26 changes: 25 additions & 1 deletion fastvideo/v1/configs/models/encoders/llama.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,23 @@
# SPDX-License-Identifier: Apache-2.0
from dataclasses import dataclass, field
from typing import Optional
from typing import List, Optional, Tuple

from fastvideo.v1.configs.models.encoders.base import (TextEncoderArchConfig,
TextEncoderConfig)


def _is_transformer_layer(n: str, m) -> bool:
return "layers" in n and str.isdigit(n.split(".")[-1])


def _is_embeddings(n: str, m) -> bool:
return n.endswith("embed_tokens")


def _is_final_norm(n: str, m) -> bool:
return n.endswith("norm")


@dataclass
class LlamaArchConfig(TextEncoderArchConfig):
vocab_size: int = 32000
Expand All @@ -32,6 +44,18 @@ class LlamaArchConfig(TextEncoderArchConfig):
head_dim: Optional[int] = None
hidden_state_skip_layer: int = 2
text_len: int = 256
stacked_params_mapping: List[Tuple[str, str, str]] = field(
default_factory=lambda: [
# (param_name, shard_name, shard_id)
(".qkv_proj", ".q_proj", "q"),
(".qkv_proj", ".k_proj", "k"),
(".qkv_proj", ".v_proj", "v"),
(".gate_up_proj", ".gate_proj", 0), # type: ignore
(".gate_up_proj", ".up_proj", 1), # type: ignore
])
_fsdp_shard_conditions: list = field(
default_factory=lambda:
[_is_transformer_layer, _is_embeddings, _is_final_norm])


@dataclass
Expand Down
24 changes: 23 additions & 1 deletion fastvideo/v1/configs/models/encoders/t5.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,23 @@
# SPDX-License-Identifier: Apache-2.0
from dataclasses import dataclass, field
from typing import Optional
from typing import List, Optional, Tuple

from fastvideo.v1.configs.models.encoders.base import (TextEncoderArchConfig,
TextEncoderConfig)


def _is_transformer_layer(n: str, m) -> bool:
return "block" in n and str.isdigit(n.split(".")[-1])


def _is_embeddings(n: str, m) -> bool:
return n.endswith("shared")


def _is_final_layernorm(n: str, m) -> bool:
return n.endswith("final_layer_norm")


@dataclass
class T5ArchConfig(TextEncoderArchConfig):
vocab_size: int = 32128
Expand All @@ -29,6 +41,16 @@ class T5ArchConfig(TextEncoderArchConfig):
eos_token_id: int = 1
classifier_dropout: float = 0.0
text_len: int = 512
stacked_params_mapping: List[Tuple[str, str,
str]] = field(default_factory=lambda: [
# (param_name, shard_name, shard_id)
(".qkv_proj", ".q", "q"),
(".qkv_proj", ".k", "k"),
(".qkv_proj", ".v", "v"),
])
_fsdp_shard_conditions: list = field(
default_factory=lambda:
[_is_transformer_layer, _is_embeddings, _is_final_layernorm])

# Referenced from https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/configuration_t5.py
def __post_init__(self):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
build_parquet_iterable_style_dataloader)
from fastvideo.v1.distributed import get_world_rank
from fastvideo.v1.distributed.parallel_state import (
cleanup_dist_env_and_memory, get_torch_device,
cleanup_dist_env_and_memory, get_local_torch_device,
maybe_init_distributed_environment_and_model_parallel)
from fastvideo.v1.logger import init_logger

Expand Down Expand Up @@ -148,8 +148,8 @@ def main() -> None:
break

# Move data to device
latents = latents.to(get_torch_device())
embeddings = embeddings.to(get_torch_device())
latents = latents.to(get_local_torch_device())
embeddings = embeddings.to(get_local_torch_device())

# Calculate actual batch size
batch_size = latents.size(0)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
build_parquet_map_style_dataloader)
from fastvideo.v1.distributed import get_world_rank
from fastvideo.v1.distributed.parallel_state import (
cleanup_dist_env_and_memory, get_torch_device,
cleanup_dist_env_and_memory, get_local_torch_device,
maybe_init_distributed_environment_and_model_parallel)
from fastvideo.v1.logger import init_logger

Expand Down Expand Up @@ -165,8 +165,8 @@ def main() -> None:
break

# Move data to device
latents = latents.to(get_torch_device())
embeddings = embeddings.to(get_torch_device())
latents = latents.to(get_local_torch_device())
embeddings = embeddings.to(get_local_torch_device())

# Calculate actual batch size
batch_size = latents.size(0)
Expand Down
10 changes: 5 additions & 5 deletions fastvideo/v1/distributed/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@
from fastvideo.v1.distributed.communication_op import *
from fastvideo.v1.distributed.parallel_state import (
cleanup_dist_env_and_memory, get_dp_group, get_dp_rank, get_dp_world_size,
get_sp_group, get_sp_parallel_rank, get_sp_world_size, get_torch_device,
get_tp_group, get_tp_rank, get_tp_world_size, get_world_group,
get_world_rank, get_world_size, init_distributed_environment,
initialize_model_parallel,
get_local_torch_device, get_sp_group, get_sp_parallel_rank,
get_sp_world_size, get_tp_group, get_tp_rank, get_tp_world_size,
get_world_group, get_world_rank, get_world_size,
init_distributed_environment, initialize_model_parallel,
maybe_init_distributed_environment_and_model_parallel,
model_parallel_is_initialized)
from fastvideo.v1.distributed.utils import *
Expand Down Expand Up @@ -40,5 +40,5 @@
"get_tp_world_size",

# Get torch device
"get_torch_device",
"get_local_torch_device",
]
4 changes: 2 additions & 2 deletions fastvideo/v1/distributed/parallel_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -904,7 +904,7 @@ def get_dp_rank() -> int:
return get_dp_group().rank_in_group


def get_torch_device() -> torch.device:
def get_local_torch_device() -> torch.device:
"""Return the torch device for the current rank."""
return torch.device(f"cuda:{envs.LOCAL_RANK}")

Expand Down Expand Up @@ -1232,4 +1232,4 @@ def initialize_sequence_parallel_group(
backend,
group_name=group_name)

return sp_group
return sp_group
20 changes: 17 additions & 3 deletions fastvideo/v1/fastvideo_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,10 @@ class FastVideoArgs:

output_type: str = "pil"

use_cpu_offload: bool = True
use_cpu_offload: bool = True # For DiT
use_fsdp_inference: bool = True
text_encoder_offload: bool = True
pin_cpu_memory: bool = True

# STA (Sliding Tile Attention) parameters
mask_strategy_file_path: Optional[str] = None
Expand Down Expand Up @@ -208,15 +210,27 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
"--use-cpu-offload",
action=StoreBoolean,
help=
"Use CPU offload for model inference. Enable if run out of memory with FSDP.",
"Use CPU offload for DiT inference. Enable if run out of memory with FSDP.",
)
parser.add_argument(
"--use-fsdp-inference",
action=StoreBoolean,
help=
"Use FSDP for inference by sharding the model weights. Latency is very low due to prefetch--enable if run out of memory.",
)

parser.add_argument(
"--text-encoder-cpu-offload",
action=StoreBoolean,
help=
"Use CPU offload for text encoder. Enable if run out of memory.",
)
parser.add_argument(
"--pin-cpu-memory",
action=StoreBoolean,
help=
"Pin memory for CPU offload. Only added as a temp workaround if it throws \"CUDA error: invalid argument\". "
"Should be enabled in almost all cases",
)
parser.add_argument(
"--disable-autocast",
action=StoreBoolean,
Expand Down
5 changes: 1 addition & 4 deletions fastvideo/v1/models/dits/stepvideo.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,10 +455,7 @@ def forward(self,

class StepVideoModel(BaseDiT):
# (Optional) Keep the same attribute for compatibility with splitting, etc.
_fsdp_shard_conditions = [
lambda n, m: "transformer_blocks" in n and n.split(".")[-1].isdigit(),
# lambda n, m: "pos_embed" in n # If needed for the patch embedding.
]
_fsdp_shard_conditions = StepVideoConfig()._fsdp_shard_conditions
_param_names_mapping = StepVideoConfig()._param_names_mapping
_reverse_param_names_mapping = StepVideoConfig(
)._reverse_param_names_mapping
Expand Down
8 changes: 4 additions & 4 deletions fastvideo/v1/models/dits/wanvideo.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,9 +318,9 @@ def forward(
value, _ = self.to_v(norm_hidden_states)

if self.norm_q is not None:
query = self.norm_q.forward_native(query)
query = self.norm_q(query)
if self.norm_k is not None:
key = self.norm_k.forward_native(key)
key = self.norm_k(key)

query = query.squeeze(1).unflatten(2, (self.num_attention_heads, -1))
key = key.squeeze(1).unflatten(2, (self.num_attention_heads, -1))
Expand Down Expand Up @@ -465,9 +465,9 @@ def forward(
gate_compress, _ = self.to_gate_compress(norm_hidden_states)

if self.norm_q is not None:
query = self.norm_q.forward_native(query)
query = self.norm_q(query)
if self.norm_k is not None:
key = self.norm_k.forward_native(key)
key = self.norm_k(key)

query = query.squeeze(1).unflatten(2, (self.num_attention_heads, -1))
key = key.squeeze(1).unflatten(2, (self.num_attention_heads, -1))
Expand Down
Loading