|  | 
| 4 | 4 | 
 | 
| 5 | 5 | import safetensors | 
| 6 | 6 | import torch | 
|  | 7 | +from torch.nn import Module | 
| 7 | 8 | from accelerate import init_empty_weights | 
| 8 | 9 | from diffusers import ( | 
| 9 | 10 |     FlowMatchEulerDiscreteScheduler, | 
|  | 
| 14 | 15 | from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution | 
| 15 | 16 | from transformers import AutoTokenizer, CLIPTextModel, CLIPTokenizer, LlamaModel | 
| 16 | 17 | from finetrainers.data._artifact import VideoArtifact | 
| 17 |  | -from finetrainers.models.hunyuan_video import hunyuan_common | 
| 18 | 18 | from finetrainers.models.utils import _expand_conv3d_with_zeroed_weights | 
|  | 19 | +from finetrainers.trainer.control_trainer.config import FrameConditioningType | 
| 19 | 20 | from finetrainers.utils.serialization import safetensors_torch_save_function | 
| 20 | 21 | 
 | 
| 21 | 22 | from ... import data | 
| 22 | 23 | from ... import functional as FF | 
| 23 | 24 | from ...logging import get_logger | 
| 24 | 25 | from ...patches.dependencies.diffusers.control import control_channel_concat | 
| 25 | 26 | from ...processors import ProcessorMixin | 
| 26 |  | -from ...typing import ArtifactType, FrameConditioningType, SchedulerType | 
|  | 27 | +from ...typing import ArtifactType, SchedulerType | 
| 27 | 28 | from ...utils import get_non_null_items | 
| 28 | 29 | from ..modeling_utils import ControlModelSpecification | 
| 29 | 30 | from .base_specification import HunyuanLatentEncodeProcessor | 
| 30 | 31 | from ...processors import CLIPPooledProcessor, LlamaProcessor, ProcessorMixin | 
| 31 | 32 | 
 | 
|  | 33 | +from ...utils import _enable_vae_memory_optimizations, get_non_null_items | 
|  | 34 | + | 
| 32 | 35 | logger = get_logger() | 
| 33 | 36 | 
 | 
| 34 | 37 | 
 | 
| @@ -88,11 +91,102 @@ def control_injection_layer_name(self) -> str: | 
| 88 | 91 |     def _resolution_dim_keys(self): | 
| 89 | 92 |         return {"latents": (2, 3, 4)} | 
| 90 | 93 | 
 | 
| 91 |  | -    load_condition_models = hunyuan_common.load_condition_models | 
|  | 94 | +    def load_condition_models(self) -> Dict[str, torch.nn.Module]: | 
|  | 95 | +        common_kwargs = {"revision": self.revision, "cache_dir": self.cache_dir} | 
|  | 96 | + | 
|  | 97 | +        if self.tokenizer_id is not None: | 
|  | 98 | +            tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_id, **common_kwargs) | 
|  | 99 | +        else: | 
|  | 100 | +            tokenizer = AutoTokenizer.from_pretrained( | 
|  | 101 | +                self.pretrained_model_name_or_path, subfolder="tokenizer", **common_kwargs | 
|  | 102 | +            ) | 
|  | 103 | + | 
|  | 104 | +        if self.tokenizer_2_id is not None: | 
|  | 105 | +            tokenizer_2 = AutoTokenizer.from_pretrained(self.tokenizer_2_id, **common_kwargs) | 
|  | 106 | +        else: | 
|  | 107 | +            tokenizer_2 = CLIPTokenizer.from_pretrained( | 
|  | 108 | +                self.pretrained_model_name_or_path, subfolder="tokenizer_2", **common_kwargs | 
|  | 109 | +            ) | 
|  | 110 | + | 
|  | 111 | +        if self.text_encoder_id is not None: | 
|  | 112 | +            text_encoder = LlamaModel.from_pretrained( | 
|  | 113 | +                self.text_encoder_id, torch_dtype=self.text_encoder_dtype, **common_kwargs | 
|  | 114 | +            ) | 
|  | 115 | +        else: | 
|  | 116 | +            text_encoder = LlamaModel.from_pretrained( | 
|  | 117 | +                self.pretrained_model_name_or_path, | 
|  | 118 | +                subfolder="text_encoder", | 
|  | 119 | +                torch_dtype=self.text_encoder_dtype, | 
|  | 120 | +                **common_kwargs, | 
|  | 121 | +            ) | 
|  | 122 | + | 
|  | 123 | +        if self.text_encoder_2_id is not None: | 
|  | 124 | +            text_encoder_2 = CLIPTextModel.from_pretrained( | 
|  | 125 | +                self.text_encoder_2_id, torch_dtype=self.text_encoder_2_dtype, **common_kwargs | 
|  | 126 | +            ) | 
|  | 127 | +        else: | 
|  | 128 | +            text_encoder_2 = CLIPTextModel.from_pretrained( | 
|  | 129 | +                self.pretrained_model_name_or_path, | 
|  | 130 | +                subfolder="text_encoder_2", | 
|  | 131 | +                torch_dtype=self.text_encoder_2_dtype, | 
|  | 132 | +                **common_kwargs, | 
|  | 133 | +            ) | 
|  | 134 | + | 
|  | 135 | +        return { | 
|  | 136 | +            "tokenizer": tokenizer, | 
|  | 137 | +            "tokenizer_2": tokenizer_2, | 
|  | 138 | +            "text_encoder": text_encoder, | 
|  | 139 | +            "text_encoder_2": text_encoder_2, | 
|  | 140 | +        } | 
|  | 141 | + | 
|  | 142 | +    def load_latent_models(self) -> Dict[str, torch.nn.Module]: | 
|  | 143 | +        common_kwargs = {"revision": self.revision, "cache_dir": self.cache_dir} | 
|  | 144 | + | 
|  | 145 | +        if self.vae_id is not None: | 
|  | 146 | +            vae = AutoencoderKLHunyuanVideo.from_pretrained(self.vae_id, torch_dtype=self.vae_dtype, **common_kwargs) | 
|  | 147 | +        else: | 
|  | 148 | +            vae = AutoencoderKLHunyuanVideo.from_pretrained( | 
|  | 149 | +                self.pretrained_model_name_or_path, subfolder="vae", torch_dtype=self.vae_dtype, **common_kwargs | 
|  | 150 | +            ) | 
|  | 151 | + | 
|  | 152 | +        return {"vae": vae} | 
| 92 | 153 | 
 | 
