Skip to content

Commit 7ab3253

Browse files
SolitaryThinkerJerryZhou54BrianChen1129
authored
[Training] [1/n] Add latent datasets (#438)
Co-authored-by: Wei Zhou <[email protected]> Co-authored-by: JerryZhou54 <[email protected]> Co-authored-by: “BrianChen1129” <[email protected]>
1 parent 6ef8fcb commit 7ab3253

File tree

7 files changed

+1093
-0
lines changed

7 files changed

+1093
-0
lines changed

fastvideo/v1/dataset/__init__.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
from torchvision import transforms
2+
from torchvision.transforms import Lambda
3+
from transformers import AutoTokenizer
4+
5+
from fastvideo.v1.dataset.t2v_datasets import T2V_dataset
6+
from fastvideo.v1.dataset.transform import (CenterCropResizeVideo, Normalize255,
7+
TemporalRandomCrop)
8+
9+
10+
def getdataset(args, start_idx=0) -> T2V_dataset:
11+
temporal_sample = TemporalRandomCrop(args.num_frames) # 16 x
12+
norm_fun = Lambda(lambda x: 2.0 * x - 1.0)
13+
resize_topcrop = [
14+
CenterCropResizeVideo((args.max_height, args.max_width), top_crop=True),
15+
]
16+
resize = [
17+
CenterCropResizeVideo((args.max_height, args.max_width)),
18+
]
19+
transform = transforms.Compose([
20+
# Normalize255(),
21+
*resize,
22+
])
23+
transform_topcrop = transforms.Compose([
24+
Normalize255(),
25+
*resize_topcrop,
26+
norm_fun,
27+
])
28+
# tokenizer = AutoTokenizer.from_pretrained("/storage/ongoing/new/Open-Sora-Plan/cache_dir/mt5-xxl", cache_dir=args.cache_dir)
29+
tokenizer = AutoTokenizer.from_pretrained(args.text_encoder_name,
30+
cache_dir=args.cache_dir)
31+
if args.dataset == "t2v":
32+
return T2V_dataset(args,
33+
transform=transform,
34+
temporal_sample=temporal_sample,
35+
tokenizer=tokenizer,
36+
transform_topcrop=transform_topcrop,
37+
start_idx=start_idx)
38+
39+
raise NotImplementedError(args.dataset)
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
# schema.py
2+
"""
3+
Unified data schema and format for saving and loading image/video data after
4+
preprocessing.
5+
6+
It uses apache arrow in-memory format that can be consumed by modern data
7+
frameworks that can handle parquet or lance file.
8+
"""
9+
10+
import pyarrow as pa
11+
12+
pyarrow_schema = pa.schema([
13+
pa.field("id", pa.string()),
14+
# --- Image/Video VAE latents ---
15+
# Tensors are stored as raw bytes with shape and dtype info for loading
16+
pa.field("vae_latent_bytes", pa.binary()),
17+
# e.g., [C, T, H, W] or [C, H, W]
18+
pa.field("vae_latent_shape", pa.list_(pa.int64())),
19+
# e.g., 'float32'
20+
pa.field("vae_latent_dtype", pa.string()),
21+
# --- Text encoder output tensor ---
22+
# Tensors are stored as raw bytes with shape and dtype info for loading
23+
pa.field("text_embedding_bytes", pa.binary()),
24+
# e.g., [SeqLen, Dim]
25+
pa.field("text_embedding_shape", pa.list_(pa.int64())),
26+
# e.g., 'bfloat16' or 'float32'
27+
pa.field("text_embedding_dtype", pa.string()),
28+
pa.field("text_attention_mask_bytes", pa.binary()),
29+
# e.g., [SeqLen]
30+
pa.field("text_attention_mask_shape", pa.list_(pa.int64())),
31+
# e.g., 'bool' or 'int8'
32+
pa.field("text_attention_mask_dtype", pa.string()),
33+
# --- Metadata ---
34+
pa.field("file_name", pa.string()),
35+
pa.field("caption", pa.string()),
36+
pa.field("media_type", pa.string()), # 'image' or 'video'
37+
pa.field("width", pa.int64()),
38+
pa.field("height", pa.int64()),
39+
# -- Video-specific (can be null/default for images) ---
40+
# Number of frames processed (e.g., 1 for image, N for video)
41+
pa.field("num_frames", pa.int64()),
42+
pa.field("duration_sec", pa.float64()),
43+
pa.field("fps", pa.float64()),
44+
])
Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
import json
2+
import os
3+
import random
4+
5+
import torch
6+
from torch.utils.data import Dataset
7+
8+
9+
class LatentDataset(Dataset):
10+
11+
def __init__(
12+
self,
13+
json_path,
14+
num_latent_t,
15+
cfg_rate,
16+
) -> None:
17+
# data_merge_path: video_dir, latent_dir, prompt_embed_dir, json_path
18+
self.json_path = json_path
19+
self.cfg_rate = cfg_rate
20+
self.datase_dir_path = os.path.dirname(json_path)
21+
self.video_dir = os.path.join(self.datase_dir_path, "video")
22+
self.latent_dir = os.path.join(self.datase_dir_path, "latent")
23+
self.prompt_embed_dir = os.path.join(self.datase_dir_path,
24+
"prompt_embed")
25+
self.prompt_attention_mask_dir = os.path.join(self.datase_dir_path,
26+
"prompt_attention_mask")
27+
with open(self.json_path) as f:
28+
self.data_anno = json.load(f)
29+
# json.load(f) already keeps the order
30+
# self.data_anno = sorted(self.data_anno, key=lambda x: x['latent_path'])
31+
self.num_latent_t = num_latent_t
32+
33+
self.uncond_prompt_embed = torch.zeros(256, 4096).to(torch.float32)
34+
35+
self.uncond_prompt_mask = torch.zeros(256).bool()
36+
self.lengths = [
37+
data_item.get("length", 1) for data_item in self.data_anno
38+
]
39+
40+
def __getitem__(self, idx):
41+
latent_file = self.data_anno[idx]["latent_path"]
42+
prompt_embed_file = self.data_anno[idx]["prompt_embed_path"]
43+
prompt_attention_mask_file = self.data_anno[idx][
44+
"prompt_attention_mask"]
45+
# load
46+
latent = torch.load(
47+
os.path.join(self.latent_dir, latent_file),
48+
map_location="cpu",
49+
weights_only=True,
50+
)
51+
latent = latent.squeeze(0)[:, -self.num_latent_t:]
52+
if random.random() < self.cfg_rate:
53+
prompt_embed = self.uncond_prompt_embed
54+
prompt_attention_mask = self.uncond_prompt_mask
55+
else:
56+
prompt_embed = torch.load(
57+
os.path.join(self.prompt_embed_dir, prompt_embed_file),
58+
map_location="cpu",
59+
weights_only=True,
60+
)
61+
prompt_attention_mask = torch.load(
62+
os.path.join(self.prompt_attention_mask_dir,
63+
prompt_attention_mask_file),
64+
map_location="cpu",
65+
weights_only=True,
66+
)
67+
return latent, prompt_embed, prompt_attention_mask
68+
69+
def __len__(self):
70+
return len(self.data_anno)
71+
72+
73+
def latent_collate_function(batch):
74+
# return latent, prompt, latent_attn_mask, text_attn_mask
75+
# latent_attn_mask: # b t h w
76+
# text_attn_mask: b 1 l
77+
# needs to check if the latent/prompt' size and apply padding & attn mask
78+
latents, prompt_embeds, prompt_attention_masks = zip(*batch)
79+
# calculate max shape
80+
max_t = max([latent.shape[1] for latent in latents])
81+
max_h = max([latent.shape[2] for latent in latents])
82+
max_w = max([latent.shape[3] for latent in latents])
83+
84+
# padding
85+
latent_list: list[torch.Tensor] = [
86+
torch.nn.functional.pad(
87+
latent,
88+
(
89+
0,
90+
max_t - latent.shape[1],
91+
0,
92+
max_h - latent.shape[2],
93+
0,
94+
max_w - latent.shape[3],
95+
),
96+
) for latent in latents
97+
]
98+
# attn mask
99+
latent_attn_mask = torch.ones(len(latent_list), max_t, max_h, max_w)
100+
# set to 0 if padding
101+
for i, latent in enumerate(latent_list):
102+
latent_attn_mask[i, latent.shape[1]:, :, :] = 0
103+
latent_attn_mask[i, :, latent.shape[2]:, :] = 0
104+
latent_attn_mask[i, :, :, latent.shape[3]:] = 0
105+
106+
prompt_embeds = torch.stack(prompt_embeds, dim=0)
107+
prompt_attention_masks = torch.stack(prompt_attention_masks, dim=0)
108+
latents = torch.stack(latent_list, dim=0)
109+
return latents, prompt_embeds, latent_attn_mask, prompt_attention_masks
110+
111+
112+
if __name__ == "__main__":
113+
dataset = LatentDataset("data/Mochi-Synthetic-Data/merge.txt",
114+
num_latent_t=28,
115+
cfg_rate=0.0)
116+
dataloader = torch.utils.data.DataLoader(dataset,
117+
batch_size=2,
118+
shuffle=False,
119+
collate_fn=latent_collate_function)
120+
for latent, prompt_embed, latent_attn_mask, prompt_attention_mask in dataloader:
121+
print(
122+
latent.shape,
123+
prompt_embed.shape,
124+
latent_attn_mask.shape,
125+
prompt_attention_mask.shape,
126+
)
127+
import pdb
128+
129+
pdb.set_trace()

0 commit comments

Comments
 (0)