Skip to content

Commit 16778b1

Browse files
committed
rename
1 parent 9039db4 commit 16778b1

File tree

2 files changed

+28
-35
lines changed

2 files changed

+28
-35
lines changed

scripts/convert_hunyuan_video_to_diffusers.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,9 @@
33

44
import torch
55
from accelerate import init_empty_weights
6+
from transformers import AutoModel, AutoTokenizer, CLIPTextModel, CLIPTokenizer
67

7-
from diffusers import AutoencoderKLHunyuanVideo, HunyuanVideoTransformer3DModel
8+
from diffusers import AutoencoderKLHunyuanVideo, HunyuanVideoTransformer3DModel, HunyuanVideoPipeline
89

910

1011
def remap_norm_scale_shift_(key, state_dict):
@@ -76,6 +77,8 @@ def remap_single_transformer_blocks_(key, state_dict):
7677
# "guidance_in.mlp.2": "time_text_embed.guidance_embedder.linear_2",
7778
# "vector_in.in_layer": "time_text_embed.text_embedder.linear_1",
7879
# "vector_in.out_layer": "time_text_embed.text_embedder.linear_2",
80+
"txt_in.t_embedder": "txt_in.time_embed",
81+
"txt_in.c_embedder": "txt_in.context_embed",
7982
"double_blocks": "transformer_blocks",
8083
"individual_token_refiner.blocks": "token_refiner.refiner_blocks",
8184
"img_attn_q_norm": "attn.norm_q",
@@ -179,6 +182,8 @@ def get_args():
179182
"--transformer_ckpt_path", type=str, default=None, help="Path to original transformer checkpoint"
180183
)
181184
parser.add_argument("--vae_ckpt_path", type=str, default=None, help="Path to original VAE checkpoint")
185+
parser.add_argument("--text_encoder_path", type=str, default=None, help="Path to original llama checkpoint")
186+
parser.add_argument("--text_encoder_2_path", type=str, default=None, help="Path to original clip checkpoint")
182187
parser.add_argument("--save_pipeline", action="store_true")
183188
parser.add_argument("--output_path", type=str, required=True, help="Path where converted model should be saved")
184189
parser.add_argument("--dtype", default="bf16", help="Torch dtype to save the transformer in.")
@@ -200,6 +205,8 @@ def get_args():
200205

201206
if args.save_pipeline:
202207
assert args.transformer_ckpt_path is not None and args.vae_ckpt_path is not None
208+
assert args.text_encoder_path is not None
209+
assert args.text_encoder_2_path is not None
203210

204211
if args.transformer_ckpt_path is not None:
205212
transformer = convert_transformer(args.transformer_ckpt_path)
@@ -211,3 +218,19 @@ def get_args():
211218
vae = convert_vae(args.vae_ckpt_path)
212219
if not args.save_pipeline:
213220
vae.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB")
221+
222+
if args.save_pipeline:
223+
text_encoder = AutoModel.from_pretrained(args.text_encoder_path, torch_dtype=torch.float16)
224+
tokenizer = AutoTokenizer.from_pretrained(args.text_encoder_path, padding_side="right")
225+
text_encoder_2 = CLIPTextModel.from_pretrained(args.text_encoder_2_path, torch_dtype=torch.float16)
226+
tokenizer_2 = CLIPTokenizer.from_pretrained(args.text_encoder_2_path)
227+
228+
pipe = HunyuanVideoPipeline(
229+
transformer=transformer,
230+
vae=vae,
231+
text_encoder=text_encoder,
232+
tokenizer=tokenizer,
233+
text_encoder_2=text_encoder_2,
234+
tokenizer_2=tokenizer_2,
235+
)
236+
pipe.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB")

src/diffusers/models/transformers/transformer_hunyuan_video.py

Lines changed: 4 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -168,31 +168,6 @@ def forward(self, caption):
168168
return hidden_states
169169

170170

171-
def timestep_embedding(t, dim, max_period=10000):
172-
"""
173-
Create sinusoidal timestep embeddings.
174-
175-
Args:
176-
t (torch.Tensor): a 1-D Tensor of N indices, one per batch element. These may be fractional.
177-
dim (int): the dimension of the output.
178-
max_period (int): controls the minimum frequency of the embeddings.
179-
180-
Returns:
181-
embedding (torch.Tensor): An (N, D) Tensor of positional embeddings.
182-
183-
.. ref_link: https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
184-
"""
185-
half = dim // 2
186-
freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(
187-
device=t.device
188-
)
189-
args = t[:, None].float() * freqs[None]
190-
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
191-
if dim % 2:
192-
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
193-
return embedding
194-
195-
196171
class TimestepEmbedder(nn.Module):
197172
"""
198173
Embeds scalar timesteps into vector representations.
@@ -219,7 +194,6 @@ def __init__(
219194
)
220195

221196
def forward(self, t):
222-
# t_freq = timestep_embedding(t, self.frequency_embedding_size, self.max_period).type(self.mlp[0].weight.dtype)
223197
t_freq = get_timestep_embedding(t, self.frequency_embedding_size, flip_sin_to_cos=True, max_period=self.max_period, downscale_freq_shift=0).type(self.mlp[0].weight.dtype)
224198
t_emb = self.mlp(t_freq)
225199
return t_emb
@@ -340,10 +314,8 @@ def __init__(
340314
hidden_size = num_attention_heads * attention_head_dim
341315

342316
self.input_embedder = nn.Linear(in_channels, hidden_size, bias=True)
343-
# self.time_embed = TimestepEmbedder(hidden_size, nn.SiLU)
344-
# self.context_embed = TextProjection(in_channels, hidden_size, nn.SiLU)
345-
self.t_embedder = TimestepEmbedder(hidden_size, nn.SiLU)
346-
self.c_embedder = TextProjection(in_channels, hidden_size, nn.SiLU)
317+
self.time_embed = TimestepEmbedder(hidden_size, nn.SiLU)
318+
self.context_embed = TextProjection(in_channels, hidden_size, nn.SiLU)
347319

348320
self.token_refiner = IndividualTokenRefiner(
349321
num_attention_heads=num_attention_heads,
@@ -361,8 +333,7 @@ def forward(
361333
attention_mask: Optional[torch.LongTensor] = None,
362334
) -> torch.Tensor:
363335
original_dtype = hidden_states.dtype
364-
# temb = self.time_embed(timestep)
365-
temb = self.t_embedder(timestep)
336+
temb = self.time_embed(timestep)
366337

367338
if attention_mask is None:
368339
pooled_projections = hidden_states.mean(dim=1)
@@ -371,8 +342,7 @@ def forward(
371342
pooled_projections = (hidden_states * mask_float).sum(dim=1) / mask_float.sum(dim=1)
372343
pooled_projections = pooled_projections.to(original_dtype)
373344

374-
# pooled_projections = self.context_embed(pooled_projections)
375-
pooled_projections = self.c_embedder(pooled_projections)
345+
pooled_projections = self.context_embed(pooled_projections)
376346
emb = temb + pooled_projections
377347

378348
hidden_states = self.input_embedder(hidden_states)

0 commit comments

Comments
 (0)