Skip to content

Commit 8631c1b

Browse files
[Training] Support Multi-Node training with FSDP + SP (#459)
1 parent 5357e12 commit 8631c1b

File tree

13 files changed

+338
-120
lines changed

13 files changed

+338
-120
lines changed

fastvideo/v1/attention/layer.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
get_sequence_model_parallel_rank, get_sequence_model_parallel_world_size)
1414
from fastvideo.v1.forward_context import ForwardContext, get_forward_context
1515
from fastvideo.v1.platforms import _Backend
16+
from fastvideo.v1.utils import get_compute_dtype
1617

1718

1819
class DistributedAttention(nn.Module):
@@ -38,7 +39,7 @@ def __init__(self,
3839
if num_kv_heads is None:
3940
num_kv_heads = num_heads
4041

41-
dtype = torch.get_default_dtype()
42+
dtype = get_compute_dtype()
4243
attn_backend = get_attn_backend(
4344
head_size,
4445
dtype,
@@ -155,7 +156,7 @@ def __init__(self,
155156
if num_kv_heads is None:
156157
num_kv_heads = num_heads
157158

158-
dtype = torch.get_default_dtype()
159+
dtype = get_compute_dtype()
159160
attn_backend = get_attn_backend(
160161
head_size,
161162
dtype,

fastvideo/v1/dataset/parquet_datasets.py

Lines changed: 45 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,8 @@
1515
from torch.utils.data import Dataset
1616
from torchdata.stateful_dataloader import StatefulDataLoader
1717

18-
from fastvideo.v1.distributed import (get_sequence_model_parallel_rank,
18+
from fastvideo.v1.distributed import (get_dp_group,
19+
get_sequence_model_parallel_rank,
1920
get_sp_group)
2021
from fastvideo.v1.logger import init_logger
2122

@@ -38,13 +39,18 @@ def __init__(self,
3839
self.batch_size = batch_size
3940
self.rank = rank
4041
self.local_rank = get_sequence_model_parallel_rank()
41-
self.sp_world_size = world_size
42+
self.sp_group = get_sp_group()
43+
self.dp_group = get_dp_group()
44+
self.dp_world_size = self.dp_group.world_size
45+
self.sp_world_size = self.sp_group.world_size
4246
self.world_size = int(os.getenv("WORLD_SIZE", 1))
4347
self.cfg_rate = cfg_rate
4448
self.num_latent_t = num_latent_t
4549
self.local_indices = None
4650
self.plan_output_dir = os.path.join(
47-
self.path, f"data_plan_{self.world_size}_{self.sp_world_size}.json")
51+
self.path,
52+
f"data_plan_{self.world_size}_{self.sp_world_size}_{self.dp_world_size}.json"
53+
)
4854

4955
ranks = get_sp_group().ranks
5056
group_ranks: List[List] = [[] for _ in range(self.world_size)]
@@ -55,40 +61,40 @@ def __init__(self,
5561
# This will be useful when resume training
5662
if os.path.exists(self.plan_output_dir):
5763
print(f"Using existing plan from {self.plan_output_dir}")
58-
dist.barrier()
59-
return
60-
61-
# Find all parquet files recursively, and record num_rows for each file
62-
print(f"Scanning for parquet files in {self.path}")
63-
metadatas = []
64-
for root, _, files in os.walk(self.path):
65-
for file in sorted(files):
66-
if file.endswith('.parquet'):
67-
file_path = os.path.join(root, file)
68-
num_rows = pq.ParquetFile(file_path).metadata.num_rows
69-
for row_idx in range(num_rows):
70-
metadatas.append((file_path, row_idx))
71-
72-
# Generate the plan that distribute rows among workers
73-
random.seed(seed)
74-
random.shuffle(metadatas)
75-
76-
# Get all sp groups
77-
# e.g. if num_gpus = 4, sp_size = 2
78-
# group_ranks = [(0, 1), (2, 3)]
79-
# We will assign the same batches of data to ranks in the same sp group, and we'll assign different batches to ranks in different sp groups
80-
# e.g. plan = {0: [row 1, row 4], 1: [row 1, row 4], 2: [row 2, row 3], 3: [row 2, row 3]}
81-
group_ranks_list: List[Any] = list(
82-
set(tuple(r) for r in group_ranks))
83-
num_sp_groups = len(group_ranks_list)
84-
plan = defaultdict(list)
85-
for idx, metadata in enumerate(metadatas):
86-
sp_group_idx = idx % num_sp_groups
87-
for global_rank in group_ranks_list[sp_group_idx]:
88-
plan[global_rank].append(metadata)
89-
90-
with open(self.plan_output_dir, "w") as f:
91-
json.dump(plan, f)
64+
else:
65+
print(f"Creating new plan for {self.plan_output_dir}")
66+
# Find all parquet files recursively, and record num_rows for each file
67+
print(f"Scanning for parquet files in {self.path}")
68+
metadatas = []
69+
for root, _, files in os.walk(self.path):
70+
for file in sorted(files):
71+
if file.endswith('.parquet'):
72+
file_path = os.path.join(root, file)
73+
num_rows = pq.ParquetFile(
74+
file_path).metadata.num_rows
75+
for row_idx in range(num_rows):
76+
metadatas.append((file_path, row_idx))
77+
78+
# Generate the plan that distribute rows among workers
79+
random.seed(seed)
80+
random.shuffle(metadatas)
81+
82+
# Get all sp groups
83+
# e.g. if num_gpus = 4, sp_size = 2
84+
# group_ranks = [(0, 1), (2, 3)]
85+
# We will assign the same batches of data to ranks in the same sp group, and we'll assign different batches to ranks in different sp groups
86+
# e.g. plan = {0: [row 1, row 4], 1: [row 1, row 4], 2: [row 2, row 3], 3: [row 2, row 3]}
87+
group_ranks_list: List[Any] = list(
88+
set(tuple(r) for r in group_ranks))
89+
num_sp_groups = len(group_ranks_list)
90+
plan = defaultdict(list)
91+
for idx, metadata in enumerate(metadatas):
92+
sp_group_idx = idx % num_sp_groups
93+
for global_rank in group_ranks_list[sp_group_idx]:
94+
plan[global_rank].append(metadata)
95+
96+
with open(self.plan_output_dir, "w") as f:
97+
json.dump(plan, f)
9298
dist.barrier()
9399

94100
def __len__(self):
@@ -121,9 +127,9 @@ def __getitem__(self, idx):
121127
cumulative = 0
122128
for i in range(parquet_file.num_row_groups):
123129
num_rows = parquet_file.metadata.row_group(i).num_rows
124-
if cumulative + num_rows > idx:
130+
if cumulative + num_rows > row_idx:
125131
row_group_index = i
126-
local_index = idx - cumulative
132+
local_index = row_idx - cumulative
127133
break
128134
cumulative += num_rows
129135

fastvideo/v1/distributed/__init__.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,10 @@
22

33
from fastvideo.v1.distributed.communication_op import *
44
from fastvideo.v1.distributed.parallel_state import (
5-
cleanup_dist_env_and_memory, get_sequence_model_parallel_rank,
6-
get_sequence_model_parallel_world_size, get_tensor_model_parallel_rank,
5+
cleanup_dist_env_and_memory, get_data_parallel_rank,
6+
get_data_parallel_world_size, get_dp_group,
7+
get_sequence_model_parallel_rank, get_sequence_model_parallel_world_size,
8+
get_sp_group, get_tensor_model_parallel_rank,
79
get_tensor_model_parallel_world_size, get_world_group,
810
init_distributed_environment, initialize_model_parallel,
911
model_parallel_is_initialized)
@@ -12,11 +14,15 @@
1214
__all__ = [
1315
"init_distributed_environment",
1416
"initialize_model_parallel",
17+
"get_data_parallel_world_size",
18+
"get_data_parallel_rank",
1519
"get_sequence_model_parallel_rank",
1620
"get_sequence_model_parallel_world_size",
1721
"get_tensor_model_parallel_rank",
1822
"get_tensor_model_parallel_world_size",
1923
"cleanup_dist_env_and_memory",
2024
"get_world_group",
25+
"get_dp_group",
26+
"get_sp_group",
2127
"model_parallel_is_initialized",
2228
]

fastvideo/v1/distributed/parallel_state.py

Lines changed: 45 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -704,7 +704,7 @@ def init_world_group(ranks: List[int], local_rank: int,
704704
group_ranks=[ranks],
705705
local_rank=local_rank,
706706
torch_distributed_backend=backend,
707-
use_device_communicator=False,
707+
use_device_communicator=True,
708708
group_name="world",
709709
)
710710

@@ -794,9 +794,18 @@ def get_sp_group() -> GroupCoordinator:
794794
return _SP
795795

796796

797+
_DP: Optional[GroupCoordinator] = None
798+
799+
800+
def get_dp_group() -> GroupCoordinator:
801+
assert _DP is not None, ("data parallel group is not initialized")
802+
return _DP
803+
804+
797805
def initialize_model_parallel(
798806
tensor_model_parallel_size: int = 1,
799807
sequence_model_parallel_size: int = 1,
808+
data_parallel_size: int = 1,
800809
backend: Optional[str] = None,
801810
) -> None:
802811
"""
@@ -852,6 +861,22 @@ def initialize_model_parallel(
852861
backend,
853862
group_name="sp")
854863

864+
# Build the data parallel groups.
865+
num_data_parallel_groups: int = (world_size // data_parallel_size)
866+
global _DP
867+
assert _DP is None, ("data parallel group is already initialized")
868+
group_ranks = []
869+
870+
for i in range(num_data_parallel_groups):
871+
ranks = list(range(i * data_parallel_size,
872+
(i + 1) * data_parallel_size))
873+
group_ranks.append(ranks)
874+
875+
_DP = init_model_parallel_group(group_ranks,
876+
get_world_group().local_rank,
877+
backend,
878+
group_name="dp")
879+
855880

856881
def get_sequence_model_parallel_world_size() -> int:
857882
"""Return world size for the sequence model parallel group."""
@@ -863,9 +888,20 @@ def get_sequence_model_parallel_rank() -> int:
863888
return get_sp_group().rank_in_group
864889

865890

891+
def get_data_parallel_world_size() -> int:
892+
"""Return world size for the data parallel group."""
893+
return get_dp_group().world_size
894+
895+
896+
def get_data_parallel_rank() -> int:
897+
"""Return my rank for the data parallel group."""
898+
return get_dp_group().rank_in_group
899+
900+
866901
def ensure_model_parallel_initialized(
867902
tensor_model_parallel_size: int,
868903
sequence_model_parallel_size: int,
904+
data_parallel_size: int,
869905
backend: Optional[str] = None,
870906
) -> None:
871907
"""Helper to initialize model parallel groups if they are not initialized,
@@ -876,7 +912,8 @@ def ensure_model_parallel_initialized(
876912
get_world_group().device_group)
877913
if not model_parallel_is_initialized():
878914
initialize_model_parallel(tensor_model_parallel_size,
879-
sequence_model_parallel_size, backend)
915+
sequence_model_parallel_size,
916+
data_parallel_size, backend)
880917
return
881918

882919
assert (
@@ -895,7 +932,7 @@ def ensure_model_parallel_initialized(
895932

896933
def model_parallel_is_initialized() -> bool:
897934
"""Check if tensor, sequence parallel groups are initialized."""
898-
return _TP is not None and _SP is not None
935+
return _TP is not None and _SP is not None and _DP is not None
899936

900937

901938
_TP_STATE_PATCHED = False
@@ -948,6 +985,11 @@ def destroy_model_parallel() -> None:
948985
_SP.destroy()
949986
_SP = None
950987

988+
global _DP
989+
if _DP:
990+
_DP.destroy()
991+
_DP = None
992+
951993

952994
def destroy_distributed_environment() -> None:
953995
global _WORLD

fastvideo/v1/fastvideo_args.py

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,8 @@ class FastVideoArgs:
4444
num_gpus: int = 1
4545
tp_size: Optional[int] = None
4646
sp_size: Optional[int] = None
47+
dp_size: int = 1
48+
dp_shards: Optional[int] = None
4749
dist_timeout: Optional[int] = None # timeout for torch.distributed
4850

4951
# Video generation parameters
@@ -70,7 +72,7 @@ class FastVideoArgs:
7072
# Text encoder configuration
7173
DEFAULT_TEXT_ENCODER_PRECISIONS = (
7274
"fp16",
73-
"fp16",
75+
# "fp16",
7476
)
7577
text_encoder_precisions: Tuple[str, ...] = field(
7678
default_factory=lambda: FastVideoArgs.DEFAULT_TEXT_ENCODER_PRECISIONS)
@@ -179,6 +181,20 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
179181
default=FastVideoArgs.sp_size,
180182
help="The sequence parallelism size.",
181183
)
184+
parser.add_argument(
185+
"--data-parallel-size",
186+
"--dp-size",
187+
type=int,
188+
default=FastVideoArgs.dp_size,
189+
help="The data parallelism size.",
190+
)
191+
parser.add_argument(
192+
"--data-parallel-shards",
193+
"--dp-shards",
194+
type=int,
195+
default=FastVideoArgs.dp_shards,
196+
help="The data parallelism shards.",
197+
)
182198
parser.add_argument(
183199
"--dist-timeout",
184200
type=int,
@@ -332,6 +348,10 @@ def from_cli_args(cls, args: argparse.Namespace) -> "FastVideoArgs":
332348
kwargs[attr] = args.tensor_parallel_size
333349
elif attr == 'sp_size' and hasattr(args, 'sequence_parallel_size'):
334350
kwargs[attr] = args.sequence_parallel_size
351+
elif attr == 'dp_size' and hasattr(args, 'data_parallel_size'):
352+
kwargs[attr] = args.data_parallel_size
353+
elif attr == 'dp_shards' and hasattr(args, 'data_parallel_shards'):
354+
kwargs[attr] = args.data_parallel_shards
335355
elif attr == 'flow_shift' and hasattr(args, 'shift'):
336356
kwargs[attr] = args.shift
337357
# Use getattr with default value from the dataclass for potentially missing attributes
@@ -343,10 +363,17 @@ def from_cli_args(cls, args: argparse.Namespace) -> "FastVideoArgs":
343363

344364
def check_fastvideo_args(self) -> None:
345365
"""Validate inference arguments for consistency"""
366+
if not self.inference_mode:
367+
assert self.dp_size is not None, "dp_size must be set for training"
368+
assert self.dp_shards is not None, "dp_shards must be set for training"
369+
assert self.sp_size is not None, "sp_size must be set for training"
370+
346371
if self.tp_size is None:
347372
self.tp_size = self.num_gpus
348373
if self.sp_size is None:
349374
self.sp_size = self.num_gpus
375+
if self.dp_shards is None:
376+
self.dp_shards = self.num_gpus
350377

351378
if self.num_gpus < max(self.tp_size, self.sp_size):
352379
self.num_gpus = max(self.tp_size, self.sp_size)
@@ -535,6 +562,10 @@ def from_cli_args(cls, args: argparse.Namespace) -> "TrainingArgs":
535562
kwargs[attr] = args.sequence_parallel_size
536563
elif attr == 'flow_shift' and hasattr(args, 'shift'):
537564
kwargs[attr] = args.shift
565+
elif attr == 'dp_size' and hasattr(args, 'data_parallel_size'):
566+
kwargs[attr] = args.data_parallel_size
567+
elif attr == 'dp_shards' and hasattr(args, 'data_parallel_shards'):
568+
kwargs[attr] = args.data_parallel_shards
538569
# Use getattr with default value from the dataclass for potentially missing attributes
539570
else:
540571
default_value = getattr(cls, attr, None)

fastvideo/v1/models/loader/component_loader.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from transformers import AutoImageProcessor, AutoTokenizer
1616
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME
1717

18-
from fastvideo.v1.fastvideo_args import FastVideoArgs
18+
from fastvideo.v1.fastvideo_args import FastVideoArgs, TrainingArgs
1919
from fastvideo.v1.logger import init_logger
2020
from fastvideo.v1.models.hf_transformer_utils import get_diffusers_config
2121
from fastvideo.v1.models.loader.fsdp_load import load_fsdp_model
@@ -391,11 +391,18 @@ def load(self, model_path: str, architecture: str,
391391
len(safetensors_list), model_path)
392392

393393
# initialize_sequence_parallel_group(fastvideo_args.sp_size)
394-
default_dtype = PRECISION_TO_TYPE[fastvideo_args.precision]
394+
if fastvideo_args.training_mode:
395+
assert isinstance(
396+
fastvideo_args, TrainingArgs
397+
), "fastvideo_args must be a TrainingArgs object when training_mode is True"
398+
default_dtype = PRECISION_TO_TYPE[fastvideo_args.master_weight_type]
399+
else:
400+
default_dtype = PRECISION_TO_TYPE[fastvideo_args.precision]
395401

396402
# Load the model using FSDP loader
397403
logger.info("Loading model from %s, default_dtype: %s", cls_name,
398404
default_dtype)
405+
assert fastvideo_args.dp_shards is not None
399406
model = load_fsdp_model(
400407
model_cls=model_cls,
401408
init_params={
@@ -404,6 +411,8 @@ def load(self, model_path: str, architecture: str,
404411
},
405412
weight_dir_list=safetensors_list,
406413
device=fastvideo_args.device,
414+
data_parallel_size=fastvideo_args.dp_size,
415+
data_parallel_shards=fastvideo_args.dp_shards,
407416
cpu_offload=fastvideo_args.use_cpu_offload,
408417
default_dtype=default_dtype,
409418
# TODO(will): make these configurable

0 commit comments

Comments
 (0)