33
44import  torch 
55from  accelerate  import  init_empty_weights 
6- from  transformers  import  AutoModel , AutoTokenizer , CLIPTextModel , CLIPTokenizer 
6+ from  transformers  import  AutoModel , AutoTokenizer , CLIPTextModel , CLIPTokenizer ,  LlavaForConditionalGeneration 
77
88from  diffusers  import  (
99    AutoencoderKLHunyuanVideo ,
@@ -134,6 +134,46 @@ def remap_single_transformer_blocks_(key, state_dict):
134134VAE_SPECIAL_KEYS_REMAP  =  {}
135135
136136
137+ TRANSFORMER_CONFIGS  =  {
138+     "HYVideo-T/2-cfgdistill" : {
139+         "in_channels" : 16 ,
140+         "out_channels" : 16 ,
141+         "num_attention_heads" : 24 ,
142+         "attention_head_dim" : 128 ,
143+         "num_layers" : 20 ,
144+         "num_single_layers" : 40 ,
145+         "num_refiner_layers" : 2 ,
146+         "mlp_ratio" : 4.0 ,
147+         "patch_size" : 2 ,
148+         "patch_size_t" : 1 ,
149+         "qk_norm" : "rms_norm" ,
150+         "guidance_embeds" : True ,
151+         "text_embed_dim" : 4096 ,
152+         "pooled_projection_dim" : 768 ,
153+         "rope_theta" : 256.0 ,
154+         "rope_axes_dim" : (16 , 56 , 56 ),
155+     },
156+     "HYVideo-T/2" : {
157+         "in_channels" : 16  *  2  +  1 ,
158+         "out_channels" : 16 ,
159+         "num_attention_heads" : 24 ,
160+         "attention_head_dim" : 128 ,
161+         "num_layers" : 20 ,
162+         "num_single_layers" : 40 ,
163+         "num_refiner_layers" : 2 ,
164+         "mlp_ratio" : 4.0 ,
165+         "patch_size" : 2 ,
166+         "patch_size_t" : 1 ,
167+         "qk_norm" : "rms_norm" ,
168+         "guidance_embeds" : False ,
169+         "text_embed_dim" : 4096 ,
170+         "pooled_projection_dim" : 768 ,
171+         "rope_theta" : 256.0 ,
172+         "rope_axes_dim" : (16 , 56 , 56 ),
173+     },
174+ }
175+ 
176+ 
137177def  update_state_dict_ (state_dict : Dict [str , Any ], old_key : str , new_key : str ) ->  Dict [str , Any ]:
138178    state_dict [new_key ] =  state_dict .pop (old_key )
139179
@@ -149,11 +189,12 @@ def get_state_dict(saved_dict: Dict[str, Any]) -> Dict[str, Any]:
149189    return  state_dict 
150190
151191
152- def  convert_transformer (ckpt_path : str ):
192+ def  convert_transformer (ckpt_path : str ,  transformer_type :  str ):
153193    original_state_dict  =  get_state_dict (torch .load (ckpt_path , map_location = "cpu" , weights_only = True ))
194+     config  =  TRANSFORMER_CONFIGS [transformer_type ]
154195
155196    with  init_empty_weights ():
156-         transformer  =  HunyuanVideoTransformer3DModel ()
197+         transformer  =  HunyuanVideoTransformer3DModel (** config )
157198
158199    for  key  in  list (original_state_dict .keys ()):
159200        new_key  =  key [:]
@@ -205,6 +246,10 @@ def get_args():
205246    parser .add_argument ("--save_pipeline" , action = "store_true" )
206247    parser .add_argument ("--output_path" , type = str , required = True , help = "Path where converted model should be saved" )
207248    parser .add_argument ("--dtype" , default = "bf16" , help = "Torch dtype to save the transformer in." )
249+     parser .add_argument (
250+         "--transformer_type" , type = str , default = "HYVideo-T/2-cfgdistill" , choices = list (TRANSFORMER_CONFIGS .keys ())
251+     )
252+     parser .add_argument ("--flow_shift" , type = float , default = 7.0 )
208253    return  parser .parse_args ()
209254
210255
@@ -228,7 +273,7 @@ def get_args():
228273        assert  args .text_encoder_2_path  is  not None 
229274
230275    if  args .transformer_ckpt_path  is  not None :
231-         transformer  =  convert_transformer (args .transformer_ckpt_path )
276+         transformer  =  convert_transformer (args .transformer_ckpt_path ,  args . transformer_type )
232277        transformer  =  transformer .to (dtype = dtype )
233278        if  not  args .save_pipeline :
234279            transformer .save_pretrained (args .output_path , safe_serialization = True , max_shard_size = "5GB" )
@@ -239,11 +284,17 @@ def get_args():
239284            vae .save_pretrained (args .output_path , safe_serialization = True , max_shard_size = "5GB" )
240285
241286    if  args .save_pipeline :
242-         text_encoder  =  AutoModel .from_pretrained (args .text_encoder_path , torch_dtype = torch .float16 )
287+         if  args .transformer_type  ==  "HYVideo-T/2-cfgdistill" :
288+             text_encoder  =  AutoModel .from_pretrained (args .text_encoder_path , torch_dtype = torch .float16 )
289+         else :
290+             text_encoder  =  LlavaForConditionalGeneration .from_pretrained (
291+                 args .text_encoder_path , torch_dtype = torch .float16 
292+             )
293+ 
243294        tokenizer  =  AutoTokenizer .from_pretrained (args .tokenizer_path , padding_side = "right" )
244295        text_encoder_2  =  CLIPTextModel .from_pretrained (args .text_encoder_2_path , torch_dtype = torch .float16 )
245296        tokenizer_2  =  CLIPTokenizer .from_pretrained (args .text_encoder_2_path )
246-         scheduler  =  FlowMatchEulerDiscreteScheduler (shift = 7.0 )
297+         scheduler  =  FlowMatchEulerDiscreteScheduler (shift = args . flow_shift )
247298
248299        pipe  =  HunyuanVideoPipeline (
249300            transformer = transformer ,
0 commit comments