Skip to content

Commit 99cbe15

Browse files
committed
remove dist load related
1 parent 2fa4ecf commit 99cbe15

File tree

8 files changed

+50
-171
lines changed

8 files changed

+50
-171
lines changed

fastvideo/v1/distributed/parallel_state.py

Lines changed: 7 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,6 @@
3636

3737
import torch
3838
import torch.distributed
39-
import torch.distributed as dist
4039
from torch.distributed import Backend, ProcessGroup, ReduceOp
4140

4241
import fastvideo.v1.envs as envs
@@ -693,19 +692,13 @@ def destroy(self) -> None:
693692

694693

695694
_WORLD: Optional[GroupCoordinator] = None
696-
_NODE: Optional[GroupCoordinator] = None
697695

698696

699697
def get_world_group() -> GroupCoordinator:
700698
assert _WORLD is not None, ("world group is not initialized")
701699
return _WORLD
702700

703701

704-
def get_node_group() -> GroupCoordinator:
705-
assert _NODE is not None, ("node group is not initialized")
706-
return _NODE
707-
708-
709702
def init_world_group(ranks: List[int], local_rank: int,
710703
backend: str) -> GroupCoordinator:
711704
return GroupCoordinator(
@@ -717,18 +710,6 @@ def init_world_group(ranks: List[int], local_rank: int,
717710
)
718711

719712

720-
def init_node_group(local_rank: int, backend: str):
721-
cpu_group = get_world_group().cpu_group
722-
node_ranks = same_node_ranks(cpu_group)
723-
node_size = len(node_ranks)
724-
all_node_ranks = [
725-
list(range(i * node_size, (i + 1) * node_size))
726-
for i in range(dist.get_world_size() // node_size)
727-
]
728-
global _NODE
729-
_NODE = init_model_parallel_group(all_node_ranks, local_rank, backend)
730-
731-
732713
def init_model_parallel_group(
733714
group_ranks: List[List[int]],
734715
local_rank: int,
@@ -801,8 +782,6 @@ def init_distributed_environment(
801782
else:
802783
assert _WORLD.world_size == torch.distributed.get_world_size(), (
803784
"world group already initialized with a different world size")
804-
# Init a group for each node
805-
init_node_group(local_rank, backend)
806785

807786

808787
_SP: Optional[GroupCoordinator] = None
@@ -925,7 +904,7 @@ def get_dp_rank() -> int:
925904
return get_dp_group().rank_in_group
926905

927906

928-
def get_local_torch_device() -> torch.device:
907+
def get_torch_device() -> torch.device:
929908
"""Return the torch device for the current rank."""
930909
return torch.device(f"cuda:{envs.LOCAL_RANK}")
931910

@@ -1042,22 +1021,17 @@ def cleanup_dist_env_and_memory(shutdown_ray: bool = False):
10421021
"torch._C._host_emptyCache() only available in Pytorch >=2.5")
10431022

10441023

1045-
def same_node_ranks(pg: Union[ProcessGroup, StatelessProcessGroup],
1046-
source_rank: int = 0) -> List[int]:
1024+
def in_the_same_node_as(pg: Union[ProcessGroup, StatelessProcessGroup],
1025+
source_rank: int = 0) -> List[bool]:
10471026
"""
1048-
This is a collective operation that returns ranks that are in the same node
1027+
This is a collective operation that returns if each rank is in the same node
10491028
as the source rank. It tests if processes are attached to the same
10501029
memory system (shared access to shared memory).
1051-
Args:
1052-
pg: the global process group to test
1053-
source_rank: the rank to test against
1054-
Returns:
1055-
A list of ranks that are in the same node as the source rank.
10561030
"""
10571031
if isinstance(pg, ProcessGroup):
10581032
assert torch.distributed.get_backend(
10591033
pg) != torch.distributed.Backend.NCCL, (
1060-
"same_node_ranks should be tested with a non-NCCL group.")
1034+
"in_the_same_node_as should be tested with a non-NCCL group.")
10611035
# local rank inside the group
10621036
rank = torch.distributed.get_rank(group=pg)
10631037
world_size = torch.distributed.get_world_size(group=pg)
@@ -1129,7 +1103,7 @@ def same_node_ranks(pg: Union[ProcessGroup, StatelessProcessGroup],
11291103
rank_data = pg.broadcast_obj(is_in_the_same_node, src=i)
11301104
aggregated_data += rank_data
11311105

1132-
return [i for i, x in enumerate(aggregated_data.tolist()) if x == 1]
1106+
return [x == 1 for x in aggregated_data.tolist()]
11331107

11341108

11351109
def initialize_tensor_parallel_group(
@@ -1258,4 +1232,4 @@ def initialize_sequence_parallel_group(
12581232
backend,
12591233
group_name=group_name)
12601234

1261-
return sp_group
1235+
return sp_group

fastvideo/v1/layers/layernorm.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
import torch
77
import torch.nn as nn
88
import torch.nn.functional as F
9-
from torch.distributed.tensor import DTensor
109

1110
from fastvideo.v1.layers.custom_op import CustomOp
1211

@@ -77,12 +76,7 @@ def forward_native(
7776
x = x * torch.rsqrt(variance + self.variance_epsilon)
7877
x = x.to(orig_dtype)
7978
if self.has_weight:
80-
# TODO(wenxuan): When using CPU offload, FSDP has a bug that doesn't unwrap DTensor in final_layer_norm.
81-
# Report this
82-
if isinstance(self.weight, DTensor):
83-
x = x * self.weight.to_local().to(x.device)
84-
else:
85-
x = x * self.weight
79+
x = x * self.weight
8680
if residual is None:
8781
return x
8882
else:

fastvideo/v1/models/dits/wanvideo.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -318,9 +318,9 @@ def forward(
318318
value, _ = self.to_v(norm_hidden_states)
319319

320320
if self.norm_q is not None:
321-
query = self.norm_q.forward_native(query)
321+
query = self.norm_q(query)
322322
if self.norm_k is not None:
323-
key = self.norm_k.forward_native(key)
323+
key = self.norm_k(key)
324324

325325
query = query.squeeze(1).unflatten(2, (self.num_attention_heads, -1))
326326
key = key.squeeze(1).unflatten(2, (self.num_attention_heads, -1))
@@ -465,9 +465,9 @@ def forward(
465465
gate_compress, _ = self.to_gate_compress(norm_hidden_states)
466466

467467
if self.norm_q is not None:
468-
query = self.norm_q.forward_native(query)
468+
query = self.norm_q(query)
469469
if self.norm_k is not None:
470-
key = self.norm_k.forward_native(key)
470+
key = self.norm_k(key)
471471

472472
query = query.squeeze(1).unflatten(2, (self.num_attention_heads, -1))
473473
key = key.squeeze(1).unflatten(2, (self.num_attention_heads, -1))

fastvideo/v1/models/encoders/t5.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ def __init__(self,
124124
self.layer_norm = RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
125125

126126
def forward(self, hidden_states) -> torch.Tensor:
127-
forwarded_states = self.layer_norm.forward_native(hidden_states)
127+
forwarded_states = self.layer_norm(hidden_states)
128128
forwarded_states = self.DenseReluDense(forwarded_states)
129129
hidden_states = hidden_states + forwarded_states
130130
return hidden_states
@@ -362,7 +362,7 @@ def forward(
362362
attention_mask: torch.Tensor,
363363
attn_metadata: Optional[AttentionMetadata] = None,
364364
) -> torch.Tensor:
365-
normed_hidden_states = self.layer_norm.forward_native(hidden_states)
365+
normed_hidden_states = self.layer_norm(hidden_states)
366366
attention_output = self.SelfAttention(
367367
hidden_states=normed_hidden_states,
368368
attention_mask=attention_mask,
@@ -391,7 +391,7 @@ def forward(
391391
hidden_states: torch.Tensor,
392392
attn_metadata: Optional[AttentionMetadata] = None,
393393
) -> torch.Tensor:
394-
normed_hidden_states = self.layer_norm.forward_native(hidden_states)
394+
normed_hidden_states = self.layer_norm(hidden_states)
395395
attention_output = self.EncDecAttention(
396396
hidden_states=normed_hidden_states,
397397
attn_metadata=attn_metadata,

fastvideo/v1/models/loader/component_loader.py

Lines changed: 14 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
from typing import Any, Generator, Iterable, List, Optional, Tuple, cast
1111

1212
import torch
13-
import torch.distributed as dist
1413
import torch.nn as nn
1514
from safetensors.torch import load_file as safetensors_load_file
1615
from transformers import AutoImageProcessor, AutoTokenizer
@@ -21,9 +20,7 @@
2120
from fastvideo.v1.fastvideo_args import FastVideoArgs
2221
from fastvideo.v1.logger import init_logger
2322
from fastvideo.v1.models.hf_transformer_utils import get_diffusers_config
24-
from fastvideo.v1.models.loader.fsdp_load import (init_device_mesh,
25-
maybe_load_fsdp_model,
26-
shard_model)
23+
from fastvideo.v1.models.loader.fsdp_load import maybe_load_fsdp_model
2724
from fastvideo.v1.models.loader.utils import set_default_torch_dtype
2825
from fastvideo.v1.models.loader.weight_utils import (
2926
filter_duplicate_safetensors_files, filter_files_not_needed_for_inference,
@@ -166,19 +163,16 @@ def _prepare_weights(
166163
return hf_folder, hf_weights_files, use_safetensors
167164

168165
def _get_weights_iterator(
169-
self,
170-
source: "Source",
171-
to_cpu: bool = True
166+
self, source: "Source"
172167
) -> Generator[Tuple[str, torch.Tensor], None, None]:
173168
"""Get an iterator for the model weights based on the load format."""
174169
hf_folder, hf_weights_files, use_safetensors = self._prepare_weights(
175170
source.model_or_path, source.fall_back_to_pt,
176171
source.allow_patterns_overrides)
177172
if use_safetensors:
178-
weights_iterator = safetensors_weights_iterator(
179-
hf_weights_files, to_cpu)
173+
weights_iterator = safetensors_weights_iterator(hf_weights_files)
180174
else:
181-
weights_iterator = pt_weights_iterator(hf_weights_files, to_cpu)
175+
weights_iterator = pt_weights_iterator(hf_weights_files)
182176

183177
if self.counter_before_loading_weights == 0.0:
184178
self.counter_before_loading_weights = time.perf_counter()
@@ -187,11 +181,10 @@ def _get_weights_iterator(
187181
for (name, tensor) in weights_iterator)
188182

189183
def _get_all_weights(
190-
self,
191-
model_config: Any,
192-
model: nn.Module,
193-
model_path: str,
194-
to_cpu: bool = True
184+
self,
185+
model_config: Any,
186+
model: nn.Module,
187+
model_path: str,
195188
) -> Generator[Tuple[str, torch.Tensor], None, None]:
196189
primary_weights = TextEncoderLoader.Source(
197190
model_path,
@@ -200,14 +193,14 @@ def _get_all_weights(
200193
allow_patterns_overrides=getattr(model, "allow_patterns_overrides",
201194
None),
202195
)
203-
yield from self._get_weights_iterator(primary_weights, to_cpu)
196+
yield from self._get_weights_iterator(primary_weights)
204197

205198
secondary_weights = cast(
206199
Iterable[TextEncoderLoader.Source],
207200
getattr(model, "secondary_weights", ()),
208201
)
209202
for source in secondary_weights:
210-
yield from self._get_weights_iterator(source, to_cpu)
203+
yield from self._get_weights_iterator(source)
211204

212205
def load(self, model_path: str, architecture: str,
213206
fastvideo_args: FastVideoArgs):
@@ -243,19 +236,13 @@ def load(self, model_path: str, architecture: str,
243236
target_device = get_local_torch_device()
244237
# TODO(will): add support for other dtypes
245238
return self.load_model(model_path, encoder_config, target_device,
246-
fastvideo_args, encoder_precision)
239+
encoder_precision)
247240

248241
def load_model(self,
249242
model_path: str,
250243
model_config: EncoderConfig,
251244
target_device: torch.device,
252-
fastvideo_args: FastVideoArgs,
253245
dtype: str = "fp16"):
254-
use_cpu_offload = fastvideo_args.text_encoder_offload and len(
255-
getattr(model_config, "_fsdp_shard_conditions", [])) > 0
256-
257-
if fastvideo_args.text_encoder_offload:
258-
target_device = torch.device("cpu")
259246
with set_default_torch_dtype(PRECISION_TO_TYPE[dtype]):
260247
with target_device:
261248
architectures = getattr(model_config, "architectures", [])
@@ -264,26 +251,12 @@ def load_model(self,
264251

265252
weights_to_load = {name for name, _ in model.named_parameters()}
266253
loaded_weights = model.load_weights(
267-
self._get_all_weights(model_config, model, model_path,
268-
use_cpu_offload))
254+
self._get_all_weights(model_config, model, model_path))
269255
self.counter_after_loading_weights = time.perf_counter()
270256
logger.info(
271257
"Loading weights took %.2f seconds",
272258
self.counter_after_loading_weights -
273259
self.counter_before_loading_weights)
274-
275-
if use_cpu_offload:
276-
mesh = init_device_mesh(
277-
"cuda",
278-
mesh_shape=(1, dist.get_world_size()),
279-
mesh_dim_names=("offload", "replicate"),
280-
)
281-
shard_model(model,
282-
cpu_offload=True,
283-
reshard_after_forward=True,
284-
mesh=mesh["offload"],
285-
fsdp_shard_conditions=model._fsdp_shard_conditions,
286-
pin_cpu_memory=fastvideo_args.pin_cpu_memory)
287260
# We only enable strict check for non-quantized models
288261
# that have loaded weights tracking currently.
289262
# if loaded_weights is not None:
@@ -320,7 +293,7 @@ def load(self, model_path: str, architecture: str,
320293
target_device = get_local_torch_device()
321294
# TODO(will): add support for other dtypes
322295
return self.load_model(
323-
model_path, encoder_config, target_device, fastvideo_args,
296+
model_path, encoder_config, target_device,
324297
fastvideo_args.pipeline_config.image_encoder_precision)
325298

326299

@@ -567,4 +540,4 @@ def load_module(module_name: str, component_model_path: str,
567540
transformers_or_diffusers)
568541

569542
# Load the module
570-
return loader.load(component_model_path, architecture, fastvideo_args)
543+
return loader.load(component_model_path, architecture, fastvideo_args)

0 commit comments

Comments
 (0)