@@ -33,7 +33,8 @@ def __init__(self,
3333 world_size : int = 1 ,
3434 cfg_rate : float = 0.0 ,
3535 num_latent_t : int = 2 ,
36- seed : int = 0 ):
36+ seed : int = 0 ,
37+ validation : bool = False ):
3738 super ().__init__ ()
3839 self .path = str (path )
3940 self .batch_size = batch_size
@@ -47,6 +48,12 @@ def __init__(self,
4748 self .cfg_rate = cfg_rate
4849 self .num_latent_t = num_latent_t
4950 self .local_indices = None
51+ self .validation = validation
52+
53+ # Negative prompt caching
54+ self .neg_metadata = None
55+ self .cached_neg_prompt : Dict [str , Any ] | None = None
56+
5057 self .plan_output_dir = os .path .join (
5158 self .path ,
5259 f"data_plan_{ self .world_size } _{ self .sp_world_size } _{ self .dp_world_size } .json"
@@ -75,6 +82,12 @@ def __init__(self,
7582 for row_idx in range (num_rows ):
7683 metadatas .append ((file_path , row_idx ))
7784
85+ # the negative prompt is always the first row in the first
86+ # parquet file
87+ if validation :
88+ self .neg_metadata = metadatas [0 ]
89+ metadatas = metadatas [1 :]
90+
7891 # Generate the plan that distribute rows among workers
7992 random .seed (seed )
8093 random .shuffle (metadatas )
@@ -93,9 +106,88 @@ def __init__(self,
93106 for global_rank in group_ranks_list [sp_group_idx ]:
94107 plan [global_rank ].append (metadata )
95108
109+ if validation :
110+ assert self .neg_metadata is not None
111+ plan ["negative_prompt" ] = [self .neg_metadata ]
96112 with open (self .plan_output_dir , "w" ) as f :
97113 json .dump (plan , f )
114+ else :
115+ pass
116+
98117 dist .barrier ()
118+ if validation :
119+ with open (self .plan_output_dir ) as f :
120+ plan = json .load (f )
121+ self .neg_metadata = plan ["negative_prompt" ][0 ]
122+
123+ def _load_and_cache_negative_prompt (self ) -> None :
124+ """Load and cache the negative prompt. Only rank 0 in each SP group should call this."""
125+ if not self .validation or self .neg_metadata is None :
126+ return
127+
128+ if self .cached_neg_prompt is not None :
129+ return
130+
131+ # Only rank 0 in each SP group should read the negative prompt
132+ try :
133+ file_path , row_idx = self .neg_metadata
134+ parquet_file = pq .ParquetFile (file_path )
135+
136+ # Since negative prompt is always the first row (row_idx = 0),
137+ # it's always in the first row group
138+ row_group_index = 0
139+ local_index = row_idx # This will be 0 for the negative prompt
140+
141+ row_group = parquet_file .read_row_group (row_group_index ).to_pydict ()
142+ row_dict = {k : v [local_index ] for k , v in row_group .items ()}
143+ del row_group
144+
145+ # Process the negative prompt row
146+ self .cached_neg_prompt = self ._process_row (row_dict )
147+
148+ except Exception as e :
149+ logger .error ("Failed to load negative prompt: %s" , e )
150+ self .cached_neg_prompt = None
151+
152+ def get_validation_negative_prompt (
153+ self
154+ ) -> tuple [torch .Tensor , torch .Tensor , torch .Tensor , Dict [str , Any ]]:
155+ """
156+ Get the negative prompt for validation.
157+ This method ensures the negative prompt is loaded and cached properly.
158+ Returns the processed negative prompt data (latents, embeddings, masks, info).
159+ """
160+ if not self .validation :
161+ raise ValueError (
162+ "get_validation_negative_prompt() can only be called in validation mode"
163+ )
164+
165+ # Load and cache if needed (only rank 0 in SP group will actually load)
166+ if self .cached_neg_prompt is None :
167+ self ._load_and_cache_negative_prompt ()
168+
169+ if self .cached_neg_prompt is None :
170+ raise RuntimeError (
171+ f"Rank { self .rank } (SP rank { self .local_rank } ): Could not retrieve negative prompt data"
172+ )
173+
174+ # Extract the components
175+ lat , emb , mask , info = (self .cached_neg_prompt ["latents" ],
176+ self .cached_neg_prompt ["embeddings" ],
177+ self .cached_neg_prompt ["masks" ],
178+ self .cached_neg_prompt ["info" ])
179+
180+ # Apply the same processing as in __getitem__
181+ if lat .numel () == 0 : # Validation parquet
182+ return lat , emb , mask , info
183+ else :
184+ lat = lat [:, - self .num_latent_t :]
185+ if self .sp_world_size > 1 :
186+ lat = rearrange (lat ,
187+ "t (n s) h w -> t n s h w" ,
188+ n = self .sp_world_size ).contiguous ()
189+ lat = lat [:, self .local_rank , :, :, :]
190+ return lat , emb , mask , info
99191
100192 def __len__ (self ):
101193 if self .local_indices is None :
0 commit comments