Skip to content

Commit e55fa6e

Browse files
[Preprocess] I2V dataset (#473)
1 parent 61b6dde commit e55fa6e

File tree

11 files changed

+534
-239
lines changed

11 files changed

+534
-239
lines changed

fastvideo/data_preprocess/preprocess.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@
1111
from fastvideo.v1.fastvideo_args import FastVideoArgs
1212
from fastvideo.v1.configs.models.vaes import WanVAEConfig
1313
from fastvideo import PipelineConfig
14-
from fastvideo.v1.pipelines.preprocess_pipeline import PreprocessPipeline
14+
from fastvideo.v1.pipelines.preprocess.preprocess_pipeline_i2v import PreprocessPipeline_I2V
15+
from fastvideo.v1.pipelines.preprocess.preprocess_pipeline_t2v import PreprocessPipeline_T2V
1516

1617
logger = init_logger(__name__)
1718

@@ -42,7 +43,7 @@ def main(args):
4243
)
4344
fastvideo_args.check_fastvideo_args()
4445
fastvideo_args.device = torch.device(f"cuda:{local_rank}")
45-
46+
PreprocessPipeline = PreprocessPipeline_I2V if args.preprocess_task == "i2v" else PreprocessPipeline_T2V
4647
pipeline = PreprocessPipeline(args.model_path, fastvideo_args)
4748
pipeline.forward(batch=None, fastvideo_args=fastvideo_args, args=args)
4849

@@ -91,6 +92,7 @@ def main(args):
9192
parser.add_argument("--group_frame", action="store_true") # TODO
9293
parser.add_argument("--group_resolution", action="store_true") # TODO
9394
parser.add_argument("--dataset", default="t2v")
95+
parser.add_argument("--preprocess_task", type=str, default="t2v")
9496
parser.add_argument("--train_fps", type=int, default=30)
9597
parser.add_argument("--use_image_num", type=int, default=0)
9698
parser.add_argument("--text_max_length", type=int, default=256)

fastvideo/v1/dataset/dataloader/schema.py

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
import pyarrow as pa
1111

