Skip to content

Commit ab01dc4

Browse files
[Feature] [Training] Add i2v training (#559)
1 parent 285a950 commit ab01dc4

18 files changed

+859
-69
lines changed

examples/training/finetune/wan_i2v_14b_480p/crush_smol/finetune_i2v.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ miscellaneous_args=(
6969
--inference_mode False
7070
--allow_tf32
7171
--checkpoints_total_limit 3
72-
--cfg 0.0
72+
--training_cfg_rate 0.1
7373
--multi_phased_distill_schedule "4000-1"
7474
--not_apply_cfg_solver
7575
--dit_precision "fp32"

examples/training/finetune/wan_i2v_14b_480p/crush_smol/finetune_i2v.slurm

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ miscellaneous_args=(
106106
--inference_mode False
107107
--allow_tf32
108108
--checkpoints_total_limit 3
109-
--cfg 0.0
109+
--training_cfg_rate 0.1
110110
--multi_phased_distill_schedule "4000-1"
111111
--not_apply_cfg_solver
112112
--dit_precision "fp32"

examples/training/finetune/wan_t2v_1_3b/crush_smol/finetune_t2v.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ miscellaneous_args=(
6969
--inference_mode False
7070
--allow_tf32
7171
--checkpoints_total_limit 3
72-
--cfg 0.0
72+
--training_cfg_rate 0.1
7373
--multi_phased_distill_schedule "4000-1"
7474
--not_apply_cfg_solver
7575
--dit_precision "fp32"

examples/training/finetune/wan_t2v_1_3b/crush_smol/finetune_t2v.slurm

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ miscellaneous_args=(
103103
--inference_mode False
104104
--allow_tf32
105105
--checkpoints_total_limit 3
106-
--cfg 0.0
106+
--training_cfg_rate 0.1
107107
--multi_phased_distill_schedule "4000-1"
108108
--not_apply_cfg_solver
109109
--dit_precision "fp32"

fastvideo/v1/dataset/parquet_dataset_iterable_style.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from typing import Dict, List, Tuple
55

66
import numpy as np
7+
import pyarrow as pa
78
import pyarrow.parquet as pq
89
import torch
910
import tqdm
@@ -70,10 +71,12 @@ def __init__(self,
7071
drop_last: bool = True,
7172
text_padding_length: int = 512,
7273
seed: int = 42,
73-
read_batch_size: int = 32):
74+
read_batch_size: int = 32,
75+
parquet_schema: pa.Schema = None):
7476
super().__init__()
7577
self.path = str(path)
7678
self.batch_size = batch_size
79+
self.parquet_schema = parquet_schema
7780
self.cfg_rate = cfg_rate
7881
self.text_padding_length = text_padding_length
7982
self.seed = seed

fastvideo/v1/dataset/parquet_dataset_map_style.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -201,10 +201,6 @@ def __init__(
201201
self.path = path
202202
self.cfg_rate = cfg_rate
203203
self.parquet_schema = parquet_schema
204-
if cfg_rate > 0.0:
205-
raise ValueError(
206-
"cfg_rate > 0.0 is not supported for now because it will trigger bug when num_data_workers > 0"
207-
)
208204
logger.info("Initializing LatentsParquetMapStyleDataset with path: %s",
209205
path)
210206
self.parquet_files, self.lengths = get_parquet_files_and_length(path)

fastvideo/v1/dataset/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import random
2-
from typing import Any, Dict, List
2+
from typing import Any, Dict, List, cast
33

44
import numpy as np
55
import torch
@@ -108,7 +108,7 @@ def collate_rows_from_parquet_schema(rows,
108108
Dict containing batched tensors and metadata
109109
"""
110110
if not rows:
111-
return {}
111+
return cast(Dict[str, Any], {})
112112

113113
# Initialize containers for different data types
114114
batch_data: Dict[str, Any] = {}

fastvideo/v1/pipelines/pipeline_batch_info.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ class ForwardBatch:
3939
image_path: Optional[str] = None
4040
image_embeds: List[torch.Tensor] = field(default_factory=list)
4141
pil_image: Optional[PIL.Image.Image] = None
42+
preprocessed_image: Optional[torch.Tensor] = None
4243

4344
# Text inputs
4445
prompt: Optional[Union[str, List[str]]] = None
@@ -150,6 +151,10 @@ class TrainingBatch:
150151
latents: Optional[torch.Tensor] = None
151152
encoder_hidden_states: Optional[torch.Tensor] = None
152153
encoder_attention_mask: Optional[torch.Tensor] = None
154+
# i2v
155+
preprocessed_image: Optional[torch.Tensor] = None
156+
image_embeds: Optional[torch.Tensor] = None
157+
image_latents: Optional[torch.Tensor] = None
153158
infos: Optional[List[Dict[str, Any]]] = None
154159

155160
# Transformer inputs

0 commit comments

Comments
 (0)