11import argparse
2- from contextlib import nullcontext
32
43import safetensors .torch
54import torch
6- from accelerate import init_empty_weights
75from huggingface_hub import hf_hub_download
86
9- from diffusers .utils .import_utils import is_accelerate_available
10-
11-
12- CTX = init_empty_weights if is_accelerate_available else nullcontext
137
148parser = argparse .ArgumentParser ()
159parser .add_argument ("--original_state_dict_repo_id" , default = None , type = str )
2216dtype = torch .bfloat16 if args .dtype == "bf16" else torch .float32
2317
2418
25- # Adapted from from the original BFL codebase.
26- def optionally_expand_state_dict (name : str , param : torch .Tensor , state_dict : dict ) -> dict :
27- if name in state_dict :
28- print (f"Expanding '{ name } ' with shape { state_dict [name ].shape } to model parameter with shape { param .shape } ." )
29- # expand with zeros:
30- expanded_state_dict_weight = torch .zeros_like (param , device = state_dict [name ].device )
31- # popular with pre-trained param for the first half. Remaining half stays with zeros.
32- slices = tuple (slice (0 , dim ) for dim in state_dict [name ].shape )
33- expanded_state_dict_weight [slices ] = state_dict [name ]
34- state_dict [name ] = expanded_state_dict_weight
35-
36- return state_dict
37-
38-
3919def load_original_checkpoint (args ):
4020 if args .original_state_dict_repo_id is not None :
4121 ckpt_path = hf_hub_download (repo_id = args .original_state_dict_repo_id , filename = args .filename )
4222 elif args .checkpoint_path is not None :
4323 ckpt_path = args .checkpoint_path
4424 else :
45- raise ValueError (" please provide either `original_state_dict_repo_id` or a local `checkpoint_path`" )
25+ raise ValueError ("Please provide either `original_state_dict_repo_id` or a local `checkpoint_path`" )
4626
4727 original_state_dict = safetensors .torch .load_file (ckpt_path )
4828 return original_state_dict
@@ -60,7 +40,7 @@ def convert_flux_control_lora_checkpoint_to_diffusers(
6040 original_state_dict , num_layers , num_single_layers , inner_dim , mlp_ratio = 4.0
6141):
6242 converted_state_dict = {}
63- original_state_dict_keys = original_state_dict .keys ()
43+ original_state_dict_keys = list ( original_state_dict .keys () )
6444
6545 for lora_key in ["lora_A" , "lora_B" ]:
6646 ## time_text_embed.timestep_embedder <- time_in
@@ -346,7 +326,8 @@ def convert_flux_control_lora_checkpoint_to_diffusers(
346326 original_state_dict .pop (f"final_layer.adaLN_modulation.1.{ lora_key } .bias" )
347327 )
348328
349- print ("Remaining:" , original_state_dict .keys ())
329+ if len (original_state_dict ) > 0 :
330+ raise ValueError (f"`original_state_dict` should be empty at this point but has { original_state_dict .keys ()= } ." )
350331
351332 for key in list (converted_state_dict .keys ()):
352333 converted_state_dict [f"transformer.{ key } " ] = converted_state_dict .pop (key )
0 commit comments