88import torch
99import torch .distributed as dist
1010import torch .distributed .checkpoint as dcp
11- import torch .distributed .checkpoint .stateful
1211from einops import rearrange
1312from safetensors .torch import save_file
1413
@@ -154,13 +153,20 @@ def save_checkpoint(transformer,
154153
155154 if rank == 0 :
156155 # Save model weights (consolidated)
157- weight_path = os .path .join (save_dir ,
156+ transformer_save_dir = os .path .join (save_dir , "transformer" )
157+ os .makedirs (transformer_save_dir , exist_ok = True )
158+ weight_path = os .path .join (transformer_save_dir ,
158159 "diffusion_pytorch_model.safetensors" )
159160 logger .info ("rank: %s, saving consolidated checkpoint to %s" ,
160161 rank ,
161162 weight_path ,
162163 local_main_process_only = False )
163- save_file (cpu_state , weight_path )
164+
165+ # Convert training format to diffusers format and save
166+ diffusers_state_dict = convert_training_to_diffusers_format (
167+ cpu_state , transformer )
168+ save_file (diffusers_state_dict , weight_path )
169+
164170 logger .info ("rank: %s, consolidated checkpoint saved to %s" ,
165171 rank ,
166172 weight_path ,
@@ -170,7 +176,7 @@ def save_checkpoint(transformer,
170176 config_dict = transformer .hf_config
171177 if "dtype" in config_dict :
172178 del config_dict ["dtype" ] # TODO
173- config_path = os .path .join (save_dir , "config.json" )
179+ config_path = os .path .join (transformer_save_dir , "config.json" )
174180 # save dict as json
175181 with open (config_path , "w" ) as f :
176182 json .dump (config_dict , f , indent = 4 )
@@ -479,3 +485,68 @@ def _has_foreach_support(tensors: List[torch.Tensor],
479485 device : torch .device ) -> bool :
480486 return _device_has_foreach_support (device ) and all (
481487 t is None or type (t ) in [torch .Tensor ] for t in tensors )
488+
489+
490+ def convert_training_to_diffusers_format (state_dict : Dict [str , Any ],
491+ transformer ) -> Dict [str , Any ]:
492+ """
493+ Convert training format state dict to diffusers format using reverse_param_names_mapping.
494+
495+ Args:
496+ state_dict: State dict in training format
497+ transformer: Transformer model object with _reverse_param_names_mapping
498+
499+ Returns:
500+ State dict in diffusers format
501+ """
502+ new_state_dict = {}
503+
504+ # Get the reverse mapping from the transformer
505+ reverse_param_names_mapping = transformer ._reverse_param_names_mapping
506+ assert reverse_param_names_mapping != {}, "reverse_param_names_mapping is empty"
507+
508+ # Group parameters that need to be split (merged parameters)
509+ merge_groups : Dict [str , List [Tuple [str , int , int ]]] = {}
510+
511+ # First pass: collect all merge groups
512+ for training_key , (
513+ diffusers_key , merge_index ,
514+ num_params_to_merge ) in reverse_param_names_mapping .items ():
515+ if merge_index is not None :
516+ # This is a merged parameter that needs to be split
517+ if training_key not in merge_groups :
518+ merge_groups [training_key ] = []
519+ merge_groups [training_key ].append (
520+ (diffusers_key , merge_index , num_params_to_merge ))
521+
522+ # Second pass: handle merged parameters by splitting them
523+ used_keys = set ()
524+ for training_key , splits in merge_groups .items ():
525+ if training_key in state_dict :
526+ v = state_dict [training_key ]
527+ # Sort by merge_index to ensure correct order
528+ splits .sort (key = lambda x : x [1 ])
529+ total = splits [0 ][2 ]
530+ split_size = v .shape [0 ] // total
531+ split_tensors = torch .split (v , split_size , dim = 0 )
532+
533+ for diffusers_key , split_index , _ in splits :
534+ new_state_dict [diffusers_key ] = split_tensors [split_index ]
535+ used_keys .add (training_key )
536+
537+ # Third pass: handle regular parameters (direct mappings)
538+ for training_key , v in state_dict .items ():
539+ if training_key in used_keys :
540+ continue
541+
542+ if training_key in reverse_param_names_mapping :
543+ diffusers_key , merge_index , _ = reverse_param_names_mapping [
544+ training_key ]
545+ if merge_index is None :
546+ # Direct mapping
547+ new_state_dict [diffusers_key ] = v
548+ else :
549+ # No mapping found, keep as is
550+ new_state_dict [training_key ] = v
551+
552+ return new_state_dict
0 commit comments