Skip to content

Commit 3ab6470

Browse files
authored
[Feat][Preprocess] support merged dataset (#752)
1 parent 989a035 commit 3ab6470

File tree

5 files changed

+76
-9
lines changed

5 files changed

+76
-9
lines changed

examples/training/finetune/wan_i2v_14B_480p/crush_smol/preprocess_wan_data_i2v_new.sh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,15 @@
22

33
GPU_NUM=1 # 2,4,8
44
MODEL_PATH="Wan-AI/Wan2.1-I2V-14B-480P-Diffusers"
5-
# MODEL_PATH="/home/eigensystem/.cache/huggingface/hub/models--Wan-AI--Wan2.1-I2V-14B-480P-Diffusers/snapshots/b184e23a8a16b20f108f727c902e769e873ffc73/"
6-
DATASET_PATH="data/crush-smol-test/"
5+
DATASET_PATH="data/crush-smol/"
76
OUTPUT_DIR="data/crush-smol_processed_i2v/"
87

98
torchrun --nproc_per_node=$GPU_NUM \
109
-m fastvideo.pipelines.preprocess.v1_preprocessing_new \
1110
--model_path $MODEL_PATH \
1211
--mode preprocess \
1312
--workload_type i2v \
13+
--preprocess.dataset_type merged \
1414
--preprocess.dataset_path $DATASET_PATH \
1515
--preprocess.dataset_output_dir $OUTPUT_DIR \
1616
--preprocess.preprocess_video_batch_size 2 \

examples/training/finetune/wan_t2v_1.3B/crush_smol/preprocess_wan_data_t2v_new.sh

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,15 @@
22

33
GPU_NUM=1 # 2,4,8
44
MODEL_PATH="Wan-AI/Wan2.1-T2V-1.3B-Diffusers"
5-
DATASET_PATH="data/crush-smol-test/"
5+
DATASET_PATH="data/crush-smol/"
66
OUTPUT_DIR="data/crush-smol_processed_t2v/"
77

88
torchrun --nproc_per_node=$GPU_NUM \
99
-m fastvideo.pipelines.preprocess.v1_preprocessing_new \
1010
--model_path $MODEL_PATH \
1111
--mode preprocess \
1212
--workload_type t2v \
13+
--preprocess.dataset_type merged \
1314
--preprocess.dataset_path $DATASET_PATH \
1415
--preprocess.dataset_output_dir $OUTPUT_DIR \
1516
--preprocess.preprocess_video_batch_size 2 \

fastvideo/configs/configs.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,42 @@
11
import dataclasses
2+
from enum import Enum
23
from typing import Any, Optional
34

45
from fastvideo.configs.utils import update_config_from_args
56
from fastvideo.utils import FlexibleArgumentParser, StoreBoolean
67

78

9+
class DatasetType(str, Enum):
10+
"""
11+
Enumeration for different dataset types.
12+
"""
13+
HF = "hf"
14+
MERGED = "merged"
15+
16+
@classmethod
17+
def from_string(cls, value: str) -> "DatasetType":
18+
"""Convert string to DatasetType enum."""
19+
try:
20+
return cls(value.lower())
21+
except ValueError:
22+
raise ValueError(
23+
f"Invalid dataset type: {value}. Must be one of: {', '.join([m.value for m in cls])}"
24+
) from None
25+
26+
@classmethod
27+
def choices(cls) -> list[str]:
28+
"""Get all available choices as strings for argparse."""
29+
return [dataset_type.value for dataset_type in cls]
30+
31+
832
@dataclasses.dataclass
933
class PreprocessConfig:
1034
"""Configuration for preprocessing operations."""
1135

1236
# Model and dataset configuration
1337
model_path: str = ""
1438
dataset_path: str = ""
39+
dataset_type: DatasetType = DatasetType.HF
1540
dataset_output_dir: str = "./output"
1641

1742
# Dataloader configuration
@@ -54,6 +79,12 @@ def add_cli_args(parser: FlexibleArgumentParser,
5479
type=str,
5580
default=PreprocessConfig.dataset_path,
5681
help="Path to the dataset directory for preprocessing")
82+
preprocess_args.add_argument(
83+
f"--{prefix_with_dot}dataset-type",
84+
type=str,
85+
choices=DatasetType.choices(),
86+
default=PreprocessConfig.dataset_type.value,
87+
help="Type of the dataset")
5788
preprocess_args.add_argument(
5889
f"--{prefix_with_dot}dataset-output-dir",
5990
type=str,
@@ -136,6 +167,10 @@ def add_cli_args(parser: FlexibleArgumentParser,
136167
def from_kwargs(cls, kwargs: dict[str,
137168
Any]) -> Optional["PreprocessConfig"]:
138169
"""Create PreprocessConfig from keyword arguments."""
170+
if 'dataset_type' in kwargs and isinstance(kwargs['dataset_type'], str):
171+
kwargs['dataset_type'] = DatasetType.from_string(
172+
kwargs['dataset_type'])
173+
139174
preprocess_config = cls()
140175
if not update_config_from_args(
141176
preprocess_config, kwargs, prefix="preprocess", pop_args=True):

fastvideo/workflow/preprocess/components.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,9 @@
1111
import pyarrow as pa
1212
import pyarrow.parquet as pq
1313
import torch
14+
from datasets import Dataset, Video, load_dataset
1415

16+
from fastvideo.configs.configs import DatasetType, PreprocessConfig
1517
from fastvideo.logger import init_logger
1618
from fastvideo.pipelines.pipeline_batch_info import PreprocessBatch
1719

@@ -395,3 +397,33 @@ def _default_file_writer_fn(self, args_tuple: tuple) -> int:
395397
written_count += len(chunk_table)
396398

397399
return written_count
400+
401+
402+
def build_dataset(preprocess_config: PreprocessConfig, split: str) -> Dataset:
403+
if preprocess_config.dataset_type == DatasetType.HF:
404+
dataset = load_dataset(preprocess_config.dataset_path, split=split)
405+
elif preprocess_config.dataset_type == DatasetType.MERGED:
406+
metadata_json_path = os.path.join(preprocess_config.dataset_path,
407+
"videos2caption.json")
408+
video_folder = os.path.join(preprocess_config.dataset_path, "videos")
409+
dataset = load_dataset("json",
410+
data_files=metadata_json_path,
411+
split=split)
412+
column_names = dataset.column_names
413+
# rename columns to match the schema
414+
if "cap" in column_names:
415+
dataset = dataset.rename_column("cap", "caption")
416+
if "path" in column_names:
417+
dataset = dataset.rename_column("path", "name")
418+
# add video column
419+
def add_video_column(item: dict[str, Any]) -> dict[str, Any]:
420+
item["video"] = os.path.join(video_folder, item["name"])
421+
return item
422+
423+
dataset = dataset.map(add_video_column)
424+
dataset = dataset.cast_column("video", Video())
425+
else:
426+
raise ValueError(
427+
f"Invalid dataset type: {preprocess_config.dataset_type}")
428+
429+
return dataset

fastvideo/workflow/preprocess/preprocess_workflow.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import os
22
from typing import cast
33

4-
from datasets import load_dataset
54
from torch.utils.data import DataLoader
65

76
from fastvideo.configs.configs import PreprocessConfig
@@ -11,7 +10,8 @@
1110
from fastvideo.logger import init_logger
1211
from fastvideo.pipelines.pipeline_registry import PipelineType
1312
from fastvideo.workflow.preprocess.components import (
14-
ParquetDatasetSaver, PreprocessingDataValidator, VideoForwardBatchBuilder)
13+
ParquetDatasetSaver, PreprocessingDataValidator, VideoForwardBatchBuilder,
14+
build_dataset)
1515
from fastvideo.workflow.preprocess.record_schema import (
1616
basic_t2v_record_creator, i2v_record_creator)
1717
from fastvideo.workflow.workflow_base import WorkflowBase
@@ -43,8 +43,7 @@ def register_components(self) -> None:
4343
self.add_component("raw_data_validator", raw_data_validator)
4444

4545
# training dataset
46-
training_dataset = load_dataset(preprocess_config.dataset_path,
47-
split="train")
46+
training_dataset = build_dataset(preprocess_config, split="train")
4847
# set load_from_cache_file to False to check filter stats
4948
training_dataset = training_dataset.filter(raw_data_validator)
5049
# we do not use collate_fn here because we use iterable-style Dataset
@@ -59,8 +58,8 @@ def register_components(self) -> None:
5958

6059
# try to load validation dataset if it exists
6160
try:
62-
validation_dataset = load_dataset(preprocess_config.dataset_path,
63-
split="validation")
61+
validation_dataset = build_dataset(preprocess_config,
62+
split="validation")
6463
validation_dataset = validation_dataset.filter(raw_data_validator)
6564
validation_dataloader = DataLoader(
6665
validation_dataset,

0 commit comments

Comments
 (0)