Skip to content

Commit f877e53

Browse files
committed
Update hunyuan video t2v preprocess
1 parent 3b39366 commit f877e53

File tree

11 files changed

+61
-76
lines changed

11 files changed

+61
-76
lines changed

HunyuanVideo

Lines changed: 0 additions & 1 deletion
This file was deleted.
Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,15 @@
22

33
GPU_NUM=1 # 2,4,8
44
MODEL_PATH="hunyuanvideo-community/HunyuanVideo"
5-
DATASET_PATH="/FastVideo/data/mini_i2v_dataset/crush-smol_raw"
6-
OUTPUT_DIR="/FastVideo/data/mini_i2v_dataset/crush-smol_processed_t2v_hunyuan/"
5+
DATASET_PATH="data/crush-smol"
6+
OUTPUT_DIR="data/crush-smol_processed_t2v_hunyuan/"
77

88
torchrun --nproc_per_node=$GPU_NUM \
99
-m fastvideo.pipelines.preprocess.v1_preprocessing_new \
1010
--model_path $MODEL_PATH \
1111
--mode preprocess \
1212
--workload_type t2v \
13+
--preprocess.dataset_type merged \
1314
--preprocess.dataset_path $DATASET_PATH \
1415
--preprocess.dataset_output_dir $OUTPUT_DIR \
1516
--preprocess.preprocess_video_batch_size 2 \
@@ -21,3 +22,4 @@ torchrun --nproc_per_node=$GPU_NUM \
2122
--preprocess.samples_per_file 8 \
2223
--preprocess.flush_frequency 8 \
2324
--preprocess.video_length_tolerance_range 5
25+

fastvideo/configs/configs.py

Lines changed: 0 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -9,29 +9,6 @@
99
logger = init_logger(__name__)
1010

1111

