Skip to content

Commit ab2476b

Browse files
committed
update
1 parent b150276 commit ab2476b

File tree

2 files changed

+67
-8
lines changed

2 files changed

+67
-8
lines changed

scripts/convert_hunyuan_video_to_diffusers.py

Lines changed: 57 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
import torch
55
from accelerate import init_empty_weights
6-
from transformers import AutoModel, AutoTokenizer, CLIPTextModel, CLIPTokenizer
6+
from transformers import AutoModel, AutoTokenizer, CLIPTextModel, CLIPTokenizer, LlavaForConditionalGeneration
77

88
from diffusers import (
99
AutoencoderKLHunyuanVideo,
@@ -134,6 +134,46 @@ def remap_single_transformer_blocks_(key, state_dict):
134134
VAE_SPECIAL_KEYS_REMAP = {}
135135

136136

137+
TRANSFORMER_CONFIGS = {
138+
"HYVideo-T/2-cfgdistill": {
139+
"in_channels": 16,
140+
"out_channels": 16,
141+
"num_attention_heads": 24,
142+
"attention_head_dim": 128,
143+
"num_layers": 20,
144+
"num_single_layers": 40,
145+
"num_refiner_layers": 2,
146+
"mlp_ratio": 4.0,
147+
"patch_size": 2,
148+
"patch_size_t": 1,
149+
"qk_norm": "rms_norm",
150+
"guidance_embeds": True,
151+
"text_embed_dim": 4096,
152+
"pooled_projection_dim": 768,
153+
"rope_theta": 256.0,
154+
"rope_axes_dim": (16, 56, 56),
155+
},
156+
"HYVideo-T/2": {
157+
"in_channels": 16 * 2 + 1,
158+
"out_channels": 16,
159+
"num_attention_heads": 24,
160+
"attention_head_dim": 128,
161+
"num_layers": 20,
162+
"num_single_layers": 40,
163+
"num_refiner_layers": 2,
164+
"mlp_ratio": 4.0,
165+
"patch_size": 2,
166+
"patch_size_t": 1,
167+
"qk_norm": "rms_norm",
168+
"guidance_embeds": False,
169+
"text_embed_dim": 4096,
170+
"pooled_projection_dim": 768,
171+
"rope_theta": 256.0,
172+
"rope_axes_dim": (16, 56, 56),
173+
},
174+
}
175+
176+
137177
def update_state_dict_(state_dict: Dict[str, Any], old_key: str, new_key: str) -> Dict[str, Any]:
138178
state_dict[new_key] = state_dict.pop(old_key)
139179

@@ -149,11 +189,12 @@ def get_state_dict(saved_dict: Dict[str, Any]) -> Dict[str, Any]:
149189
return state_dict
150190

151191

152-
def convert_transformer(ckpt_path: str):
192+
def convert_transformer(ckpt_path: str, transformer_type: str):
153193
original_state_dict = get_state_dict(torch.load(ckpt_path, map_location="cpu", weights_only=True))
194+
config = TRANSFORMER_CONFIGS[transformer_type]
154195

155196
with init_empty_weights():
156-
transformer = HunyuanVideoTransformer3DModel()
197+
transformer = HunyuanVideoTransformer3DModel(**config)
157198

158199
for key in list(original_state_dict.keys()):
159200
new_key = key[:]
@@ -205,6 +246,10 @@ def get_args():
205246
parser.add_argument("--save_pipeline", action="store_true")
206247
parser.add_argument("--output_path", type=str, required=True, help="Path where converted model should be saved")
207248
parser.add_argument("--dtype", default="bf16", help="Torch dtype to save the transformer in.")
249+
parser.add_argument(
250+
"--transformer_type", type=str, default="HYVideo-T/2-cfgdistill", choices=list(TRANSFORMER_CONFIGS.keys())
251+
)
252+
parser.add_argument("--flow_shift", type=float, default=7.0)
208253
return parser.parse_args()
209254

210255

@@ -228,7 +273,7 @@ def get_args():
228273
assert args.text_encoder_2_path is not None
229274

230275
if args.transformer_ckpt_path is not None:
231-
transformer = convert_transformer(args.transformer_ckpt_path)
276+
transformer = convert_transformer(args.transformer_ckpt_path, args.transformer_type)
232277
transformer = transformer.to(dtype=dtype)
233278
if not args.save_pipeline:
234279
transformer.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB")
@@ -239,11 +284,17 @@ def get_args():
239284
vae.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB")
240285

241286
if args.save_pipeline:
242-
text_encoder = AutoModel.from_pretrained(args.text_encoder_path, torch_dtype=torch.float16)
287+
if args.transformer_type == "HYVideo-T/2-cfgdistill":
288+
text_encoder = AutoModel.from_pretrained(args.text_encoder_path, torch_dtype=torch.float16)
289+
else:
290+
text_encoder = LlavaForConditionalGeneration.from_pretrained(
291+
args.text_encoder_path, torch_dtype=torch.float16
292+
)
293+
243294
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_path, padding_side="right")
244295
text_encoder_2 = CLIPTextModel.from_pretrained(args.text_encoder_2_path, torch_dtype=torch.float16)
245296
tokenizer_2 = CLIPTokenizer.from_pretrained(args.text_encoder_2_path)
246-
scheduler = FlowMatchEulerDiscreteScheduler(shift=7.0)
297+
scheduler = FlowMatchEulerDiscreteScheduler(shift=args.flow_shift)
247298

248299
pipe = HunyuanVideoPipeline(
249300
transformer=transformer,

src/diffusers/models/transformers/transformer_hunyuan_video.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -581,7 +581,11 @@ def __init__(
581581
self.context_embedder = HunyuanVideoTokenRefiner(
582582
text_embed_dim, num_attention_heads, attention_head_dim, num_layers=num_refiner_layers
583583
)
584-
self.time_text_embed = CombinedTimestepGuidanceTextProjEmbeddings(inner_dim, pooled_projection_dim)
584+
585+
if guidance_embeds:
586+
self.time_text_embed = CombinedTimestepGuidanceTextProjEmbeddings(inner_dim, pooled_projection_dim)
587+
else:
588+
self.time_text_embed = CombinedTimestepTextProjEmbeddings(inner_dim, pooled_projection_dim)
585589

586590
# 2. RoPE
587591
self.rope = HunyuanVideoRotaryPosEmbed(patch_size, patch_size_t, rope_axes_dim, rope_theta)
@@ -708,7 +712,11 @@ def forward(
708712
image_rotary_emb = self.rope(hidden_states)
709713

710714
# 2. Conditional embeddings
711-
temb = self.time_text_embed(timestep, guidance, pooled_projections)
715+
if self.config.guidance_embeds:
716+
temb = self.time_text_embed(timestep, guidance, pooled_projections)
717+
else:
718+
temb = self.time_text_embed(timestep, pooled_projections)
719+
712720
hidden_states = self.x_embedder(hidden_states)
713721
encoder_hidden_states = self.context_embedder(encoder_hidden_states, timestep, encoder_attention_mask)
714722

0 commit comments

Comments
 (0)