Skip to content

Commit d30879a

Browse files
EigensystemJiayiZhangA
authored andcommitted
[Feat][Preprocess] support merged dataset (#752)
1 parent 766b931 commit d30879a

File tree

3 files changed

+44
-0
lines changed

3 files changed

+44
-0
lines changed

fastvideo/configs/configs.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,29 @@
99
logger = init_logger(__name__)
1010

1111

12+
class DatasetType(str, Enum):
13+
"""
14+
Enumeration for different dataset types.
15+
"""
16+
HF = "hf"
17+
MERGED = "merged"
18+
19+
@classmethod
20+
def from_string(cls, value: str) -> "DatasetType":
21+
"""Convert string to DatasetType enum."""
22+
try:
23+
return cls(value.lower())
24+
except ValueError:
25+
raise ValueError(
26+
f"Invalid dataset type: {value}. Must be one of: {', '.join([m.value for m in cls])}"
27+
) from None
28+
29+
@classmethod
30+
def choices(cls) -> list[str]:
31+
"""Get all available choices as strings for argparse."""
32+
return [dataset_type.value for dataset_type in cls]
33+
34+
1235
class DatasetType(str, Enum):
1336
"""
1437
Enumeration for different dataset types.

fastvideo/workflow/preprocess/components.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,10 @@
1414
from datasets import Dataset, Video, load_dataset
1515

1616
from fastvideo.configs.configs import DatasetType, PreprocessConfig
17+
<<<<<<< HEAD
1718
from fastvideo.distributed.parallel_state import get_world_rank, get_world_size
19+
=======
20+
>>>>>>> 15df36ab ([Feat][Preprocess] support merged dataset (#752))
1821
from fastvideo.logger import init_logger
1922
from fastvideo.pipelines.pipeline_batch_info import PreprocessBatch
2023

@@ -402,13 +405,19 @@ def _default_file_writer_fn(self, args_tuple: tuple) -> int:
402405
return written_count
403406

404407

408+
<<<<<<< HEAD
405409
def build_dataset(preprocess_config: PreprocessConfig, split: str,
406410
validator: Callable[[dict[str, Any]], bool]) -> Dataset:
407411
if preprocess_config.dataset_type == DatasetType.HF:
408412
dataset = load_dataset(preprocess_config.dataset_path, split=split)
409413
dataset = dataset.filter(validator)
410414
dataset = dataset.shard(num_shards=get_world_size(),
411415
index=get_world_rank())
416+
=======
417+
def build_dataset(preprocess_config: PreprocessConfig, split: str) -> Dataset:
418+
if preprocess_config.dataset_type == DatasetType.HF:
419+
dataset = load_dataset(preprocess_config.dataset_path, split=split)
420+
>>>>>>> 15df36ab ([Feat][Preprocess] support merged dataset (#752))
412421
elif preprocess_config.dataset_type == DatasetType.MERGED:
413422
metadata_json_path = os.path.join(preprocess_config.dataset_path,
414423
"videos2caption.json")
@@ -422,11 +431,14 @@ def build_dataset(preprocess_config: PreprocessConfig, split: str,
422431
dataset = dataset.rename_column("cap", "caption")
423432
if "path" in column_names:
424433
dataset = dataset.rename_column("path", "name")
434+
<<<<<<< HEAD
425435

426436
dataset = dataset.filter(validator)
427437
dataset = dataset.shard(num_shards=get_world_size(),
428438
index=get_world_rank())
429439

440+
=======
441+
>>>>>>> 15df36ab ([Feat][Preprocess] support merged dataset (#752))
430442
# add video column
431443
def add_video_column(item: dict[str, Any]) -> dict[str, Any]:
432444
item["video"] = os.path.join(video_folder, item["name"])

fastvideo/workflow/preprocess/preprocess_workflow.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,9 +44,13 @@ def register_components(self) -> None:
4444
self.add_component("raw_data_validator", raw_data_validator)
4545

4646
# training dataset
47+
<<<<<<< HEAD
4748
training_dataset = build_dataset(preprocess_config,
4849
split="train",
4950
validator=raw_data_validator)
51+
=======
52+
training_dataset = build_dataset(preprocess_config, split="train")
53+
>>>>>>> 15df36ab ([Feat][Preprocess] support merged dataset (#752))
5054
# set load_from_cache_file to False to check filter stats
5155
training_dataset = training_dataset.filter(raw_data_validator)
5256
# we do not use collate_fn here because we use iterable-style Dataset
@@ -62,8 +66,13 @@ def register_components(self) -> None:
6266
# try to load validation dataset if it exists
6367
try:
6468
validation_dataset = build_dataset(preprocess_config,
69+
<<<<<<< HEAD
6570
split="validation",
6671
validator=raw_data_validator)
72+
=======
73+
split="validation")
74+
validation_dataset = validation_dataset.filter(raw_data_validator)
75+
>>>>>>> 15df36ab ([Feat][Preprocess] support merged dataset (#752))
6776
validation_dataloader = DataLoader(
6877
validation_dataset,
6978
batch_size=preprocess_config.preprocess_video_batch_size,

0 commit comments

Comments
 (0)