| 93 |  | -    load_latent_models = hunyuan_common.load_latent_models | 
|  | 154 | +    def load_pipeline( | 
|  | 155 | +        self, | 
|  | 156 | +        tokenizer: Optional[AutoTokenizer] = None, | 
|  | 157 | +        tokenizer_2: Optional[CLIPTokenizer] = None, | 
|  | 158 | +        text_encoder: Optional[LlamaModel] = None, | 
|  | 159 | +        text_encoder_2: Optional[CLIPTextModel] = None, | 
|  | 160 | +        transformer: Optional[Module] = None, | 
|  | 161 | +        vae: Optional[AutoencoderKLHunyuanVideo] = None, | 
|  | 162 | +        scheduler: Optional[FlowMatchEulerDiscreteScheduler] = None, | 
|  | 163 | +        enable_slicing: bool = False, | 
|  | 164 | +        enable_tiling: bool = False, | 
|  | 165 | +        enable_model_cpu_offload: bool = False, | 
|  | 166 | +        training: bool = False, | 
|  | 167 | +        **kwargs, | 
|  | 168 | +    ) -> HunyuanVideoPipeline: | 
|  | 169 | +        components = { | 
|  | 170 | +            "tokenizer": tokenizer, | 
|  | 171 | +            "tokenizer_2": tokenizer_2, | 
|  | 172 | +            "text_encoder": text_encoder, | 
|  | 173 | +            "text_encoder_2": text_encoder_2, | 
|  | 174 | +            "transformer": transformer, | 
|  | 175 | +            "vae": vae, | 
|  | 176 | +            "scheduler": scheduler, | 
|  | 177 | +        } | 
|  | 178 | +        components = get_non_null_items(components) | 
|  | 179 | + | 
|  | 180 | +        pipe = HunyuanVideoPipeline.from_pretrained( | 
|  | 181 | +            self.pretrained_model_name_or_path, **components, revision=self.revision, cache_dir=self.cache_dir | 
|  | 182 | +        ) | 
|  | 183 | +        pipe.text_encoder.to(self.text_encoder_dtype) | 
|  | 184 | +        pipe.text_encoder_2.to(self.text_encoder_2_dtype) | 
|  | 185 | +        pipe.vae.to(self.vae_dtype) | 
| 94 | 186 | 
 | 
| 95 |  | -    load_pipeline = hunyuan_common.load_pipeline | 
|  | 187 | +        _enable_vae_memory_optimizations(pipe.vae, enable_slicing, enable_tiling) | 
|  | 188 | +        if not training: | 
|  | 189 | +            pipe.transformer.to(self.transformer_dtype) | 
| 96 | 190 | 
 | 
| 97 | 191 |     def load_diffusion_models(self, new_in_features: int) -> Dict[str, torch.nn.Module]: | 
| 98 | 192 |         common_kwargs = {"revision": self.revision, "cache_dir": self.cache_dir} | 
|  | 
0 commit comments