1515from torch .utils .data import Dataset
1616from torchdata .stateful_dataloader import StatefulDataLoader
1717
18- from fastvideo .v1 .distributed import (get_sequence_model_parallel_rank ,
18+ from fastvideo .v1 .distributed import (get_dp_group ,
19+ get_sequence_model_parallel_rank ,
1920 get_sp_group )
2021from fastvideo .v1 .logger import init_logger
2122
@@ -38,13 +39,18 @@ def __init__(self,
3839 self .batch_size = batch_size
3940 self .rank = rank
4041 self .local_rank = get_sequence_model_parallel_rank ()
41- self .sp_world_size = world_size
42+ self .sp_group = get_sp_group ()
43+ self .dp_group = get_dp_group ()
44+ self .dp_world_size = self .dp_group .world_size
45+ self .sp_world_size = self .sp_group .world_size
4246 self .world_size = int (os .getenv ("WORLD_SIZE" , 1 ))
4347 self .cfg_rate = cfg_rate
4448 self .num_latent_t = num_latent_t
4549 self .local_indices = None
4650 self .plan_output_dir = os .path .join (
47- self .path , f"data_plan_{ self .world_size } _{ self .sp_world_size } .json" )
51+ self .path ,
52+ f"data_plan_{ self .world_size } _{ self .sp_world_size } _{ self .dp_world_size } .json"
53+ )
4854
4955 ranks = get_sp_group ().ranks
5056 group_ranks : List [List ] = [[] for _ in range (self .world_size )]
@@ -55,40 +61,40 @@ def __init__(self,
5561 # This will be useful when resume training
5662 if os .path .exists (self .plan_output_dir ):
5763 print (f"Using existing plan from { self .plan_output_dir } " )
58- dist . barrier ()
59- return
60-
61- # Find all parquet files recursively, and record num_rows for each file
62- print ( f"Scanning for parquet files in { self . path } " )
63- metadatas = []
64- for root , _ , files in os . walk ( self . path ):
65- for file in sorted ( files ):
66- if file . endswith ( '.parquet' ):
67- file_path = os . path . join ( root , file )
68- num_rows = pq . ParquetFile ( file_path ).metadata .num_rows
69- for row_idx in range (num_rows ):
70- metadatas .append ((file_path , row_idx ))
71-
72- # Generate the plan that distribute rows among workers
73- random .seed (seed )
74- random .shuffle (metadatas )
75-
76- # Get all sp groups
77- # e.g. if num_gpus = 4, sp_size = 2
78- # group_ranks = [(0, 1), (2, 3)]
79- # We will assign the same batches of data to ranks in the same sp group, and we'll assign different batches to ranks in different sp groups
80- # e.g. plan = {0: [row 1, row 4], 1: [row 1, row 4], 2: [row 2, row 3], 3: [row 2, row 3]}
81- group_ranks_list : List [Any ] = list (
82- set (tuple (r ) for r in group_ranks ))
83- num_sp_groups = len (group_ranks_list )
84- plan = defaultdict (list )
85- for idx , metadata in enumerate (metadatas ):
86- sp_group_idx = idx % num_sp_groups
87- for global_rank in group_ranks_list [sp_group_idx ]:
88- plan [global_rank ].append (metadata )
89-
90- with open (self .plan_output_dir , "w" ) as f :
91- json .dump (plan , f )
64+ else :
65+ print ( f"Creating new plan for { self . plan_output_dir } " )
66+ # Find all parquet files recursively, and record num_rows for each file
67+ print ( f"Scanning for parquet files in { self . path } " )
68+ metadatas = []
69+ for root , _ , files in os . walk ( self . path ):
70+ for file in sorted ( files ):
71+ if file . endswith ( '.parquet' ):
72+ file_path = os . path . join ( root , file )
73+ num_rows = pq . ParquetFile (
74+ file_path ).metadata .num_rows
75+ for row_idx in range (num_rows ):
76+ metadatas .append ((file_path , row_idx ))
77+
78+ # Generate the plan that distribute rows among workers
79+ random .seed (seed )
80+ random .shuffle (metadatas )
81+
82+ # Get all sp groups
83+ # e.g. if num_gpus = 4, sp_size = 2
84+ # group_ranks = [(0, 1), (2, 3)]
85+ # We will assign the same batches of data to ranks in the same sp group, and we'll assign different batches to ranks in different sp groups
86+ # e.g. plan = {0: [row 1, row 4], 1: [row 1, row 4], 2: [row 2, row 3], 3: [row 2, row 3]}
87+ group_ranks_list : List [Any ] = list (
88+ set (tuple (r ) for r in group_ranks ))
89+ num_sp_groups = len (group_ranks_list )
90+ plan = defaultdict (list )
91+ for idx , metadata in enumerate (metadatas ):
92+ sp_group_idx = idx % num_sp_groups
93+ for global_rank in group_ranks_list [sp_group_idx ]:
94+ plan [global_rank ].append (metadata )
95+
96+ with open (self .plan_output_dir , "w" ) as f :
97+ json .dump (plan , f )
9298 dist .barrier ()
9399
94100 def __len__ (self ):
@@ -121,9 +127,9 @@ def __getitem__(self, idx):
121127 cumulative = 0
122128 for i in range (parquet_file .num_row_groups ):
123129 num_rows = parquet_file .metadata .row_group (i ).num_rows
124- if cumulative + num_rows > idx :
130+ if cumulative + num_rows > row_idx :
125131 row_group_index = i
126- local_index = idx - cumulative
132+ local_index = row_idx - cumulative
127133 break
128134 cumulative += num_rows
129135
0 commit comments