Skip to content

Commit 5f938b5

Browse files
[Revert] "[Feature] Load weights from distributed" (#571)
1 parent 74da2a7 commit 5f938b5

38 files changed

+220
-410
lines changed

.github/workflows/pr-test.yml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -122,17 +122,17 @@ jobs:
122122
# Actual tests
123123
encoder-test:
124124
- 'fastvideo/v1/models/encoders/**'
125-
- 'fastvideo/v1/models/loader/**'
125+
- 'fastvideo/v1/models/loaders/**'
126126
- 'fastvideo/v1/tests/encoders/**'
127127
- *common-paths
128128
vae-test:
129129
- 'fastvideo/v1/models/vaes/**'
130-
- 'fastvideo/v1/models/loader/**'
130+
- 'fastvideo/v1/models/loaders/**'
131131
- 'fastvideo/v1/tests/vaes/**'
132132
- *common-paths
133133
transformer-test:
134134
- 'fastvideo/v1/models/dits/**'
135-
- 'fastvideo/v1/models/loader/**'
135+
- 'fastvideo/v1/models/loaders/**'
136136
- 'fastvideo/v1/tests/transformers/**'
137137
- 'fastvideo/v1/layers/**'
138138
- 'fastvideo/v1/attention/**'

examples/inference/basic/basic.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ def main():
1010
# attempt to identify the optimal arguments.
1111
generator = VideoGenerator.from_pretrained(
1212
"Wan-AI/Wan2.1-T2V-1.3B-Diffusers",
13-
# FastVideo will automatically handle distributed setup
13+
# if num_gpus > 1, FastVideo will automatically handle distributed setup
1414
num_gpus=2,
1515
use_fsdp_inference=True,
1616
use_cpu_offload=False

fastvideo/v1/configs/models/base.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# SPDX-License-Identifier: Apache-2.0
22
from dataclasses import dataclass, field, fields
3-
from typing import Any, Dict, List, Tuple
3+
from typing import Any, Dict
44

55
from fastvideo.v1.logger import init_logger
66

@@ -12,9 +12,7 @@
1212
# 3. Any field in ArchConfig is fixed upon initialization, and should be hidden away from users
1313
@dataclass
1414
class ArchConfig:
15-
stacked_params_mapping: List[Tuple[str, str, str]] = field(
16-
default_factory=list
17-
) # mapping from huggingface weight names to custom names
15+
pass
1816

1917

2018
@dataclass

fastvideo/v1/configs/models/dits/stepvideo.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,13 @@
55
from fastvideo.v1.configs.models.dits.base import DiTArchConfig, DiTConfig
66

77

8+
def is_blocks(n: str, m) -> bool:
9+
return "blocks" in n and str.isdigit(n.split(".")[-1])
10+
11+
812
@dataclass
913
class StepVideoArchConfig(DiTArchConfig):
10-
_fsdp_shard_conditions: list = field(
11-
default_factory=lambda:
12-
[lambda n, m: "transformer_blocks" in n and n.split(".")[-1].isdigit()])
14+
_fsdp_shard_conditions: list = field(default_factory=lambda: [is_blocks])
1315

1416
_param_names_mapping: dict = field(
1517
default_factory=lambda: {

fastvideo/v1/configs/models/encoders/base.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,11 +32,8 @@ class TextEncoderArchConfig(EncoderArchConfig):
3232
output_past: bool = True
3333
scalable_attention: bool = True
3434
tie_word_embeddings: bool = False
35-
stacked_params_mapping: List[Tuple[str, str, str]] = field(
36-
default_factory=list
37-
) # mapping from huggingface weight names to custom names
35+
3836
tokenizer_kwargs: Dict[str, Any] = field(default_factory=dict)
39-
_fsdp_shard_conditions: list = field(default_factory=lambda: [])
4037

4138
def __post_init__(self) -> None:
4239
self.tokenizer_kwargs = {

fastvideo/v1/configs/models/encoders/clip.py

Lines changed: 1 addition & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,13 @@
11
# SPDX-License-Identifier: Apache-2.0
22
from dataclasses import dataclass, field
3-
from typing import List, Optional, Tuple
3+
from typing import Optional
44

55
from fastvideo.v1.configs.models.encoders.base import (ImageEncoderArchConfig,
66
ImageEncoderConfig,
77
TextEncoderArchConfig,
88
TextEncoderConfig)
99

1010

11-
def _is_transformer_layer(n: str, m) -> bool:
12-
return "layers" in n and str.isdigit(n.split(".")[-1])
13-
14-
15-
def _is_embeddings(n: str, m) -> bool:
16-
return n.endswith("embeddings")
17-
18-
1911
@dataclass
2012
class CLIPTextArchConfig(TextEncoderArchConfig):
2113
vocab_size: int = 49408
@@ -35,15 +27,6 @@ class CLIPTextArchConfig(TextEncoderArchConfig):
3527
bos_token_id: int = 49406
3628
eos_token_id: int = 49407
3729
text_len: int = 77
38-
stacked_params_mapping: List[Tuple[str, str,
39-
str]] = field(default_factory=lambda: [
40-
# (param_name, shard_name, shard_id)
41-
("qkv_proj", "q_proj", "q"),
42-
("qkv_proj", "k_proj", "k"),
43-
("qkv_proj", "v_proj", "v"),
44-
])
45-
_fsdp_shard_conditions: list = field(
46-
default_factory=lambda: [_is_transformer_layer, _is_embeddings])
4730

4831

4932
@dataclass
@@ -62,13 +45,6 @@ class CLIPVisionArchConfig(ImageEncoderArchConfig):
6245
attention_dropout: float = 0.0
6346
initializer_range: float = 0.02
6447
initializer_factor: float = 1.0
65-
stacked_params_mapping: List[Tuple[str, str,
66-
str]] = field(default_factory=lambda: [
67-
# (param_name, shard_name, shard_id)
68-
("qkv_proj", "q_proj", "q"),
69-
("qkv_proj", "k_proj", "k"),
70-
("qkv_proj", "v_proj", "v"),
71-
])
7248

7349

7450
@dataclass

fastvideo/v1/configs/models/encoders/llama.py

Lines changed: 1 addition & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,11 @@
11
# SPDX-License-Identifier: Apache-2.0
22
from dataclasses import dataclass, field
3-
from typing import List, Optional, Tuple
3+
from typing import Optional
44

55
from fastvideo.v1.configs.models.encoders.base import (TextEncoderArchConfig,
66
TextEncoderConfig)
77

88

9-
def _is_transformer_layer(n: str, m) -> bool:
10-
return "layers" in n and str.isdigit(n.split(".")[-1])
11-
12-
13-
def _is_embeddings(n: str, m) -> bool:
14-
return n.endswith("embed_tokens")
15-
16-
17-
def _is_final_norm(n: str, m) -> bool:
18-
return n.endswith("norm")
19-
20-
219
@dataclass
2210
class LlamaArchConfig(TextEncoderArchConfig):
2311
vocab_size: int = 32000
@@ -44,18 +32,6 @@ class LlamaArchConfig(TextEncoderArchConfig):
4432
head_dim: Optional[int] = None
4533
hidden_state_skip_layer: int = 2
4634
text_len: int = 256
47-
stacked_params_mapping: List[Tuple[str, str, str]] = field(
48-
default_factory=lambda: [
49-
# (param_name, shard_name, shard_id)
50-
(".qkv_proj", ".q_proj", "q"),
51-
(".qkv_proj", ".k_proj", "k"),
52-
(".qkv_proj", ".v_proj", "v"),
53-
(".gate_up_proj", ".gate_proj", 0), # type: ignore
54-
(".gate_up_proj", ".up_proj", 1), # type: ignore
55-
])
56-
_fsdp_shard_conditions: list = field(
57-
default_factory=lambda:
58-
[_is_transformer_layer, _is_embeddings, _is_final_norm])
5935

6036

6137
@dataclass

fastvideo/v1/configs/models/encoders/t5.py

Lines changed: 1 addition & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,11 @@
11
# SPDX-License-Identifier: Apache-2.0
22
from dataclasses import dataclass, field
3-
from typing import List, Optional, Tuple
3+
from typing import Optional
44

55
from fastvideo.v1.configs.models.encoders.base import (TextEncoderArchConfig,
66
TextEncoderConfig)
77

88

9-
def _is_transformer_layer(n: str, m) -> bool:
10-
return "block" in n and str.isdigit(n.split(".")[-1])
11-
12-
13-
def _is_embeddings(n: str, m) -> bool:
14-
return n.endswith("shared")
15-
16-
17-
def _is_final_layernorm(n: str, m) -> bool:
18-
return n.endswith("final_layer_norm")
19-
20-
219
@dataclass
2210
class T5ArchConfig(TextEncoderArchConfig):
2311
vocab_size: int = 32128
@@ -41,16 +29,6 @@ class T5ArchConfig(TextEncoderArchConfig):
4129
eos_token_id: int = 1
4230
classifier_dropout: float = 0.0
4331
text_len: int = 512
44-
stacked_params_mapping: List[Tuple[str, str,
45-
str]] = field(default_factory=lambda: [
46-
# (param_name, shard_name, shard_id)
47-
(".qkv_proj", ".q", "q"),
48-
(".qkv_proj", ".k", "k"),
49-
(".qkv_proj", ".v", "v"),
50-
])
51-
_fsdp_shard_conditions: list = field(
52-
default_factory=lambda:
53-
[_is_transformer_layer, _is_embeddings, _is_final_layernorm])
5432

5533
# Referenced from https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/configuration_t5.py
5634
def __post_init__(self):

fastvideo/v1/dataset/benchmarks/benchmark_parquet_dataset_iterable_style.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
build_parquet_iterable_style_dataloader)
1212
from fastvideo.v1.distributed import get_world_rank
1313
from fastvideo.v1.distributed.parallel_state import (
14-
cleanup_dist_env_and_memory, get_local_torch_device,
14+
cleanup_dist_env_and_memory, get_torch_device,
1515
maybe_init_distributed_environment_and_model_parallel)
1616
from fastvideo.v1.logger import init_logger
1717

@@ -148,8 +148,8 @@ def main() -> None:
148148
break
149149

150150
# Move data to device
151-
latents = latents.to(get_local_torch_device())
152-
embeddings = embeddings.to(get_local_torch_device())
151+
latents = latents.to(get_torch_device())
152+
embeddings = embeddings.to(get_torch_device())
153153

154154
# Calculate actual batch size
155155
batch_size = latents.size(0)

fastvideo/v1/dataset/benchmarks/benchmark_parquet_dataset_map_style.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
build_parquet_map_style_dataloader)
1414
from fastvideo.v1.distributed import get_world_rank
1515
from fastvideo.v1.distributed.parallel_state import (
16-
cleanup_dist_env_and_memory, get_local_torch_device,
16+
cleanup_dist_env_and_memory, get_torch_device,
1717
maybe_init_distributed_environment_and_model_parallel)
1818
from fastvideo.v1.logger import init_logger
1919

@@ -165,8 +165,8 @@ def main() -> None:
165165
break
166166

167167
# Move data to device
168-
latents = latents.to(get_local_torch_device())
169-
embeddings = embeddings.to(get_local_torch_device())
168+
latents = latents.to(get_torch_device())
169+
embeddings = embeddings.to(get_torch_device())
170170

171171
# Calculate actual batch size
172172
batch_size = latents.size(0)

0 commit comments

Comments
 (0)