Skip to content

Commit 4aeabbc

Browse files
[Feature][Training] Add cfg rate for dataset loader (#556)
1 parent 949bb5c commit 4aeabbc

File tree

11 files changed

+26
-22
lines changed

11 files changed

+26
-22
lines changed

fastvideo/v1/dataset/parquet_dataset_map_style.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -200,10 +200,6 @@ def __init__(
200200
super().__init__()
201201
self.path = path
202202
self.cfg_rate = cfg_rate
203-
if cfg_rate > 0.0:
204-
raise ValueError(
205-
"cfg_rate > 0.0 is not supported for now because it will trigger bug when num_data_workers > 0"
206-
)
207203
logger.info("Initializing LatentsParquetMapStyleDataset with path: %s",
208204
path)
209205
self.parquet_files, self.lengths = get_parquet_files_and_length(path)
@@ -247,7 +243,7 @@ def get_validation_negative_prompt(
247243
[self.lengths[0]])
248244

249245
all_latents_list, all_embs_list, all_masks_list, caption_text_list = collate_latents_embs_masks(
250-
[row_dict], self.text_padding_length, self.keys)
246+
[row_dict], self.text_padding_length, self.keys, cfg_rate=0.0)
251247
all_latents, all_embs, all_masks, caption_text = all_latents_list[
252248
0], all_embs_list[0], all_masks_list[0], caption_text_list[0]
253249
# add batch dimension
@@ -268,7 +264,7 @@ def __getitems__(self, indices: List[int]):
268264
]
269265

270266
all_latents, all_embs, all_masks, caption_text = collate_latents_embs_masks(
271-
rows, self.text_padding_length, self.keys)
267+
rows, self.text_padding_length, self.keys, self.cfg_rate)
272268
return all_latents, all_embs, all_masks, caption_text
273269

274270
def __len__(self):

fastvideo/v1/dataset/preprocessing_datasets.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -461,7 +461,7 @@ def _init_stages(self, args, transform, transform_topcrop,
461461
self.text_encoding_stage = TextEncodingStage(
462462
tokenizer=tokenizer,
463463
text_max_length=args.text_max_length,
464-
cfg_rate=args.cfg)
464+
cfg_rate=args.training_cfg_rate)
465465

466466
def _load_raw_data(self) -> List[Dict]:
467467
"""Load raw data from JSON files."""

fastvideo/v1/dataset/utils.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import random
12
from typing import Any, Dict, List
23

34
import numpy as np
@@ -20,7 +21,7 @@ def pad(t: torch.Tensor, padding_length: int) -> torch.Tensor:
2021
return t[:padding_length], torch.ones(padding_length)
2122

2223

23-
def get_torch_tensors_from_row_dict(row_dict, keys) -> Dict[str, Any]:
24+
def get_torch_tensors_from_row_dict(row_dict, keys, cfg_rate) -> Dict[str, Any]:
2425
"""
2526
Get the latents and prompts from a row dictionary.
2627
"""
@@ -42,7 +43,10 @@ def get_torch_tensors_from_row_dict(row_dict, keys) -> Dict[str, Any]:
4243
bytes = row_dict[f"{key}_bytes"]
4344

4445
# TODO (peiyuan): read precision
45-
data = np.frombuffer(bytes, dtype=np.float32).reshape(shape).copy()
46+
if key == 'text_embedding' and random.random() < cfg_rate:
47+
data = np.zeros((512, 4096), dtype=np.float32)
48+
else:
49+
data = np.frombuffer(bytes, dtype=np.float32).reshape(shape).copy()
4650
data = torch.from_numpy(data)
4751
if len(data.shape) == 3:
4852
B, L, D = data.shape
@@ -53,8 +57,11 @@ def get_torch_tensors_from_row_dict(row_dict, keys) -> Dict[str, Any]:
5357

5458

