Skip to content

Commit 0991003

Browse files
[bugfix] [misc] fix denoising stage init; rename distributed env function; fix logging. (#481)
Co-authored-by: Will Lin <[email protected]>
1 parent 8f8ce6d commit 0991003

29 files changed

+194
-183
lines changed

fastvideo/v1/attention/layer.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@
99
get_attn_backend)
1010
from fastvideo.v1.distributed.communication_op import (
1111
sequence_model_parallel_all_gather, sequence_model_parallel_all_to_all_4D)
12-
from fastvideo.v1.distributed.parallel_state import (
13-
get_sequence_model_parallel_rank, get_sequence_model_parallel_world_size)
12+
from fastvideo.v1.distributed.parallel_state import (get_sp_parallel_rank,
13+
get_sp_world_size)
1414
from fastvideo.v1.forward_context import ForwardContext, get_forward_context
1515
from fastvideo.v1.platforms import _Backend
1616
from fastvideo.v1.utils import get_compute_dtype
@@ -86,8 +86,8 @@ def forward(
8686
assert q.dim() == 4 and k.dim() == 4 and v.dim(
8787
) == 4, "Expected 4D tensors"
8888
batch_size, seq_len, num_heads, head_dim = q.shape
89-
local_rank = get_sequence_model_parallel_rank()
90-
world_size = get_sequence_model_parallel_world_size()
89+
local_rank = get_sp_parallel_rank()
90+
world_size = get_sp_world_size()
9191

9292
forward_context: ForwardContext = get_forward_context()
9393
ctx_attn_metadata = forward_context.attn_metadata

fastvideo/v1/dataset/parquet_datasets.py

Lines changed: 26 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,9 @@
1515
from torch.utils.data import Dataset
1616
from torchdata.stateful_dataloader import StatefulDataLoader
1717

18-
from fastvideo.v1.distributed import (get_dp_group,
19-
get_sequence_model_parallel_rank,
20-
get_sp_group)
18+
from fastvideo.v1.distributed import (get_sp_group, get_sp_parallel_rank,
19+
get_sp_world_size, get_world_rank,
20+
get_world_size)
2121
from fastvideo.v1.logger import init_logger
2222

2323
logger = init_logger(__name__)
@@ -28,23 +28,19 @@ class ParquetVideoTextDataset(Dataset):
2828

2929
def __init__(self,
3030
path: str,
31-
batch_size: int = 1024,
32-
rank: int = 0,
33-
world_size: int = 1,
31+
batch_size,
3432
cfg_rate: float = 0.0,
3533
num_latent_t: int = 2,
3634
seed: int = 0,
3735
validation: bool = False):
3836
super().__init__()
3937
self.path = str(path)
4038
self.batch_size = batch_size
41-
self.rank = rank
42-
self.local_rank = get_sequence_model_parallel_rank()
39+
self.global_rank = get_world_rank()
40+
self.rank_in_sp_group = get_sp_parallel_rank()
4341
self.sp_group = get_sp_group()
44-
self.dp_group = get_dp_group()
45-
self.dp_world_size = self.dp_group.world_size
46-
self.sp_world_size = self.sp_group.world_size
47-
self.world_size = int(os.getenv("WORLD_SIZE", 1))
42+
self.sp_world_size = get_sp_world_size()
43+
self.world_size = get_world_size()
4844
self.cfg_rate = cfg_rate
4945
self.num_latent_t = num_latent_t
5046
self.local_indices = None
@@ -56,22 +52,26 @@ def __init__(self,
5652

5753
self.plan_output_dir = os.path.join(
5854
self.path,
59-
f"data_plan_{self.world_size}_{self.sp_world_size}_{self.dp_world_size}.json"
55+
f"data_plan_world_size_{self.world_size}_sp_size_{self.sp_world_size}.json"
6056
)
6157

62-
ranks = get_sp_group().ranks
58+
# group_ranks: a list of lists
59+
# len(group_ranks) = self.world_size
60+
# len(group_ranks[i]) = self.sp_world_size
61+
# group_ranks[i] represents the ranks of the SP group for the i-th GPU
62+
# For example, if self.world_size = 4, self.sp_world_size = 2, then
63+
# group_ranks = [[0, 1], [0, 1], [2, 3], [2, 3]]
64+
sp_group_ranks = get_sp_group().ranks
6365
group_ranks: List[List] = [[] for _ in range(self.world_size)]
64-
torch.distributed.all_gather_object(group_ranks, ranks)
66+
dist.all_gather_object(group_ranks, sp_group_ranks)
6567

66-
if rank == 0:
68+
if self.global_rank == 0:
6769
# If a plan already exists, then skip creating a new plan
6870
# This will be useful when resume training
6971
if os.path.exists(self.plan_output_dir):
70-
print(f"Using existing plan from {self.plan_output_dir}")
72+
logger.info("Using existing plan from %s", self.plan_output_dir)
7173
else:
72-
print(f"Creating new plan for {self.plan_output_dir}")
73-
# Find all parquet files recursively, and record num_rows for each file
74-
print(f"Scanning for parquet files in {self.path}")
74+
logger.info("Creating new plan for %s", self.plan_output_dir)
7575
metadatas = []
7676
for root, _, files in os.walk(self.path):
7777
for file in sorted(files):
@@ -94,7 +94,7 @@ def __init__(self,
9494

9595
# Get all sp groups
9696
# e.g. if num_gpus = 4, sp_size = 2
97-
# group_ranks = [(0, 1), (2, 3)]
97+
# group_ranks = [(0, 1), (0, 1), (2, 3), (2, 3)]
9898
# We will assign the same batches of data to ranks in the same sp group, and we'll assign different batches to ranks in different sp groups
9999
# e.g. plan = {0: [row 1, row 4], 1: [row 1, row 4], 2: [row 2, row 3], 3: [row 2, row 3]}
100100
group_ranks_list: List[Any] = list(
@@ -113,7 +113,6 @@ def __init__(self,
113113
json.dump(plan, f)
114114
else:
115115
pass
116-
117116
dist.barrier()
118117
if validation:
119118
with open(self.plan_output_dir) as f:
@@ -168,7 +167,7 @@ def get_validation_negative_prompt(
168167

169168
if self.cached_neg_prompt is None:
170169
raise RuntimeError(
171-
f"Rank {self.rank} (SP rank {self.local_rank}): Could not retrieve negative prompt data"
170+
f"Rank {self.global_rank} (SP rank {self.rank_in_sp_group}): Could not retrieve negative prompt data"
172171
)
173172

174173
# Extract the components
@@ -186,15 +185,15 @@ def get_validation_negative_prompt(
186185
lat = rearrange(lat,
187186
"t (n s) h w -> t n s h w",
188187
n=self.sp_world_size).contiguous()
189-
lat = lat[:, self.local_rank, :, :, :]
188+
lat = lat[:, self.rank_in_sp_group, :, :, :]
190189
return lat, emb, mask, info
191190

192191
def __len__(self):
193192
if self.local_indices is None:
194193
try:
195194
with open(self.plan_output_dir) as f:
196195
plan = json.load(f)
197-
self.local_indices = plan[str(self.rank)]
196+
self.local_indices = plan[str(self.global_rank)]
198197
except Exception as err:
199198
raise Exception(
200199
"The data plan hasn't been created yet") from err
@@ -206,7 +205,7 @@ def __getitem__(self, idx):
206205
try:
207206
with open(self.plan_output_dir) as f:
208207
plan = json.load(f)
209-
self.local_indices = plan[self.rank]
208+
self.local_indices = plan[self.global_rank]
210209
except Exception as err:
211210
raise Exception(
212211
"The data plan hasn't been created yet") from err
@@ -240,7 +239,7 @@ def __getitem__(self, idx):
240239
lat = rearrange(lat,
241240
"t (n s) h w -> t n s h w",
242241
n=self.sp_world_size).contiguous()
243-
lat = lat[:, self.local_rank, :, :, :]
242+
lat = lat[:, self.rank_in_sp_group, :, :, :]
244243
return lat, emb, mask, info
245244

246245
def _process_row(self, row) -> Dict[str, Any]:
@@ -356,8 +355,6 @@ def _process_row(self, row) -> Dict[str, Any]:
356355
dataset = ParquetVideoTextDataset(
357356
args.path,
358357
batch_size=args.batch_size,
359-
rank=rank,
360-
world_size=world_size,
361358
)
362359

363360
# Create DataLoader with proper settings

fastvideo/v1/distributed/__init__.py

Lines changed: 23 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2,27 +2,37 @@
22

33
from fastvideo.v1.distributed.communication_op import *
44
from fastvideo.v1.distributed.parallel_state import (
5-
cleanup_dist_env_and_memory, get_data_parallel_rank,
6-
get_data_parallel_world_size, get_dp_group,
7-
get_sequence_model_parallel_rank, get_sequence_model_parallel_world_size,
8-
get_sp_group, get_tensor_model_parallel_rank,
9-
get_tensor_model_parallel_world_size, get_world_group,
10-
init_distributed_environment, initialize_model_parallel,
5+
cleanup_dist_env_and_memory, get_dp_group, get_dp_rank, get_dp_world_size,
6+
get_sp_group, get_sp_parallel_rank, get_sp_world_size, get_tp_group,
7+
get_tp_rank, get_tp_world_size, get_world_group, get_world_rank,
8+
get_world_size, init_distributed_environment, initialize_model_parallel,
119
model_parallel_is_initialized)
1210
from fastvideo.v1.distributed.utils import *
1311

1412
__all__ = [
13+
# Initialization
1514
"init_distributed_environment",
1615
"initialize_model_parallel",
17-
"get_data_parallel_world_size",
18-
"get_data_parallel_rank",
19-
"get_sequence_model_parallel_rank",
20-
"get_sequence_model_parallel_world_size",
21-
"get_tensor_model_parallel_rank",
22-
"get_tensor_model_parallel_world_size",
2316
"cleanup_dist_env_and_memory",
17+
"model_parallel_is_initialized",
18+
19+
# World group
2420
"get_world_group",
21+
"get_world_rank",
22+
"get_world_size",
23+
24+
# Data parallel group
2525
"get_dp_group",
26+
"get_dp_rank",
27+
"get_dp_world_size",
28+
29+
# Sequence parallel group
2630
"get_sp_group",
27-
"model_parallel_is_initialized",
31+
"get_sp_parallel_rank",
32+
"get_sp_world_size",
33+
34+
# Tensor parallel group
35+
"get_tp_group",
36+
"get_tp_rank",
37+
"get_tp_world_size",
2838
]

fastvideo/v1/distributed/parallel_state.py

Lines changed: 19 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -735,9 +735,6 @@ def get_tp_group() -> GroupCoordinator:
735735
return _TP
736736

737737

738-
# kept for backward compatibility
739-
get_tensor_model_parallel_group = get_tp_group
740-
741738
_ENABLE_CUSTOM_ALL_REDUCE = True
742739

743740

@@ -878,22 +875,32 @@ def initialize_model_parallel(
878875
group_name="dp")
879876

880877

881-
def get_sequence_model_parallel_world_size() -> int:
878+
def get_sp_world_size() -> int:
882879
"""Return world size for the sequence model parallel group."""
883880
return get_sp_group().world_size
884881

885882

886-
def get_sequence_model_parallel_rank() -> int:
883+
def get_sp_parallel_rank() -> int:
887884
"""Return my rank for the sequence model parallel group."""
888885
return get_sp_group().rank_in_group
889886

890887

891-
def get_data_parallel_world_size() -> int:
888+
def get_world_size() -> int:
889+
"""Return world size for the world group."""
890+
return get_world_group().world_size
891+
892+
893+
def get_world_rank() -> int:
894+
"""Return my rank for the world group."""
895+
return get_world_group().rank
896+
897+
898+
def get_dp_world_size() -> int:
892899
"""Return world size for the data parallel group."""
893900
return get_dp_group().world_size
894901

895902

896-
def get_data_parallel_rank() -> int:
903+
def get_dp_rank() -> int:
897904
"""Return my rank for the data parallel group."""
898905
return get_dp_group().rank_in_group
899906

@@ -916,10 +923,9 @@ def ensure_model_parallel_initialized(
916923
data_parallel_size, backend)
917924
return
918925

919-
assert (
920-
get_tensor_model_parallel_world_size() == tensor_model_parallel_size
921-
), ("tensor parallel group already initialized, but of unexpected size: "
922-
f"{get_tensor_model_parallel_world_size()=} vs. "
926+
assert (get_tp_world_size() == tensor_model_parallel_size), (
927+
"tensor parallel group already initialized, but of unexpected size: "
928+
f"{get_tp_world_size()=} vs. "
923929
f"{tensor_model_parallel_size=}")
924930

925931
if sequence_model_parallel_size > 1:
@@ -963,12 +969,12 @@ def patch_tensor_parallel_group(tp_group: GroupCoordinator):
963969
_TP = old_tp_group
964970

965971

966-
def get_tensor_model_parallel_world_size() -> int:
972+
def get_tp_world_size() -> int:
967973
"""Return world size for the tensor model parallel group."""
968974
return get_tp_group().world_size
969975

970976

971-
def get_tensor_model_parallel_rank() -> int:
977+
def get_tp_rank() -> int:
972978
"""Return my rank for the tensor model parallel group."""
973979
return get_tp_group().rank_in_group
974980

fastvideo/v1/fastvideo_args.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -551,7 +551,7 @@ class TrainingArgs(FastVideoArgs):
551551
gradient_accumulation_steps: int = 0
552552
learning_rate: float = 0.0
553553
scale_lr: bool = False
554-
lr_scheduler: str = ""
554+
lr_scheduler: str = "constant"
555555
lr_warmup_steps: int = 0
556556
max_grad_norm: float = 0.0
557557
gradient_checkpointing: bool = False

0 commit comments

Comments
 (0)