Skip to content

Commit 9d43e8a

Browse files
committed
remove hunyuan_common.py
1 parent d894b05 commit 9d43e8a

File tree

4 files changed

+204
-121
lines changed

4 files changed

+204
-121
lines changed

examples/formats/hunyuan_video/convert_to_original_format.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,17 @@ def convert_lora_sd(diffusers_lora_sd):
108108
elif "proj_out" in key:
109109
new_key = key.replace("proj_out", "linear2").replace(single_block_pattern, prefix + "single_blocks")
110110
converted_lora_sd[new_key] = diffusers_lora_sd[key]
111+
elif "x_embedder" in key:
112+
new_key = key.replace("x_embedder", "img_in").replace(double_block_pattern, prefix + "")
113+
if "lora_A" in key:
114+
embed = diffusers_lora_sd[key]
115+
sizes = embed.size()
116+
x_reshaped = embed.view(sizes[0], 16, sizes[2], sizes[3], sizes[4], 2)
117+
x_meaned = x_reshaped.mean(dim=2)
118+
converted_lora_sd[new_key] = x_meaned
119+
else:
120+
converted_lora_sd[new_key] = diffusers_lora_sd[key]
121+
print(new_key, diffusers_lora_sd[key].size())
111122

112123
else:
113124
print(f"unknown or not implemented: {key}")

finetrainers/models/hunyuan_video/base_specification.py

Lines changed: 94 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,6 @@
1313
from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution
1414
from transformers import AutoTokenizer, CLIPTextModel, CLIPTokenizer, LlamaModel
1515

16-
from finetrainers.models.hunyuan_video import hunyuan_common
17-
1816
import finetrainers.functional as FF
1917
from finetrainers.data import VideoArtifact
2018
from finetrainers.logging import get_logger
@@ -132,11 +130,102 @@ def __init__(
132130
def _resolution_dim_keys(self):
133131
return {"latents": (2, 3, 4)}
134132

135-
load_condition_models = hunyuan_common.load_condition_models
133+
def load_condition_models(self) -> Dict[str, torch.nn.Module]:
134+
common_kwargs = {"revision": self.revision, "cache_dir": self.cache_dir}
135+
136+
if self.tokenizer_id is not None:
137+
tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_id, **common_kwargs)
138+
else:
139+
tokenizer = AutoTokenizer.from_pretrained(
140+
self.pretrained_model_name_or_path, subfolder="tokenizer", **common_kwargs
141+
)
142+
143+
if self.tokenizer_2_id is not None:
144+
tokenizer_2 = AutoTokenizer.from_pretrained(self.tokenizer_2_id, **common_kwargs)
145+
else:
146+
tokenizer_2 = CLIPTokenizer.from_pretrained(
147+
self.pretrained_model_name_or_path, subfolder="tokenizer_2", **common_kwargs
148+
)
149+
150+
if self.text_encoder_id is not None:
151+
text_encoder = LlamaModel.from_pretrained(
152+
self.text_encoder_id, torch_dtype=self.text_encoder_dtype, **common_kwargs
153+
)
154+
else:
155+
text_encoder = LlamaModel.from_pretrained(
156+
self.pretrained_model_name_or_path,
157+
subfolder="text_encoder",
158+
torch_dtype=self.text_encoder_dtype,
159+
**common_kwargs,
160+
)
161+
162+
if self.text_encoder_2_id is not None:
163+
text_encoder_2 = CLIPTextModel.from_pretrained(
164+
self.text_encoder_2_id, torch_dtype=self.text_encoder_2_dtype, **common_kwargs
165+
)
166+
else:
167+
text_encoder_2 = CLIPTextModel.from_pretrained(
168+
self.pretrained_model_name_or_path,
169+
subfolder="text_encoder_2",
170+
torch_dtype=self.text_encoder_2_dtype,
171+
**common_kwargs,
172+
)
173+
174+
return {
175+
"tokenizer": tokenizer,
176+
"tokenizer_2": tokenizer_2,
177+
"text_encoder": text_encoder,
178+
"text_encoder_2": text_encoder_2,
179+
}
180+
181+
def load_latent_models(self) -> Dict[str, torch.nn.Module]:
182+
common_kwargs = {"revision": self.revision, "cache_dir": self.cache_dir}
183+
184+
if self.vae_id is not None:
185+
vae = AutoencoderKLHunyuanVideo.from_pretrained(self.vae_id, torch_dtype=self.vae_dtype, **common_kwargs)
186+
else:
187+
vae = AutoencoderKLHunyuanVideo.from_pretrained(
188+
self.pretrained_model_name_or_path, subfolder="vae", torch_dtype=self.vae_dtype, **common_kwargs
189+
)
190+
191+
return {"vae": vae}
136192

