Skip to content

Commit 6f2c09d

Browse files
committed
[WIP][Feat] distill run in single gpu
1 parent 30db272 commit 6f2c09d

File tree

8 files changed

+190
-930
lines changed

8 files changed

+190
-930
lines changed

fastvideo/distill/solver.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,27 @@
77
from diffusers.schedulers.scheduling_utils import SchedulerMixin
88
from diffusers.utils import BaseOutput, logging
99

10-
from fastvideo.models.mochi_hf.pipeline_mochi import linear_quadratic_schedule
11-
1210
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
1311

1412

13+
# from: https://github.com/genmoai/models/blob/075b6e36db58f1242921deff83a1066887b9c9e1/src/mochi_preview/infer.py#L77
14+
def linear_quadratic_schedule(num_steps, threshold_noise, linear_steps=None):
15+
if linear_steps is None:
16+
linear_steps = num_steps // 2
17+
linear_sigma_schedule = [i * threshold_noise / linear_steps for i in range(linear_steps)]
18+
threshold_noise_step_diff = linear_steps - threshold_noise * num_steps
19+
quadratic_steps = num_steps - linear_steps
20+
quadratic_coef = threshold_noise_step_diff / (linear_steps * quadratic_steps**2)
21+
linear_coef = threshold_noise / linear_steps - 2 * threshold_noise_step_diff / (quadratic_steps**2)
22+
const = quadratic_coef * (linear_steps**2)
23+
quadratic_sigma_schedule = [
24+
quadratic_coef * (i**2) + linear_coef * i + const for i in range(linear_steps, num_steps)
25+
]
26+
sigma_schedule = linear_sigma_schedule + quadratic_sigma_schedule
27+
sigma_schedule = [1.0 - x for x in sigma_schedule]
28+
return sigma_schedule
29+
30+
1531
@dataclass
1632
class PCMFMSchedulerOutput(BaseOutput):
1733
prev_sample: torch.FloatTensor

fastvideo/v1/dataset/parquet_datasets.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,9 @@ def __init__(self,
119119
plan = json.load(f)
120120
self.neg_metadata = plan["negative_prompt"][0]
121121

122+
self.uncond_prompt_embed = torch.zeros(512, 4096).to(torch.float32)
123+
self.uncond_prompt_mask = torch.zeros(1, 512).bool()
124+
122125
def _load_and_cache_negative_prompt(self) -> None:
123126
"""Load and cache the negative prompt. Only rank 0 in each SP group should call this."""
124127
if not self.validation or self.neg_metadata is None:
@@ -188,6 +191,7 @@ def get_validation_negative_prompt(
188191
lat = lat[:, self.rank_in_sp_group, :, :, :]
189192
return lat, emb, mask, info
190193

194+
191195
def __len__(self):
192196
if self.local_indices is None:
193197
try:

fastvideo/v1/fastvideo_args.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -590,6 +590,8 @@ class TrainingArgs(FastVideoArgs):
590590
pred_decay_type: str = ""
591591
hunyuan_teacher_disable_cfg: bool = False
592592

593+
use_lora: bool = False
594+
593595
# master_weight_type
594596
master_weight_type: str = ""
595597

fastvideo/v1/models/loader/component_loader.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -366,6 +366,7 @@ class TransformerLoader(ComponentLoader):
366366
def load(self, model_path: str, architecture: str,
367367
fastvideo_args: FastVideoArgs):
368368
"""Load the transformer based on the model path, architecture, and inference args."""
369+
print(f"Loading transformer from {model_path}")
369370
config = get_diffusers_config(model=model_path)
370371
hf_config = deepcopy(config)
371372
cls_name = config.pop("_class_name")

fastvideo/v1/pipelines/composed_pipeline_base.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ def __init__(self,
5656
use. The pipeline should be stateless and not hold any batch state.
5757
"""
5858

59-
if fastvideo_args.training_mode:
59+
if fastvideo_args.training_mode or fastvideo_args.distill_mode:
6060
assert isinstance(fastvideo_args, TrainingArgs)
6161
self.training_args = fastvideo_args
6262
assert self.training_args is not None
@@ -94,11 +94,12 @@ def __init__(self,
9494
self.initialize_validation_pipeline(self.training_args)
9595
self.initialize_training_pipeline(self.training_args)
9696

97+
# TODO(jinzhe): discuss this
9798
if fastvideo_args.distill_mode:
98-
self.initialize_distillation_pipeline(fastvideo_args)
99-
100-
if fastvideo_args.log_validation:
101-
self.initialize_validation_pipeline(fastvideo_args)
99+
assert self.training_args is not None
100+
if self.training_args.log_validation:
101+
self.initialize_validation_pipeline(self.training_args)
102+
self.initialize_distillation_pipeline(self.training_args)
102103

103104
self.initialize_pipeline(fastvideo_args)
104105

0 commit comments

Comments
 (0)