@@ -33,6 +33,24 @@ def _maybe_map_sgm_blocks_to_diffusers(state_dict, unet_config, delimiter="_", b
3333 # 1. get all state_dict_keys
3434 all_keys = list (state_dict .keys ())
3535 sgm_patterns = ["input_blocks" , "middle_block" , "output_blocks" ]
36+ not_sgm_patterns = ["down_blocks" , "mid_block" , "up_blocks" ]
37+
38+ # check if state_dict contains both patterns
39+ contains_sgm_patterns = False
40+ contains_not_sgm_patterns = False
41+ for key in all_keys :
42+ if any (p in key for p in sgm_patterns ):
43+ contains_sgm_patterns = True
44+ elif any (p in key for p in not_sgm_patterns ):
45+ contains_not_sgm_patterns = True
46+
47+ # if state_dict contains both patterns, remove sgm
48+ # we can then return state_dict immediately
49+ if contains_sgm_patterns and contains_not_sgm_patterns :
50+ for key in all_keys :
51+ if any (p in key for p in sgm_patterns ):
52+ state_dict .pop (key )
53+ return state_dict
3654
3755 # 2. check if needs remapping, if not return original dict
3856 is_in_sgm_format = False
@@ -126,7 +144,7 @@ def _maybe_map_sgm_blocks_to_diffusers(state_dict, unet_config, delimiter="_", b
126144 )
127145 new_state_dict [new_key ] = state_dict .pop (key )
128146
129- if len ( state_dict ) > 0 :
147+ if state_dict :
130148 raise ValueError ("At this point all state dict entries have to be converted." )
131149
132150 return new_state_dict
0 commit comments