@@ -33,6 +33,24 @@ def _maybe_map_sgm_blocks_to_diffusers(state_dict, unet_config, delimiter="_", b
3333    # 1. get all state_dict_keys 
3434    all_keys  =  list (state_dict .keys ())
3535    sgm_patterns  =  ["input_blocks" , "middle_block" , "output_blocks" ]
36+     not_sgm_patterns  =  ["down_blocks" , "mid_block" , "up_blocks" ]
37+ 
38+     # check if state_dict contains both patterns 
39+     contains_sgm_patterns  =  False 
40+     contains_not_sgm_patterns  =  False 
41+     for  key  in  all_keys :
42+         if  any (p  in  key  for  p  in  sgm_patterns ):
43+             contains_sgm_patterns  =  True 
44+         elif  any (p  in  key  for  p  in  not_sgm_patterns ):
45+             contains_not_sgm_patterns  =  True 
46+ 
47+     # if state_dict contains both patterns, remove sgm 
48+     # we can then return state_dict immediately 
49+     if  contains_sgm_patterns  and  contains_not_sgm_patterns :
50+         for  key  in  all_keys :
51+             if  any (p  in  key  for  p  in  sgm_patterns ):
52+                 state_dict .pop (key )
53+         return  state_dict 
3654
3755    # 2. check if needs remapping, if not return original dict 
3856    is_in_sgm_format  =  False 
@@ -126,7 +144,7 @@ def _maybe_map_sgm_blocks_to_diffusers(state_dict, unet_config, delimiter="_", b
126144            )
127145            new_state_dict [new_key ] =  state_dict .pop (key )
128146
129-     if  len ( state_dict )  >   0 :
147+     if  state_dict :
130148        raise  ValueError ("At this point all state dict entries have to be converted." )
131149
132150    return  new_state_dict 
0 commit comments