Skip to content

Commit 8741d20

Browse files
[Training] Refactor and improve validation datasets (#539)
1 parent cdc85f5 commit 8741d20

26 files changed

+993
-412
lines changed
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
{
2+
"data": [
3+
{
4+
"caption": "A large metal cylinder is seen pressing down on a pile of Oreo cookies, flattening them as if they were under a hydraulic press.",
5+
"image_path": null,
6+
"video_path": "examples/training/finetune/wan_i2v_14b_480p/crush_smol/validation_dataset/yYcK4nANZz4-Scene-034.mp4",
7+
"num_inference_steps": 50,
8+
"height": 480,
9+
"width": 832,
10+
"num_frames": 77
11+
},
12+
{
13+
"caption": "A large metal cylinder is seen compressing colorful clay into a compact shape, demonstrating the power of a hydraulic press.",
14+
"image_path": null,
15+
"video_path": "examples/training/finetune/wan_i2v_14b_480p/crush_smol/validation_dataset/yYcK4nANZz4-Scene-027.mp4",
16+
"num_inference_steps": 50,
17+
"height": 480,
18+
"width": 832,
19+
"num_frames": 77
20+
},
21+
{
22+
"caption": "A large metal cylinder is seen pressing down on a pile of colorful candies, flattening them as if they were under a hydraulic press. The candies are crushed and broken into small pieces, creating a mess on the table.",
23+
"image_path": null,
24+
"video_path": "examples/training/finetune/wan_i2v_14b_480p/crush_smol/validation_dataset/yYcK4nANZz4-Scene-030.mp4",
25+
"num_inference_steps": 50,
26+
"height": 480,
27+
"width": 832,
28+
"num_frames": 77
29+
}
30+
]
31+
}

fastvideo/v1/dataset/__init__.py

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,17 @@
1-
import os
2-
1+
# SPDX-License-Identifier: Apache-2.0
32
from torchvision import transforms
43
from torchvision.transforms import Lambda
5-
from transformers import AutoTokenizer
64

7-
from fastvideo.v1.dataset.t2v_datasets import T2V_dataset
5+
from fastvideo.v1.dataset.parquet_dataset_map_style import (
6+
build_parquet_map_style_dataloader)
7+
from fastvideo.v1.dataset.preprocessing_datasets import (
8+
VideoCaptionMergedDataset)
89
from fastvideo.v1.dataset.transform import (CenterCropResizeVideo, Normalize255,
910
TemporalRandomCrop)
10-
11-
from .parquet_dataset_map_style import build_parquet_map_style_dataloader
12-
13-
__all__ = ["build_parquet_map_style_dataloader"]
11+
from fastvideo.v1.dataset.validation_dataset import ValidationDataset
1412

1513

16-
def getdataset(args, start_idx=0) -> T2V_dataset:
14+
def getdataset(args) -> VideoCaptionMergedDataset:
1715
temporal_sample = TemporalRandomCrop(args.num_frames) # 16 x
1816
norm_fun = Lambda(lambda x: 2.0 * x - 1.0)
1917
resize_topcrop = [
@@ -31,15 +29,17 @@ def getdataset(args, start_idx=0) -> T2V_dataset:
3129
*resize_topcrop,
3230
norm_fun,
3331
])
34-
tokenizer_path = os.path.join(args.model_path, "tokenizer")
35-
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path,
36-
cache_dir=args.cache_dir)
3732
if args.dataset == "t2v":
38-
return T2V_dataset(args,
39-
transform=transform,
40-
temporal_sample=temporal_sample,
41-
tokenizer=tokenizer,
42-
transform_topcrop=transform_topcrop,
43-
start_idx=start_idx)
33+
return VideoCaptionMergedDataset(data_merge_path=args.data_merge_path,
34+
args=args,
35+
transform=transform,
36+
temporal_sample=temporal_sample,
37+
transform_topcrop=transform_topcrop)
4438

4539
raise NotImplementedError(args.dataset)
40+
41+
42+
__all__ = [
43+
"build_parquet_map_style_dataloader", "ValidationDataset",
44+
"VideoCaptionMergedDataset"
45+
]

fastvideo/v1/dataset/dataloader/schema.py

Lines changed: 77 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,48 @@
4848
pa.field("fps", pa.float64()),
4949
])
5050