12-
pyarrow_schema = pa.schema([
12+
pyarrow_schema_i2v = pa.schema([
1313
pa.field("id", pa.string()),
1414
# --- Image/Video VAE latents ---
1515
# Tensors are stored as raw bytes with shape and dtype info for loading
@@ -30,6 +30,10 @@
3030
pa.field("text_attention_mask_shape", pa.list_(pa.int64())),
3131
# e.g., 'bool' or 'int8'
3232
pa.field("text_attention_mask_dtype", pa.string()),
33+
#I2V
34+
pa.field("clip_feature_bytes", pa.binary()),
35+
pa.field("clip_feature_shape", pa.list_(pa.int64())),
36+
pa.field("clip_feature_dtype", pa.string()),
3337
# --- Metadata ---
3438
pa.field("file_name", pa.string()),
3539
pa.field("caption", pa.string()),
@@ -42,3 +46,37 @@
4246
pa.field("duration_sec", pa.float64()),
4347
pa.field("fps", pa.float64()),
4448
])
49+
50+
pyarrow_schema_t2v = pa.schema([
51+
pa.field("id", pa.string()),
52+
# --- Image/Video VAE latents ---
53+
# Tensors are stored as raw bytes with shape and dtype info for loading
54+
pa.field("vae_latent_bytes", pa.binary()),
55+
# e.g., [C, T, H, W] or [C, H, W]
56+
pa.field("vae_latent_shape", pa.list_(pa.int64())),
57+
# e.g., 'float32'
58+
pa.field("vae_latent_dtype", pa.string()),
59+
# --- Text encoder output tensor ---
60+
# Tensors are stored as raw bytes with shape and dtype info for loading
61+
pa.field("text_embedding_bytes", pa.binary()),
62+
# e.g., [SeqLen, Dim]
63+
pa.field("text_embedding_shape", pa.list_(pa.int64())),
64+
# e.g., 'bfloat16' or 'float32'
65+
pa.field("text_embedding_dtype", pa.string()),
66+
pa.field("text_attention_mask_bytes", pa.binary()),
67+
# e.g., [SeqLen]
68+
pa.field("text_attention_mask_shape", pa.list_(pa.int64())),
69+
# e.g., 'bool' or 'int8'
70+
pa.field("text_attention_mask_dtype", pa.string()),
71+
# --- Metadata ---
72+
pa.field("file_name", pa.string()),
73+
pa.field("caption", pa.string()),
74+
pa.field("media_type", pa.string()), # 'image' or 'video'
75+
pa.field("width", pa.int64()),
76+
pa.field("height", pa.int64()),
77+
# -- Video-specific (can be null/default for images) ---
78+
# Number of frames processed (e.g., 1 for image, N for video)
79+
pa.field("num_frames", pa.int64()),
80+
pa.field("duration_sec", pa.float64()),
81+
pa.field("fps", pa.float64()),
82+
])
Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
import argparse
2+
import json
3+
import os
4+
import time
5+
from multiprocessing import Pool, cpu_count
6+
from pathlib import Path
7+
8+
import torchvision
9+
from tqdm import tqdm
10+
11+
12+
def get_video_info(video_path):
13+
"""Get video information using torchvision."""
14+
# Read video tensor (T, C, H, W)
15+
video_tensor, _, info = torchvision.io.read_video(str(video_path),
16+
output_format="TCHW",
17+
pts_unit="sec")
18+
19+
num_frames = video_tensor.shape[0]
20+
height = video_tensor.shape[2]
21+
width = video_tensor.shape[3]
22+
fps = info.get("video_fps", 0)
23+
duration = num_frames / fps if fps > 0 else 0
24+
25+
# Extract name
26+
_, _, videos_dir, video_name = str(video_path).split("/")
27+
28+
return {
29+
"path": str(video_name),
30+
"resolution": {
31+
"width": width,
32+
"height": height
33+
},
34+
"size": os.path.getsize(video_path),
35+
"fps": fps,
36+
"duration": duration,
37+
"num_frames": num_frames
38+
}
39+
40+
41+
def prepare_dataset_json(folder_path,
42+
output_name="videos2caption.json",
43+
num_workers=None) -> None:
44+
"""Prepare dataset information from a folder containing videos and prompt.txt."""
45+
folder_path = Path(folder_path)
46+
47+
# Read prompt file
48+
prompt_file = folder_path / "prompt.txt"
49+
if not prompt_file.exists():
50+
raise FileNotFoundError(f"prompt.txt not found in {folder_path}")
51+
52+
with open(prompt_file) as f:
53+
prompts = [line.strip() for line in f.readlines() if line.strip()]
54+
55+
# Read videos file
56+
videos_file = folder_path / "videos.txt"
57+
if not videos_file.exists():
58+
raise FileNotFoundError(f"videos.txt not found in {folder_path}")
59+
60+
with open(videos_file) as f:
61+
video_paths = [line.strip() for line in f.readlines() if line.strip()]
62+
63+
if len(prompts) != len(video_paths):
64+
raise ValueError(
65+
f"Number of prompts ({len(prompts)}) does not match number of videos ({len(video_paths)})"
66+
)
67+
68+
# Prepare arguments for multiprocessing
69+
process_args = [folder_path / video_path for video_path in video_paths]
70+
71+
# Determine number of workers
72+
if num_workers is None:
73+
num_workers = max(1, cpu_count() - 1) # Leave one CPU free
74+
75+
# Process videos in parallel
76+
start_time = time.time()
77+
with Pool(num_workers) as pool:
78+
results = list(
79+
tqdm(pool.imap(get_video_info, process_args),
80+
total=len(process_args),
81+
desc="Processing videos",
82+
unit="video"))
83+
84+
# Combine results with prompts
85+
dataset_info = []
86+
for result, prompt in zip(results, prompts):
87+
result["cap"] = [prompt]
88+
dataset_info.append(result)
89+
90+
# Calculate total processing time
91+
total_time = time.time() - start_time
92+
total_videos = len(dataset_info)
93+
avg_time_per_video = total_time / total_videos if total_videos > 0 else 0
94+
95+
print("\nProcessing completed:")
96+
print(f"Total videos processed: {total_videos}")
97+
print(f"Total time: {total_time:.2f} seconds")
98+
print(f"Average time per video: {avg_time_per_video:.2f} seconds")
99+
100+
# Save to JSON file
101+
output_file = folder_path / output_name
102+
with open(output_file, 'w') as f:
103+
json.dump(dataset_info, f, indent=2)
104+
105+
# Create merge.txt
106+
merge_file = folder_path / "merge.txt"
107+
with open(merge_file, 'w') as f:
108+
f.write(f"{folder_path}/videos,{output_file}\n")
109+
110+
print(f"Dataset information saved to {output_file}")
111+
print(f"Merge file created at {merge_file}")
112+
113+
114+
def parse_args() -> argparse.Namespace:
115+
parser = argparse.ArgumentParser(
116+
description='Prepare video dataset information in JSON format')
117+
parser.add_argument(
118+
'--folder',
119+
type=str,
120+
required=True,
121+
help='Path to the folder containing videos and prompt.txt')
122+
parser.add_argument(
123+
'--output',
124+
type=str,
125+
default='videos2caption.json',
126+
help='Name of the output JSON file (default: videos2caption.json)')
127+
parser.add_argument('--workers',
128+
type=int,
129+
default=32,
130+
help='Number of worker processes (default: 16)')
131+
return parser.parse_args()
132+
133+
134+
if __name__ == "__main__":
135+
args = parse_args()
136+
prepare_dataset_json(args.folder, args.output, args.workers)

fastvideo/v1/dataset/t2v_datasets.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,7 @@ def get_video(self, idx) -> dict:
138138
video_path = dataset_prog.cap_list[idx]["path"]
139139
assert os.path.exists(video_path), f"file {video_path} do not exist!"
140140
frame_indices = dataset_prog.cap_list[idx]["sample_frame_index"]
141+
141142
torchvision_video, _, metadata = torchvision.io.read_video(
142143
video_path, output_format="TCHW")
143144
video = torchvision_video[frame_indices]
@@ -270,7 +271,8 @@ def define_frame_index(self, cap_list) -> tuple[list[dict], list[int]]:
270271
cnt_resolution_mismatch += 1
271272
continue
272273

273-
# import ipdb;ipdb.set_trace()
274+
# if path == 'finetrainers/3dgs-dissolve/videos/1.mp4':
275+
# from IPython import embed; embed()
274276
i["num_frames"] = math.ceil(fps * duration)
275277
# max 5.0 and min 1.0 are just thresholds to filter some videos which have suitable duration.
276278
if i["num_frames"] / fps > self.video_length_tolerance_range * (

fastvideo/v1/fastvideo_args.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ class FastVideoArgs:
9292
lora_target_names: Optional[List[
9393
str]] = None # can restrict list of layers to adapt, e.g. ["q_proj"]
9494

95-
# STA (Spatial-Temporal Attention) parameters
95+
# STA parameters
9696
mask_strategy_file_path: Optional[str] = None
9797
enable_torch_compile: bool = False
9898

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
"""
3+
I2V Data Preprocessing pipeline implementation.
4+
5+
This module contains an implementation of the I2V Data Preprocessing pipeline
6+
using the modular pipeline architecture.
7+
"""
8+
from typing import Any, Dict, List, Optional
9+
10+
import numpy as np
11+
import torch
12+
from PIL import Image
13+
14+
from fastvideo.v1.dataset.dataloader.schema import pyarrow_schema_i2v
15+
from fastvideo.v1.fastvideo_args import FastVideoArgs
16+
from fastvideo.v1.forward_context import set_forward_context
17+
from fastvideo.v1.pipelines.preprocess_pipeline_base import (
18+
BasePreprocessPipeline)
19+
20+
21+
class PreprocessPipeline_I2V(BasePreprocessPipeline):
22+
"""I2V preprocessing pipeline implementation."""
23+
24+
_required_config_modules = [
25+
"text_encoder", "tokenizer", "vae", "image_encoder", "image_processor"
26+
]
27+
28+
def get_schema_fields(self) -> List[str]:
29+
"""Get the schema fields for I2V pipeline."""
30+
return [f.name for f in pyarrow_schema_i2v]
31+
32+
def get_extra_features(self, valid_data: Dict[str, Any],
33+
fastvideo_args: FastVideoArgs) -> Dict[str, Any]:
34+
"""Get CLIP features from the first frame of each video."""
35+
first_frame = valid_data["pixel_values"][:, :, 0, :, :].permute(
36+
0, 2, 3, 1) # (B, C, T, H, W) -> (B, H, W, C)
37+
38+
processed_images = []
39+
for frame in first_frame:
40+
frame_pil = Image.fromarray(frame.cpu().numpy().astype(np.uint8))
41+
processed_img = self.get_module("image_processor")(
42+
images=frame_pil, return_tensors="pt")
43+
processed_images.append(processed_img)
44+
45+
# Get CLIP features
46+
pixel_values = torch.cat(
47+
[img['pixel_values'] for img in processed_images],
48+
dim=0).to(fastvideo_args.device)
49+
with torch.no_grad():
50+
image_inputs = {'pixel_values': pixel_values}
51+
with set_forward_context(current_timestep=0, attn_metadata=None):
52+
clip_features = self.get_module("image_encoder")(**image_inputs)
53+
clip_features = clip_features.last_hidden_state
54+
55+
return {"clip_feature": clip_features}
56+
57+
def create_record(
58+
self,
59+
video_name: str,
60+
vae_latent: np.ndarray,
61+
text_embedding: np.ndarray,
62+
text_attention_mask: np.ndarray,
63+
valid_data: Optional[Dict[str, Any]],
64+
idx: int,
65+
extra_features: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
66+
"""Create a record for the Parquet dataset with CLIP features."""
67+
record = super().create_record(video_name=video_name,
68+
vae_latent=vae_latent,
69+
text_embedding=text_embedding,
70+
text_attention_mask=text_attention_mask,
71+
valid_data=valid_data,
72+
idx=idx,
73+
extra_features=extra_features)
74+
75+
if extra_features and "clip_feature" in extra_features:
76+
clip_feature = extra_features["clip_feature"]
77+
record.update({
78+
"clip_feature_bytes": clip_feature.tobytes(),
79+
"clip_feature_shape": list(clip_feature.shape),
80+
"clip_feature_dtype": str(clip_feature.dtype),
81+
})
82+
else:
83+
record.update({
84+
"clip_feature_bytes": b"",
85+
"clip_feature_shape": [],
86+
"clip_feature_dtype": "",
87+
})
88+
89+
return record
90+
91+
92+
EntryClass = PreprocessPipeline_I2V
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
"""
3+
T2V Data Preprocessing pipeline implementation.
4+
5+
This module contains an implementation of the T2V Data Preprocessing pipeline
6+
using the modular pipeline architecture.
7+
"""
8+
from fastvideo.v1.dataset.dataloader.schema import pyarrow_schema_t2v
9+
from fastvideo.v1.pipelines.preprocess_pipeline_base import (
10+
BasePreprocessPipeline)
11+
12+
13+
class PreprocessPipeline_T2V(BasePreprocessPipeline):
14+
"""T2V preprocessing pipeline implementation."""
15+
16+
_required_config_modules = ["text_encoder", "tokenizer", "vae"]
17+
18+
def get_schema_fields(self):
19+
"""Get the schema fields for T2V pipeline."""
20+
return [f.name for f in pyarrow_schema_t2v]
21+
22+
23+
EntryClass = PreprocessPipeline_T2V

0 commit comments

Comments
 (0)