5559
def collate_latents_embs_masks(
56-
batch_to_process, text_padding_length,
57-
keys) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, List[str]]:
60+
batch_to_process,
61+
text_padding_length,
62+
keys,
63+
cfg_rate=0.0
64+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, List[str]]:
5865
# Initialize tensors to hold padded embeddings and masks
5966
all_latents = []
6067
all_embs = []
@@ -63,7 +70,7 @@ def collate_latents_embs_masks(
6370
# Process each row individually
6471
for i, row in enumerate(batch_to_process):
6572
# Get tensors from row
66-
data = get_torch_tensors_from_row_dict(row, keys)
73+
data = get_torch_tensors_from_row_dict(row, keys, cfg_rate)
6774
latents, emb = data["vae_latent"], data["text_embedding"]
6875

6976
padded_emb, mask = pad(emb, text_padding_length)

fastvideo/v1/fastvideo_args.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -384,7 +384,7 @@ class TrainingArgs(FastVideoArgs):
384384
# diffusion setting
385385
ema_decay: float = 0.0
386386
ema_start_step: int = 0
387-
cfg: float = 0.0
387+
training_cfg_rate: float = 0.0
388388
precondition_outputs: bool = False
389389

390390
# validation & logs
@@ -528,7 +528,7 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
528528
type=int,
529529
default=0,
530530
help="Step to start EMA")
531-
parser.add_argument("--cfg",
531+
parser.add_argument("--training-cfg-rate",
532532
type=float,
533533
help="Classifier-free guidance scale")
534534
parser.add_argument(

fastvideo/v1/pipelines/preprocess/v1_preprocess.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ def main(args) -> None:
9191
type=str,
9292
default="google/t5-v1_1-xxl")
9393
parser.add_argument("--cache_dir", type=str, default="./cache_dir")
94-
parser.add_argument("--cfg", type=float, default=0.0)
94+
parser.add_argument("--training_cfg_rate", type=float, default=0.0)
9595
parser.add_argument(
9696
"--output_dir",
9797
type=str,

fastvideo/v1/tests/nightly/test_e2e_overfit_single_sample.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ def run_training():
122122
"--checkpoints_total_limit", "3",
123123
"--allow_tf32",
124124
"--ema_start_step", "0",
125-
"--cfg", "0.0",
125+
"--training_cfg_rate", "0.0",
126126
"--output_dir", LOCAL_OUTPUT_DIR,
127127
"--tracker_project_name", "wan_finetune_overfit_ci",
128128
"--num_height", "480",

fastvideo/v1/tests/training/VSA/test_training_loss_VSA.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ def run_worker():
5454
"--checkpoints_total_limit", "3",
5555
"--allow_tf32",
5656
"--ema_start_step", "0",
57-
"--cfg", "0.0",
57+
"--training_cfg_rate", "0.0",
5858
"--output_dir", "data/wan_finetune_test_VSA",
5959
"--tracker_project_name", "wan_finetune_ci_VSA",
6060
"--wandb_run_name", wandb_name,

fastvideo/v1/tests/training/Vanilla/test_training_loss.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ def run_worker():
5959
"--checkpoints_total_limit", "3",
6060
"--allow_tf32",
6161
"--ema_start_step", "0",
62-
"--cfg", "0.0",
62+
"--training_cfg_rate", "0.0",
6363
"--output_dir", "data/wan_finetune_test",
6464
"--tracker_project_name", "wan_finetune_ci",
6565
"--wandb_run_name", wandb_name,

fastvideo/v1/training/training_pipeline.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,7 @@ def initialize_training_pipeline(self, training_args: TrainingArgs):
105105
training_args.data_path,
106106
training_args.train_batch_size,
107107
num_data_workers=training_args.dataloader_num_workers,
108+
cfg_rate=training_args.training_cfg_rate,
108109
drop_last=True,
109110
text_padding_length=training_args.pipeline_config.
110111
text_encoder_configs[0].arch_config.
@@ -534,9 +535,9 @@ def _log_validation(self, transformer, training_args, global_step) -> None:
534535
training_args.validation_preprocessed_path,
535536
batch_size=1,
536537
num_data_workers=0,
538+
cfg_rate=0.0,
537539
drop_last=False,
538-
drop_first_row=sampling_param.negative_prompt is not None,
539-
cfg_rate=training_args.cfg)
540+
drop_first_row=sampling_param.negative_prompt is not None)
540541
if sampling_param.negative_prompt:
541542
_, negative_prompt_embeds, negative_prompt_attention_mask, _ = validation_dataset.get_validation_negative_prompt(
542543
)

scripts/finetune/finetune_v1.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ torchrun --nnodes 1 --nproc_per_node $NUM_GPUS\
3636
--checkpoints_total_limit 3\
3737
--allow_tf32\
3838
--ema_start_step 0\
39-
--cfg 0.0\
39+
--training_cfg_rate 0.0\
4040
--output_dir="$DATA_DIR/outputs/wan_finetune"\
4141
--tracker_project_name wan_finetune \
4242
--num_height 480 \

0 commit comments

Comments
 (0)