51+
pyarrow_schema_i2v_validation = pa.schema([
52+
pa.field("id", pa.string()),
53+
# --- Image/Video VAE latents ---
54+
# Tensors are stored as raw bytes with shape and dtype info for loading
55+
pa.field("vae_latent_bytes", pa.binary()),
56+
# e.g., [C, T, H, W] or [C, H, W]
57+
pa.field("vae_latent_shape", pa.list_(pa.int64())),
58+
# e.g., 'float32'
59+
pa.field("vae_latent_dtype", pa.string()),
60+
# --- Text encoder output tensor ---
61+
# Tensors are stored as raw bytes with shape and dtype info for loading
62+
pa.field("text_embedding_bytes", pa.binary()),
63+
# e.g., [SeqLen, Dim]
64+
pa.field("text_embedding_shape", pa.list_(pa.int64())),
65+
# e.g., 'bfloat16' or 'float32'
66+
pa.field("text_embedding_dtype", pa.string()),
67+
pa.field("text_attention_mask_bytes", pa.binary()),
68+
# e.g., [SeqLen]
69+
pa.field("text_attention_mask_shape", pa.list_(pa.int64())),
70+
# e.g., 'bool' or 'int8'
71+
pa.field("text_attention_mask_dtype", pa.string()),
72+
#I2V
73+
pa.field("clip_feature_bytes", pa.binary()),
74+
pa.field("clip_feature_shape", pa.list_(pa.int64())),
75+
pa.field("clip_feature_dtype", pa.string()),
76+
# I2V Validation
77+
pa.field("pil_image_bytes", pa.binary()),
78+
pa.field("pil_image_shape", pa.list_(pa.int64())),
79+
pa.field("pil_image_dtype", pa.string()),
80+
# --- Metadata ---
81+
pa.field("file_name", pa.string()),
82+
pa.field("caption", pa.string()),
83+
pa.field("media_type", pa.string()), # 'image' or 'video'
84+
pa.field("width", pa.int64()),
85+
pa.field("height", pa.int64()),
86+
# -- Video-specific (can be null/default for images) ---
87+
# Number of frames processed (e.g., 1 for image, N for video)
88+
pa.field("num_frames", pa.int64()),
89+
pa.field("duration_sec", pa.float64()),
90+
pa.field("fps", pa.float64()),
91+
])
92+
5193
pyarrow_schema_t2v = pa.schema([
5294
pa.field("id", pa.string()),
5395
# --- Image/Video VAE latents ---
@@ -80,4 +122,38 @@
80122
pa.field("num_frames", pa.int64()),
81123
pa.field("duration_sec", pa.float64()),
82124
pa.field("fps", pa.float64()),
83-
])
125+
])
126+
127+
pyarrow_schema_t2v_validation = pa.schema([
128+
pa.field("id", pa.string()),
129+
# --- Image/Video VAE latents ---
130+
# Tensors are stored as raw bytes with shape and dtype info for loading
131+
pa.field("vae_latent_bytes", pa.binary()),
132+
# e.g., [C, T, H, W] or [C, H, W]
133+
pa.field("vae_latent_shape", pa.list_(pa.int64())),
134+
# e.g., 'float32'
135+
pa.field("vae_latent_dtype", pa.string()),
136+
# --- Text encoder output tensor ---
137+
# Tensors are stored as raw bytes with shape and dtype info for loading
138+
pa.field("text_embedding_bytes", pa.binary()),
139+
# e.g., [SeqLen, Dim]
140+
pa.field("text_embedding_shape", pa.list_(pa.int64())),
141+
# e.g., 'bfloat16' or 'float32'
142+
pa.field("text_embedding_dtype", pa.string()),
143+
pa.field("text_attention_mask_bytes", pa.binary()),
144+
# e.g., [SeqLen]
145+
pa.field("text_attention_mask_shape", pa.list_(pa.int64())),
146+
# e.g., 'bool' or 'int8'
147+
pa.field("text_attention_mask_dtype", pa.string()),
148+
# --- Metadata ---
149+
pa.field("file_name", pa.string()),
150+
pa.field("caption", pa.string()),
151+
pa.field("media_type", pa.string()), # 'image' or 'video'
152+
pa.field("width", pa.int64()),
153+
pa.field("height", pa.int64()),
154+
# -- Video-specific (can be null/default for images) ---
155+
# Number of frames processed (e.g., 1 for image, N for video)
156+
pa.field("num_frames", pa.int64()),
157+
pa.field("duration_sec", pa.float64()),
158+
pa.field("fps", pa.float64()),
159+
])

0 commit comments

Comments
 (0)