Skip to content

Commit 8f8ce6d

Browse files
[bugfix] [training] Add negative prompt to preprocessing and validation (#479)
Co-authored-by: Will Lin <[email protected]>
1 parent e3d0cbe commit 8f8ce6d

File tree

9 files changed

+127
-13
lines changed

9 files changed

+127
-13
lines changed

fastvideo/v1/dataset/parquet_datasets.py

Lines changed: 93 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,8 @@ def __init__(self,
3333
world_size: int = 1,
3434
cfg_rate: float = 0.0,
3535
num_latent_t: int = 2,
36-
seed: int = 0):
36+
seed: int = 0,
37+
validation: bool = False):
3738
super().__init__()
3839
self.path = str(path)
3940
self.batch_size = batch_size
@@ -47,6 +48,12 @@ def __init__(self,
4748
self.cfg_rate = cfg_rate
4849
self.num_latent_t = num_latent_t
4950
self.local_indices = None
51+
self.validation = validation
52+
53+
# Negative prompt caching
54+
self.neg_metadata = None
55+
self.cached_neg_prompt: Dict[str, Any] | None = None
56+
5057
self.plan_output_dir = os.path.join(
5158
self.path,
5259
f"data_plan_{self.world_size}_{self.sp_world_size}_{self.dp_world_size}.json"
@@ -75,6 +82,12 @@ def __init__(self,
7582
for row_idx in range(num_rows):
7683
metadatas.append((file_path, row_idx))
7784

85+
# the negative prompt is always the first row in the first
86+
# parquet file
87+
if validation:
88+
self.neg_metadata = metadatas[0]
89+
metadatas = metadatas[1:]
90+
7891
# Generate the plan that distribute rows among workers
7992
random.seed(seed)
8093
random.shuffle(metadatas)
@@ -93,9 +106,88 @@ def __init__(self,
93106
for global_rank in group_ranks_list[sp_group_idx]:
94107
plan[global_rank].append(metadata)
95108

109+
if validation:
110+
assert self.neg_metadata is not None
111+
plan["negative_prompt"] = [self.neg_metadata]
96112
with open(self.plan_output_dir, "w") as f:
97113
json.dump(plan, f)
114+
else:
115+
pass
116+
98117
dist.barrier()
118+
if validation:
119+
with open(self.plan_output_dir) as f:
120+
plan = json.load(f)
121+
self.neg_metadata = plan["negative_prompt"][0]
122+
123+
def _load_and_cache_negative_prompt(self) -> None:
124+
"""Load and cache the negative prompt. Only rank 0 in each SP group should call this."""
125+
if not self.validation or self.neg_metadata is None:
126+
return
127+
128+
if self.cached_neg_prompt is not None:
129+
return
130+
131+
# Only rank 0 in each SP group should read the negative prompt
132+
try:
133+
file_path, row_idx = self.neg_metadata
134+
parquet_file = pq.ParquetFile(file_path)
135+
136+
# Since negative prompt is always the first row (row_idx = 0),
137+
# it's always in the first row group
138+
row_group_index = 0
139+
local_index = row_idx # This will be 0 for the negative prompt
140+
141+
row_group = parquet_file.read_row_group(row_group_index).to_pydict()
142+
row_dict = {k: v[local_index] for k, v in row_group.items()}
143+
del row_group
144+
145+
# Process the negative prompt row
146+
self.cached_neg_prompt = self._process_row(row_dict)
147+
148+
except Exception as e:
149+
logger.error("Failed to load negative prompt: %s", e)
150+
self.cached_neg_prompt = None
151+
152+
def get_validation_negative_prompt(
153+
self
154+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, Dict[str, Any]]:
155+
"""
156+
Get the negative prompt for validation.
157+
This method ensures the negative prompt is loaded and cached properly.
158+
Returns the processed negative prompt data (latents, embeddings, masks, info).
159+
"""
160+
if not self.validation:
161+
raise ValueError(
162+
"get_validation_negative_prompt() can only be called in validation mode"
163+
)
164+
165+
# Load and cache if needed (only rank 0 in SP group will actually load)
166+
if self.cached_neg_prompt is None:
167+
self._load_and_cache_negative_prompt()
168+
169+
if self.cached_neg_prompt is None:
170+
raise RuntimeError(
171+
f"Rank {self.rank} (SP rank {self.local_rank}): Could not retrieve negative prompt data"
172+
)
173+
174+
# Extract the components
175+
lat, emb, mask, info = (self.cached_neg_prompt["latents"],
176+
self.cached_neg_prompt["embeddings"],
177+
self.cached_neg_prompt["masks"],
178+
self.cached_neg_prompt["info"])
179+
180+
# Apply the same processing as in __getitem__
181+
if lat.numel() == 0: # Validation parquet
182+
return lat, emb, mask, info
183+
else:
184+
lat = lat[:, -self.num_latent_t:]
185+
if self.sp_world_size > 1:
186+
lat = rearrange(lat,
187+
"t (n s) h w -> t n s h w",
188+
n=self.sp_world_size).contiguous()
189+
lat = lat[:, self.local_rank, :, :, :]
190+
return lat, emb, mask, info
99191

100192
def __len__(self):
101193
if self.local_indices is None:

fastvideo/v1/pipelines/composed_pipeline_base.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,7 @@ def from_pretrained(cls,
161161
for key, value in config_args.items():
162162
setattr(fastvideo_args, key, value)
163163

164+
fastvideo_args.num_gpus = int(os.environ.get("WORLD_SIZE", 1))
164165
fastvideo_args.use_cpu_offload = False
165166
# make sure we are in training mode
166167
fastvideo_args.inference_mode = False

fastvideo/v1/pipelines/pipeline_batch_info.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@
77
in a functional manner, reducing the need for explicit parameter passing.
88
"""
99

10-
from dataclasses import dataclass, field
10+
import pprint
11+
from dataclasses import asdict, dataclass, field
1112
from typing import Any, Dict, List, Optional, Union
1213

1314
import torch
@@ -126,4 +127,8 @@ def __post_init__(self):
126127
# Set do_classifier_free_guidance based on guidance scale and negative prompt
127128
if self.guidance_scale > 1.0:
128129
self.do_classifier_free_guidance = True
130+
if self.negative_prompt_embeds is None:
129131
self.negative_prompt_embeds = []
132+
133+
def __str__(self):
134+
return pprint.pformat(asdict(self), indent=2, width=120)

fastvideo/v1/pipelines/preprocess_pipeline_base.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from torch.utils.data.distributed import DistributedSampler
1313
from tqdm import tqdm
1414

15+
from fastvideo.v1.configs.sample import SamplingParam
1516
from fastvideo.v1.dataset import getdataset
1617
from fastvideo.v1.fastvideo_args import FastVideoArgs
1718
from fastvideo.v1.logger import init_logger
@@ -300,7 +301,10 @@ def preprocess_validation_text(self, fastvideo_args: FastVideoArgs, args):
300301

301302
# Prepare batch data for Parquet dataset
302303
batch_data = []
303-
304+
sampling_param = SamplingParam.from_pretrained(
305+
fastvideo_args.model_path)
306+
if sampling_param.negative_prompt:
307+
prompts = [sampling_param.negative_prompt] + prompts
304308
# Add progress bar for validation text preprocessing
305309
pbar = tqdm(enumerate(prompts),
306310
desc="Processing validation prompts",

fastvideo/v1/pipelines/stages/conditioning.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ def forward(
3636
Returns:
3737
The batch with applied conditioning.
3838
"""
39+
# TODO!!
3940
if not batch.do_classifier_free_guidance:
4041
return batch
4142
else:

fastvideo/v1/pipelines/wan/wan_pipeline.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ class WanPipeline(LoRAPipeline, ComposedPipelineBase):
3030
]
3131

3232
def initialize_pipeline(self, fastvideo_args: FastVideoArgs):
33+
# We use UniPCMScheduler from Wan2.1 official repo, not the one in diffusers.
3334
self.modules["scheduler"] = FlowUniPCMultistepScheduler(
3435
shift=fastvideo_args.flow_shift)
3536

fastvideo/v1/training/training_pipeline.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,11 @@ def _log_validation(self, transformer, training_args, global_step) -> None:
178178
rank=self.rank,
179179
world_size=self.world_size,
180180
cfg_rate=training_args.cfg,
181-
num_latent_t=training_args.num_latent_t)
181+
num_latent_t=training_args.num_latent_t,
182+
validation=True)
183+
if sampling_param.negative_prompt:
184+
_, negative_prompt_embeds, negative_prompt_attention_mask, _ = validation_dataset.get_validation_negative_prompt(
185+
)
182186

183187
validation_dataloader = StatefulDataLoader(
184188
validation_dataset,
@@ -194,6 +198,7 @@ def _log_validation(self, transformer, training_args, global_step) -> None:
194198

195199
# Add the transformer to the validation pipeline
196200
self.validation_pipeline.add_module("transformer", transformer)
201+
# TODO(Peiyuan): those logic should be inside add_module
197202
self.validation_pipeline.latent_preparation_stage.transformer = transformer # type: ignore[attr-defined]
198203
self.validation_pipeline.denoising_stage.transformer = transformer # type: ignore[attr-defined]
199204

@@ -221,24 +226,30 @@ def _log_validation(self, transformer, training_args, global_step) -> None:
221226
data_type="video",
222227
latents=None,
223228
seed=validation_seed, # Use deterministic seed
229+
generator=torch.Generator(
230+
device="cpu").manual_seed(validation_seed),
224231
prompt_embeds=[prompt_embeds],
225232
prompt_attention_mask=[prompt_attention_mask],
233+
negative_prompt_embeds=[negative_prompt_embeds],
234+
negative_attention_mask=[negative_prompt_attention_mask],
226235
# make sure we use the same height, width, and num_frames as the training pipeline
227236
height=training_args.num_height,
228237
width=training_args.num_width,
229238
num_frames=num_frames,
239+
# TODO(will): validation_sampling_steps and
240+
# validation_guidance_scale are actually passed in as a list of
241+
# values, like "10,20,30". The validation should be run for each
242+
# combination of values.
230243
# num_inference_steps=fastvideo_args.validation_sampling_steps,
231244
num_inference_steps=sampling_param.num_inference_steps,
232245
# guidance_scale=fastvideo_args.validation_guidance_scale,
233-
guidance_scale=1,
246+
guidance_scale=sampling_param.guidance_scale,
234247
n_tokens=n_tokens,
235-
do_classifier_free_guidance=False,
236248
eta=0.0,
237249
)
238250

239251
# Run validation inference
240-
with torch.inference_mode(), torch.autocast("cuda",
241-
dtype=torch.bfloat16):
252+
with torch.no_grad(), torch.autocast("cuda", dtype=torch.bfloat16):
242253
output_batch = self.validation_pipeline.forward(
243254
batch, training_args)
244255
samples = output_batch.output

fastvideo/v1/training/wan_training_pipeline.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -198,8 +198,7 @@ def forward(
198198
torch.manual_seed(seed)
199199
torch.cuda.manual_seed_all(seed)
200200

201-
noise_random_generator = torch.Generator(device="cpu")
202-
noise_random_generator.manual_seed(seed)
201+
noise_random_generator = torch.Generator(device="cpu").manual_seed(seed)
203202

204203
logger.info("Initialized random seeds with seed: %s", seed)
205204

@@ -271,7 +270,7 @@ def forward(
271270
gpu_memory_usage = torch.cuda.memory_allocated() / 1024**2
272271
logger.info("GPU memory usage before train_one_step: %s MB",
273272
gpu_memory_usage)
274-
273+
self._log_validation(self.transformer, self.training_args, 1)
275274
for step in range(self.init_steps + 1,
276275
self.training_args.max_train_steps + 1):
277276
start_time = time.perf_counter()

scripts/preprocess/preprocess_wan_data_t2v.sh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@
22
GPU_NUM=1 # 2,4,8
33
MODEL_PATH="Wan-AI/Wan2.1-T2V-1.3B-Diffusers"
44
MODEL_TYPE="wan"
5-
DATA_MERGE_PATH="your/path/to/Mixkit-Src/merge.txt"
6-
OUTPUT_DIR="your/path"
5+
DATA_MERGE_PATH="data/crush-smol/merge.txt"
6+
OUTPUT_DIR="data/crush-smol/latents"
77
VALIDATION_PATH="assets/prompt.txt"
88

99
torchrun --nproc_per_node=$GPU_NUM \

0 commit comments

Comments
 (0)