137-
load_latent_models = hunyuan_common.load_latent_models
193+
def load_pipeline(
194+
self,
195+
tokenizer: Optional[AutoTokenizer] = None,
196+
tokenizer_2: Optional[CLIPTokenizer] = None,
197+
text_encoder: Optional[LlamaModel] = None,
198+
text_encoder_2: Optional[CLIPTextModel] = None,
199+
transformer: Optional[torch.Module] = None,
200+
vae: Optional[AutoencoderKLHunyuanVideo] = None,
201+
scheduler: Optional[FlowMatchEulerDiscreteScheduler] = None,
202+
enable_slicing: bool = False,
203+
enable_tiling: bool = False,
204+
enable_model_cpu_offload: bool = False,
205+
training: bool = False,
206+
**kwargs,
207+
) -> HunyuanVideoPipeline:
208+
components = {
209+
"tokenizer": tokenizer,
210+
"tokenizer_2": tokenizer_2,
211+
"text_encoder": text_encoder,
212+
"text_encoder_2": text_encoder_2,
213+
"transformer": transformer,
214+
"vae": vae,
215+
"scheduler": scheduler,
216+
}
217+
components = get_non_null_items(components)
218+
219+
pipe = HunyuanVideoPipeline.from_pretrained(
220+
self.pretrained_model_name_or_path, **components, revision=self.revision, cache_dir=self.cache_dir
221+
)
222+
pipe.text_encoder.to(self.text_encoder_dtype)
223+
pipe.text_encoder_2.to(self.text_encoder_2_dtype)
224+
pipe.vae.to(self.vae_dtype)
138225

139-
load_pipeline = hunyuan_common.load_pipeline
226+
_enable_vae_memory_optimizations(pipe.vae, enable_slicing, enable_tiling)
227+
if not training:
228+
pipe.transformer.to(self.transformer_dtype)
140229

141230
def load_diffusion_models(self) -> Dict[str, torch.nn.Module]:
142231
common_kwargs = {"revision": self.revision, "cache_dir": self.cache_dir}

finetrainers/models/hunyuan_video/control_specification.py

