1515from torch .utils .data import Dataset
1616from torchdata .stateful_dataloader import StatefulDataLoader
1717
18- from fastvideo .v1 .distributed import (get_dp_group ,
19- get_sequence_model_parallel_rank ,
20- get_sp_group )
18+ from fastvideo .v1 .distributed import (get_sp_group , get_sp_parallel_rank ,
19+ get_sp_world_size , get_world_rank ,
20+ get_world_size )
2121from fastvideo .v1 .logger import init_logger
2222
2323logger = init_logger (__name__ )
@@ -28,23 +28,19 @@ class ParquetVideoTextDataset(Dataset):
2828
2929 def __init__ (self ,
3030 path : str ,
31- batch_size : int = 1024 ,
32- rank : int = 0 ,
33- world_size : int = 1 ,
31+ batch_size ,
3432 cfg_rate : float = 0.0 ,
3533 num_latent_t : int = 2 ,
3634 seed : int = 0 ,
3735 validation : bool = False ):
3836 super ().__init__ ()
3937 self .path = str (path )
4038 self .batch_size = batch_size
41- self .rank = rank
42- self .local_rank = get_sequence_model_parallel_rank ()
39+ self .global_rank = get_world_rank ()
40+ self .rank_in_sp_group = get_sp_parallel_rank ()
4341 self .sp_group = get_sp_group ()
44- self .dp_group = get_dp_group ()
45- self .dp_world_size = self .dp_group .world_size
46- self .sp_world_size = self .sp_group .world_size
47- self .world_size = int (os .getenv ("WORLD_SIZE" , 1 ))
42+ self .sp_world_size = get_sp_world_size ()
43+ self .world_size = get_world_size ()
4844 self .cfg_rate = cfg_rate
4945 self .num_latent_t = num_latent_t
5046 self .local_indices = None
@@ -56,22 +52,26 @@ def __init__(self,
5652
5753 self .plan_output_dir = os .path .join (
5854 self .path ,
59- f"data_plan_ { self .world_size } _ { self .sp_world_size } _ { self . dp_world_size } .json"
55+ f"data_plan_world_size_ { self .world_size } _sp_size_ { self .sp_world_size } .json"
6056 )
6157
62- ranks = get_sp_group ().ranks
58+ # group_ranks: a list of lists
59+ # len(group_ranks) = self.world_size
60+ # len(group_ranks[i]) = self.sp_world_size
61+ # group_ranks[i] represents the ranks of the SP group for the i-th GPU
62+ # For example, if self.world_size = 4, self.sp_world_size = 2, then
63+ # group_ranks = [[0, 1], [0, 1], [2, 3], [2, 3]]
64+ sp_group_ranks = get_sp_group ().ranks
6365 group_ranks : List [List ] = [[] for _ in range (self .world_size )]
64- torch . distributed . all_gather_object (group_ranks , ranks )
66+ dist . all_gather_object (group_ranks , sp_group_ranks )
6567
66- if rank == 0 :
68+ if self . global_rank == 0 :
6769 # If a plan already exists, then skip creating a new plan
6870 # This will be useful when resume training
6971 if os .path .exists (self .plan_output_dir ):
70- print ( f "Using existing plan from { self .plan_output_dir } " )
72+ logger . info ( "Using existing plan from %s" , self .plan_output_dir )
7173 else :
72- print (f"Creating new plan for { self .plan_output_dir } " )
73- # Find all parquet files recursively, and record num_rows for each file
74- print (f"Scanning for parquet files in { self .path } " )
74+ logger .info ("Creating new plan for %s" , self .plan_output_dir )
7575 metadatas = []
7676 for root , _ , files in os .walk (self .path ):
7777 for file in sorted (files ):
@@ -94,7 +94,7 @@ def __init__(self,
9494
9595 # Get all sp groups
9696 # e.g. if num_gpus = 4, sp_size = 2
97- # group_ranks = [(0, 1), (2, 3)]
97+ # group_ranks = [(0, 1), (0, 1), (2, 3), ( 2, 3)]
9898 # We will assign the same batches of data to ranks in the same sp group, and we'll assign different batches to ranks in different sp groups
9999 # e.g. plan = {0: [row 1, row 4], 1: [row 1, row 4], 2: [row 2, row 3], 3: [row 2, row 3]}
100100 group_ranks_list : List [Any ] = list (
@@ -113,7 +113,6 @@ def __init__(self,
113113 json .dump (plan , f )
114114 else :
115115 pass
116-
117116 dist .barrier ()
118117 if validation :
119118 with open (self .plan_output_dir ) as f :
@@ -168,7 +167,7 @@ def get_validation_negative_prompt(
168167
169168 if self .cached_neg_prompt is None :
170169 raise RuntimeError (
171- f"Rank { self .rank } (SP rank { self .local_rank } ): Could not retrieve negative prompt data"
170+ f"Rank { self .global_rank } (SP rank { self .rank_in_sp_group } ): Could not retrieve negative prompt data"
172171 )
173172
174173 # Extract the components
@@ -186,15 +185,15 @@ def get_validation_negative_prompt(
186185 lat = rearrange (lat ,
187186 "t (n s) h w -> t n s h w" ,
188187 n = self .sp_world_size ).contiguous ()
189- lat = lat [:, self .local_rank , :, :, :]
188+ lat = lat [:, self .rank_in_sp_group , :, :, :]
190189 return lat , emb , mask , info
191190
192191 def __len__ (self ):
193192 if self .local_indices is None :
194193 try :
195194 with open (self .plan_output_dir ) as f :
196195 plan = json .load (f )
197- self .local_indices = plan [str (self .rank )]
196+ self .local_indices = plan [str (self .global_rank )]
198197 except Exception as err :
199198 raise Exception (
200199 "The data plan hasn't been created yet" ) from err
@@ -206,7 +205,7 @@ def __getitem__(self, idx):
206205 try :
207206 with open (self .plan_output_dir ) as f :
208207 plan = json .load (f )
209- self .local_indices = plan [self .rank ]
208+ self .local_indices = plan [self .global_rank ]
210209 except Exception as err :
211210 raise Exception (
212211 "The data plan hasn't been created yet" ) from err
@@ -240,7 +239,7 @@ def __getitem__(self, idx):
240239 lat = rearrange (lat ,
241240 "t (n s) h w -> t n s h w" ,
242241 n = self .sp_world_size ).contiguous ()
243- lat = lat [:, self .local_rank , :, :, :]
242+ lat = lat [:, self .rank_in_sp_group , :, :, :]
244243 return lat , emb , mask , info
245244
246245 def _process_row (self , row ) -> Dict [str , Any ]:
@@ -356,8 +355,6 @@ def _process_row(self, row) -> Dict[str, Any]:
356355 dataset = ParquetVideoTextDataset (
357356 args .path ,
358357 batch_size = args .batch_size ,
359- rank = rank ,
360- world_size = world_size ,
361358 )
362359
363360 # Create DataLoader with proper settings
0 commit comments