@@ -35,20 +35,26 @@ def _maybe_map_sgm_blocks_to_diffusers(state_dict, unet_config, delimiter="_", b
3535 sgm_patterns = ["input_blocks" , "middle_block" , "output_blocks" ]
3636 not_sgm_patterns = ["down_blocks" , "mid_block" , "up_blocks" ]
3737
38- # Purge out unnecessary blocks.
39- for block in not_sgm_patterns :
40- for k in all_keys :
41- if block in k :
42- state_dict .pop (k )
43-
44- revised_all_keys = []
38+ # check if state_dict contains both patterns
39+ contains_sgm_patterns = False
40+ contains_not_sgm_patterns = False
4541 for key in all_keys :
46- if not any (pattern in key for pattern in not_sgm_patterns ):
47- revised_all_keys .append (key )
42+ if any (p in key for p in sgm_patterns ):
43+ contains_sgm_patterns = True
44+ elif any (p in key for p in not_sgm_patterns ):
45+ contains_not_sgm_patterns = True
46+
47+ # if state_dict contains both patterns, remove sgm
48+ # we can then return state_dict immediately
49+ if contains_sgm_patterns and contains_not_sgm_patterns :
50+ for key in all_keys :
51+ if any (p in key for p in sgm_patterns ):
52+ state_dict .pop (key )
53+ return state_dict
4854
4955 # 2. check if needs remapping, if not return original dict
5056 is_in_sgm_format = False
51- for key in revised_all_keys :
57+ for key in all_keys :
5258 if any (p in key for p in sgm_patterns ):
5359 is_in_sgm_format = True
5460 break
@@ -63,7 +69,7 @@ def _maybe_map_sgm_blocks_to_diffusers(state_dict, unet_config, delimiter="_", b
6369 # Retrieves # of down, mid and up blocks
6470 input_block_ids , middle_block_ids , output_block_ids = set (), set (), set ()
6571
66- for layer in revised_all_keys :
72+ for layer in all_keys :
6773 if "text" in layer :
6874 new_state_dict [layer ] = state_dict .pop (layer )
6975 else :
0 commit comments