12-
class DatasetType(str, Enum):
13-
"""
14-
Enumeration for different dataset types.
15-
"""
16-
HF = "hf"
17-
MERGED = "merged"
18-
19-
@classmethod
20-
def from_string(cls, value: str) -> "DatasetType":
21-
"""Convert string to DatasetType enum."""
22-
try:
23-
return cls(value.lower())
24-
except ValueError:
25-
raise ValueError(
26-
f"Invalid dataset type: {value}. Must be one of: {', '.join([m.value for m in cls])}"
27-
) from None
28-
29-
@classmethod
30-
def choices(cls) -> list[str]:
31-
"""Get all available choices as strings for argparse."""
32-
return [dataset_type.value for dataset_type in cls]
33-
34-
3512
class DatasetType(str, Enum):
3613
"""
3714
Enumeration for different dataset types.

fastvideo/configs/models/encoders/clip.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,13 @@ class CLIPVisionArchConfig(ImageEncoderArchConfig):
7474
class CLIPTextConfig(TextEncoderConfig):
7575
arch_config: TextEncoderArchConfig = field(
7676
default_factory=CLIPTextArchConfig)
77-
77+
tokenizer_kwargs: dict = field(
78+
default_factory=lambda: {
79+
"padding": "max_length",
80+
"truncation": True,
81+
"max_length": 77,
82+
"return_tensors": "pt"
83+
})
7884
num_hidden_layers_override: int | None = None
7985
require_post_norm: bool | None = None
8086
prefix: str = "clip"

fastvideo/configs/models/encoders/llama.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,5 +60,11 @@ class LlamaArchConfig(TextEncoderArchConfig):
6060
@dataclass
6161
class LlamaConfig(TextEncoderConfig):
6262
arch_config: TextEncoderArchConfig = field(default_factory=LlamaArchConfig)
63-
63+
tokenizer_kwargs: dict = field(
64+
default_factory=lambda: {
65+
"padding": "max_length",
66+
"truncation": True,
67+
"max_length": 256,
68+
"return_tensors": "pt"
69+
})
6470
prefix: str = "llama"

fastvideo/layers/rotary_embedding.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -138,14 +138,14 @@ def forward_native(
138138
cos, sin = cos_sin.chunk(2, dim=-1)
139139

140140
query_shape = query.shape
141-
query = query.view(num_tokens, -1, self.head_size)
141+
query = query.reshape(num_tokens, -1, self.head_size)
142142
query_rot = query[..., :self.rotary_dim]
143143
query_pass = query[..., self.rotary_dim:]
144144
query_rot = _apply_rotary_emb(query_rot, cos, sin, self.is_neox_style)
145145
query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)
146146

147147
key_shape = key.shape
148-
key = key.view(num_tokens, -1, self.head_size)
148+
key = key.reshape(num_tokens, -1, self.head_size)
149149
key_rot = key[..., :self.rotary_dim]
150150
key_pass = key[..., self.rotary_dim:]
151151
key_rot = _apply_rotary_emb(key_rot, cos, sin, self.is_neox_style)

fastvideo/models/vaes/hunyuanvae.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ def forward(self,
9191
key = key.view(batch_size, -1, self.heads, head_dim).transpose(1, 2)
9292
value = value.view(batch_size, -1, self.heads, head_dim).transpose(1, 2)
9393

94-
# Perform scaled dot-product attention
94+
# Perform scaled dot-product attentionz
9595
hidden_states = F.scaled_dot_product_attention(query,
9696
key,
9797
value,
@@ -361,7 +361,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
361361
hidden_states.device,
362362
batch_size=batch_size)
363363
hidden_states = attn(hidden_states,
364-
attention_mask=attention_mask)
364+
attention_mask=attention_mask.unsqueeze(1))
365365
hidden_states = hidden_states.unflatten(
366366
1, (num_frames, height, width)).permute(0, 4, 1, 2, 3)
367367

@@ -385,7 +385,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
385385
hidden_states.device,
386386
batch_size=batch_size)
387387
hidden_states = attn(hidden_states,
388-
attention_mask=attention_mask)
388+
attention_mask=attention_mask.unsqueeze(1))
389389
hidden_states = hidden_states.unflatten(
390390
1, (num_frames, height, width)).permute(0, 4, 1, 2, 3)
391391

fastvideo/pipelines/preprocess/hunyuan/hunyuan_preprocess_pipelines.py

Lines changed: 31 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@
99

1010
class PreprocessPipelineI2V(ComposedPipelineBase):
1111
_required_config_modules = [
12-
"image_encoder", "image_processor", "text_encoder", "tokenizer", "vae"
12+
"image_encoder", "image_processor", "text_encoder", "tokenizer",
13+
"text_encoder_2", "tokenizer_2", "vae"
1314
]
1415

1516
def create_pipeline_stages(self, fastvideo_args: FastVideoArgs):
@@ -51,7 +52,9 @@ def create_pipeline_stages(self, fastvideo_args: FastVideoArgs):
5152

5253

5354
class PreprocessPipelineT2V(ComposedPipelineBase):
54-
_required_config_modules = ["text_encoder", "tokenizer", "vae"]
55+
_required_config_modules = [
56+
"text_encoder", "tokenizer", "text_encoder_2", "tokenizer_2", "vae"
57+
]
5558

5659
def create_pipeline_stages(self, fastvideo_args: FastVideoArgs):
5760
assert fastvideo_args.preprocess_config is not None
@@ -61,10 +64,34 @@ def create_pipeline_stages(self, fastvideo_args: FastVideoArgs):
6164
preprocess_config.training_cfg_rate,
6265
seed=fastvideo_args.preprocess_config.seed,
6366
))
67+
# llama_tokenizer_kwargs = {
68+
# "padding": "max_length",
69+
# "truncation": True,
70+
# "max_length": 256,
71+
# "return_tensors": "pt"
72+
# }
73+
# clip_tokenizer_kwargs = {
74+
# "padding": "max_length",
75+
# "truncation": True,
76+
# "max_length": 77,
77+
# "return_tensors": "pt"
78+
# }
79+
# if len(fastvideo_args.pipeline_config.text_encoder_configs) >= 2:
80+
# fastvideo_args.pipeline_config.text_encoder_configs[0].tokenizer_kwargs = llama_tokenizer_kwargs
81+
# fastvideo_args.pipeline_config.text_encoder_configs[1].tokenizer_kwargs = clip_tokenizer_kwargs
82+
text_encoders = [
83+
self.get_module("text_encoder"),
84+
self.get_module("text_encoder_2")
85+
]
86+
tokenizers = [
87+
self.get_module("tokenizer"),
88+
self.get_module("tokenizer_2")
89+
]
90+
6491
self.add_stage(stage_name="prompt_encoding_stage",
6592
stage=TextEncodingStage(
66-
text_encoders=[self.get_module("text_encoder")],
67-
tokenizers=[self.get_module("tokenizer")],
93+
text_encoders=text_encoders,
94+
tokenizers=tokenizers,
6895
))
6996
self.add_stage(
7097
stage_name="video_transform_stage",

fastvideo/workflow/preprocess/components.py

Lines changed: 2 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,6 @@
1414
from datasets import Dataset, Video, load_dataset
1515

1616
from fastvideo.configs.configs import DatasetType, PreprocessConfig
17-
<<<<<<< HEAD
18-
from fastvideo.distributed.parallel_state import get_world_rank, get_world_size
19-
=======
20-
>>>>>>> 15df36ab ([Feat][Preprocess] support merged dataset (#752))
2117
from fastvideo.logger import init_logger
2218
from fastvideo.pipelines.pipeline_batch_info import PreprocessBatch
2319

@@ -82,10 +78,8 @@ def __call__(self, batch: dict[str, Any]) -> bool:
8278

8379
def _validate_data_type(self, batch: dict[str, Any]) -> bool:
8480
"""Validate basic validity of data items"""
85-
print("-------------------------------")
86-
print(batch)
87-
return not (batch["caption"] is None or batch["caption"] == ""
88-
or "fps" not in batch or batch["fps"] is None or batch["fps"] <= 0
81+
return not (batch["caption"] is None or batch["caption"] == "" or "fps"
82+
not in batch or batch["fps"] is None or batch["fps"] <= 0
8983
or batch["num_frames"] is None or batch["num_frames"] <= 0)
9084

9185
def _validate_resolution(self, batch: dict[str, Any]) -> bool:
@@ -405,19 +399,9 @@ def _default_file_writer_fn(self, args_tuple: tuple) -> int:
405399
return written_count
406400

407401

408-
<<<<<<< HEAD
409-
def build_dataset(preprocess_config: PreprocessConfig, split: str,
410-
validator: Callable[[dict[str, Any]], bool]) -> Dataset:
411-
if preprocess_config.dataset_type == DatasetType.HF:
412-
dataset = load_dataset(preprocess_config.dataset_path, split=split)
413-
dataset = dataset.filter(validator)
414-
dataset = dataset.shard(num_shards=get_world_size(),
415-
index=get_world_rank())
416-
=======
417402
def build_dataset(preprocess_config: PreprocessConfig, split: str) -> Dataset:
418403
if preprocess_config.dataset_type == DatasetType.HF:
419404
dataset = load_dataset(preprocess_config.dataset_path, split=split)
420-
>>>>>>> 15df36ab ([Feat][Preprocess] support merged dataset (#752))
421405
elif preprocess_config.dataset_type == DatasetType.MERGED:
422406
metadata_json_path = os.path.join(preprocess_config.dataset_path,
423407
"videos2caption.json")
@@ -431,14 +415,6 @@ def build_dataset(preprocess_config: PreprocessConfig, split: str) -> Dataset:
431415
dataset = dataset.rename_column("cap", "caption")
432416
if "path" in column_names:
433417
dataset = dataset.rename_column("path", "name")
434-
<<<<<<< HEAD
435-
436-
dataset = dataset.filter(validator)
437-
dataset = dataset.shard(num_shards=get_world_size(),
438-
index=get_world_rank())
439-
440-
=======
441-
>>>>>>> 15df36ab ([Feat][Preprocess] support merged dataset (#752))
442418
# add video column
443419
def add_video_column(item: dict[str, Any]) -> dict[str, Any]:
444420
item["video"] = os.path.join(video_folder, item["name"])

fastvideo/workflow/preprocess/preprocess_workflow.py

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -44,13 +44,7 @@ def register_components(self) -> None:
4444
self.add_component("raw_data_validator", raw_data_validator)
4545

4646
# training dataset
47-
<<<<<<< HEAD
48-
training_dataset = build_dataset(preprocess_config,
49-
split="train",
50-
validator=raw_data_validator)
51-
=======
5247
training_dataset = build_dataset(preprocess_config, split="train")
53-
>>>>>>> 15df36ab ([Feat][Preprocess] support merged dataset (#752))
5448
# set load_from_cache_file to False to check filter stats
5549
training_dataset = training_dataset.filter(raw_data_validator)
5650
# we do not use collate_fn here because we use iterable-style Dataset
@@ -66,13 +60,8 @@ def register_components(self) -> None:
6660
# try to load validation dataset if it exists
6761
try:
6862
validation_dataset = build_dataset(preprocess_config,
69-
<<<<<<< HEAD
70-
split="validation",
71-
validator=raw_data_validator)
72-
=======
7363
split="validation")
7464
validation_dataset = validation_dataset.filter(raw_data_validator)
75-
>>>>>>> 15df36ab ([Feat][Preprocess] support merged dataset (#752))
7665
validation_dataloader = DataLoader(
7766
validation_dataset,
7867
batch_size=preprocess_config.preprocess_video_batch_size,

0 commit comments

Comments
 (0)