Lines changed: 99 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
import safetensors
66
import torch
7+
from torch.nn import Module
78
from accelerate import init_empty_weights
89
from diffusers import (
910
FlowMatchEulerDiscreteScheduler,
@@ -14,21 +15,23 @@
1415
from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution
1516
from transformers import AutoTokenizer, CLIPTextModel, CLIPTokenizer, LlamaModel
1617
from finetrainers.data._artifact import VideoArtifact
17-
from finetrainers.models.hunyuan_video import hunyuan_common
1818
from finetrainers.models.utils import _expand_conv3d_with_zeroed_weights
19+
from finetrainers.trainer.control_trainer.config import FrameConditioningType
1920
from finetrainers.utils.serialization import safetensors_torch_save_function
2021

2122
from ... import data
2223
from ... import functional as FF
2324
from ...logging import get_logger
2425
from ...patches.dependencies.diffusers.control import control_channel_concat
2526
from ...processors import ProcessorMixin
26-
from ...typing import ArtifactType, FrameConditioningType, SchedulerType
27+
from ...typing import ArtifactType, SchedulerType
2728
from ...utils import get_non_null_items
2829
from ..modeling_utils import ControlModelSpecification
2930
from .base_specification import HunyuanLatentEncodeProcessor
3031
from ...processors import CLIPPooledProcessor, LlamaProcessor, ProcessorMixin
3132

33+
from ...utils import _enable_vae_memory_optimizations, get_non_null_items
34+
3235
logger = get_logger()
3336

3437

@@ -88,11 +91,102 @@ def control_injection_layer_name(self) -> str:
8891
def _resolution_dim_keys(self):
8992
return {"latents": (2, 3, 4)}
9093

91-
load_condition_models = hunyuan_common.load_condition_models
94+
def load_condition_models(self) -> Dict[str, torch.nn.Module]:
95+
common_kwargs = {"revision": self.revision, "cache_dir": self.cache_dir}
96+
97+
if self.tokenizer_id is not None:
98+
tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_id, **common_kwargs)
99+
else:
100+
tokenizer = AutoTokenizer.from_pretrained(
101+
self.pretrained_model_name_or_path, subfolder="tokenizer", **common_kwargs
102+
)
103+
104+
if self.tokenizer_2_id is not None:
105+
tokenizer_2 = AutoTokenizer.from_pretrained(self.tokenizer_2_id, **common_kwargs)
106+
else:
107+
tokenizer_2 = CLIPTokenizer.from_pretrained(
108+
self.pretrained_model_name_or_path, subfolder="tokenizer_2", **common_kwargs
109+
)
110+
111+
if self.text_encoder_id is not None:
112+
text_encoder = LlamaModel.from_pretrained(
113+
self.text_encoder_id, torch_dtype=self.text_encoder_dtype, **common_kwargs
114+
)
115+
else:
116+
text_encoder = LlamaModel.from_pretrained(
117+
self.pretrained_model_name_or_path,
118+
subfolder="text_encoder",
119+
torch_dtype=self.text_encoder_dtype,
120+
**common_kwargs,
121+
)
122+
123+
if self.text_encoder_2_id is not None:
124+
text_encoder_2 = CLIPTextModel.from_pretrained(
125+
self.text_encoder_2_id, torch_dtype=self.text_encoder_2_dtype, **common_kwargs
126+
)
127+
else:
128+
text_encoder_2 = CLIPTextModel.from_pretrained(
129+
self.pretrained_model_name_or_path,
130+
subfolder="text_encoder_2",
131+
torch_dtype=self.text_encoder_2_dtype,
132+
**common_kwargs,
133+
)
134+
135+
return {
136+
"tokenizer": tokenizer,
137+
"tokenizer_2": tokenizer_2,
138+
"text_encoder": text_encoder,
139+
"text_encoder_2": text_encoder_2,
140+
}
141+
142+
def load_latent_models(self) -> Dict[str, torch.nn.Module]:
143+
common_kwargs = {"revision": self.revision, "cache_dir": self.cache_dir}
144+
145+
if self.vae_id is not None:
146+
vae = AutoencoderKLHunyuanVideo.from_pretrained(self.vae_id, torch_dtype=self.vae_dtype, **common_kwargs)
147+
else:
148+
vae = AutoencoderKLHunyuanVideo.from_pretrained(
149+
self.pretrained_model_name_or_path, subfolder="vae", torch_dtype=self.vae_dtype, **common_kwargs
150+
)
151+
152+
return {"vae": vae}
92153

93-
load_latent_models = hunyuan_common.load_latent_models
154+
def load_pipeline(
155+
self,
156+
tokenizer: Optional[AutoTokenizer] = None,
157+
tokenizer_2: Optional[CLIPTokenizer] = None,
158+
text_encoder: Optional[LlamaModel] = None,
159+
text_encoder_2: Optional[CLIPTextModel] = None,
160+
transformer: Optional[Module] = None,
161+
vae: Optional[AutoencoderKLHunyuanVideo] = None,
162+
scheduler: Optional[FlowMatchEulerDiscreteScheduler] = None,
163+
enable_slicing: bool = False,
164+
enable_tiling: bool = False,
165+
enable_model_cpu_offload: bool = False,
166+
training: bool = False,
167+
**kwargs,
168+
) -> HunyuanVideoPipeline:
169+
components = {
170+
"tokenizer": tokenizer,
171+
"tokenizer_2": tokenizer_2,
172+
"text_encoder": text_encoder,
173+
"text_encoder_2": text_encoder_2,
174+
"transformer": transformer,
175+
"vae": vae,
176+
"scheduler": scheduler,
177+
}
178+
components = get_non_null_items(components)
179+
180+
pipe = HunyuanVideoPipeline.from_pretrained(
181+
self.pretrained_model_name_or_path, **components, revision=self.revision, cache_dir=self.cache_dir
182+
)
183+
pipe.text_encoder.to(self.text_encoder_dtype)
184+
pipe.text_encoder_2.to(self.text_encoder_2_dtype)
185+
pipe.vae.to(self.vae_dtype)
94186

95-
load_pipeline = hunyuan_common.load_pipeline
187+
_enable_vae_memory_optimizations(pipe.vae, enable_slicing, enable_tiling)
188+
if not training:
189+
pipe.transformer.to(self.transformer_dtype)
96190

97191
def load_diffusion_models(self, new_in_features: int) -> Dict[str, torch.nn.Module]:
98192
common_kwargs = {"revision": self.revision, "cache_dir": self.cache_dir}

finetrainers/models/hunyuan_video/hunyuan_common.py

Lines changed: 0 additions & 111 deletions
This file was deleted.

0 commit comments

Comments
 (0)