@@ -795,8 +795,6 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
795795        quantization_config  =  kwargs .pop ("quantization_config" , None )
796796        dduf_entries : Optional [Dict [str , DDUFEntry ]] =  kwargs .pop ("dduf_entries" , None )
797797        disable_mmap  =  kwargs .pop ("disable_mmap" , False )
798-         state_dict  =  kwargs .pop ("state_dict" , None )
799-         config  =  kwargs .pop ("config" , None )
800798
801799        allow_pickle  =  False 
802800        if  use_safetensors  is  None :
@@ -867,39 +865,35 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
867865                # The max memory utils require PyTorch >= 1.10 to have torch.cuda.mem_get_info. 
868866                raise  ValueError ("`low_cpu_mem_usage` and `device_map` require PyTorch >= 1.10." )
869867
870-         if  (not  config  and  state_dict ) or  (config  and  not  state_dict ):
871-             raise  ValueError ("You need to pass both the config and the state dict to initalize the model." )
872- 
873868        user_agent  =  {
874869            "diffusers" : __version__ ,
875870            "file_type" : "model" ,
876871            "framework" : "pytorch" ,
877872        }
878873        unused_kwargs  =  {}
879874
880-         if  config  is  None :
881-             # Load config if we don't provide a configuration 
882-             config_path  =  pretrained_model_name_or_path 
875+         # Load config if we don't provide a configuration 
876+         config_path  =  pretrained_model_name_or_path 
883877
884-              # TODO: We need to let the user pass a config in from_pretrained 
885-              # load config 
886-              config , unused_kwargs , commit_hash  =  cls .load_config (
887-                  config_path ,
888-                  cache_dir = cache_dir ,
889-                  return_unused_kwargs = True ,
890-                  return_commit_hash = True ,
891-                  force_download = force_download ,
892-                  proxies = proxies ,
893-                  local_files_only = local_files_only ,
894-                  token = token ,
895-                  revision = revision ,
896-                  subfolder = subfolder ,
897-                  user_agent = user_agent ,
898-                  dduf_entries = dduf_entries ,
899-                  ** kwargs ,
900-              )
901-              # no in-place modification of the original config. 
902-              config  =  copy .deepcopy (config )
878+         # TODO: We need to let the user pass a config in from_pretrained 
879+         # load config 
880+         config , unused_kwargs , commit_hash  =  cls .load_config (
881+             config_path ,
882+             cache_dir = cache_dir ,
883+             return_unused_kwargs = True ,
884+             return_commit_hash = True ,
885+             force_download = force_download ,
886+             proxies = proxies ,
887+             local_files_only = local_files_only ,
888+             token = token ,
889+             revision = revision ,
890+             subfolder = subfolder ,
891+             user_agent = user_agent ,
892+             dduf_entries = dduf_entries ,
893+             ** kwargs ,
894+         )
895+         # no in-place modification of the original config. 
896+         config  =  copy .deepcopy (config )
903897
904898        # determine initial quantization config. 
905899        ####################################### 
@@ -951,103 +945,79 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
951945
952946        is_sharded  =  False 
953947        resolved_archive_file  =  None 
954-         if  state_dict  is  None :
955-             # Determine if we're loading from a directory of sharded checkpoints. 
956-             sharded_metadata  =  None 
957-             index_file  =  None 
958-             is_local  =  os .path .isdir (pretrained_model_name_or_path )
959-             index_file_kwargs  =  {
960-                 "is_local" : is_local ,
961-                 "pretrained_model_name_or_path" : pretrained_model_name_or_path ,
962-                 "subfolder" : subfolder  or  "" ,
963-                 "use_safetensors" : use_safetensors ,
964-                 "cache_dir" : cache_dir ,
965-                 "variant" : variant ,
966-                 "force_download" : force_download ,
967-                 "proxies" : proxies ,
968-                 "local_files_only" : local_files_only ,
969-                 "token" : token ,
970-                 "revision" : revision ,
971-                 "user_agent" : user_agent ,
972-                 "commit_hash" : commit_hash ,
973-                 "dduf_entries" : dduf_entries ,
974-             }
975-             index_file  =  _fetch_index_file (** index_file_kwargs )
976-             # In case the index file was not found we still have to consider the legacy format. 
977-             # this becomes applicable when the variant is not None. 
978-             if  variant  is  not None  and  (index_file  is  None  or  not  os .path .exists (index_file )):
979-                 index_file  =  _fetch_index_file_legacy (** index_file_kwargs )
980-             if  index_file  is  not None  and  (dduf_entries  or  index_file .is_file ()):
981-                 is_sharded  =  True 
982- 
983-             if  is_sharded  and  from_flax :
984-                 raise  ValueError ("Loading of sharded checkpoints is not supported when `from_flax=True`." )
985- 
986-             # load model 
987-             if  from_flax :
988-                 resolved_archive_file  =  _get_model_file (
948+ 
949+         # Determine if we're loading from a directory of sharded checkpoints. 
950+         sharded_metadata  =  None 
951+         index_file  =  None 
952+         is_local  =  os .path .isdir (pretrained_model_name_or_path )
953+         index_file_kwargs  =  {
954+             "is_local" : is_local ,
955+             "pretrained_model_name_or_path" : pretrained_model_name_or_path ,
956+             "subfolder" : subfolder  or  "" ,
957+             "use_safetensors" : use_safetensors ,
958+             "cache_dir" : cache_dir ,
959+             "variant" : variant ,
960+             "force_download" : force_download ,
961+             "proxies" : proxies ,
962+             "local_files_only" : local_files_only ,
963+             "token" : token ,
964+             "revision" : revision ,
965+             "user_agent" : user_agent ,
966+             "commit_hash" : commit_hash ,
967+             "dduf_entries" : dduf_entries ,
968+         }
969+         index_file  =  _fetch_index_file (** index_file_kwargs )
970+         # In case the index file was not found we still have to consider the legacy format. 
971+         # this becomes applicable when the variant is not None. 
972+         if  variant  is  not None  and  (index_file  is  None  or  not  os .path .exists (index_file )):
973+             index_file  =  _fetch_index_file_legacy (** index_file_kwargs )
974+         if  index_file  is  not None  and  (dduf_entries  or  index_file .is_file ()):
975+             is_sharded  =  True 
976+ 
977+         if  is_sharded  and  from_flax :
978+             raise  ValueError ("Loading of sharded checkpoints is not supported when `from_flax=True`." )
979+ 
980+         # load model 
981+         if  from_flax :
982+             resolved_archive_file  =  _get_model_file (
983+                 pretrained_model_name_or_path ,
984+                 weights_name = FLAX_WEIGHTS_NAME ,
985+                 cache_dir = cache_dir ,
986+                 force_download = force_download ,
987+                 proxies = proxies ,
988+                 local_files_only = local_files_only ,
989+                 token = token ,
990+                 revision = revision ,
991+                 subfolder = subfolder ,
992+                 user_agent = user_agent ,
993+                 commit_hash = commit_hash ,
994+             )
995+             model  =  cls .from_config (config , ** unused_kwargs )
996+ 
997+             # Convert the weights 
998+             from  .modeling_pytorch_flax_utils  import  load_flax_checkpoint_in_pytorch_model 
999+ 
1000+             model  =  load_flax_checkpoint_in_pytorch_model (model , resolved_archive_file )
1001+         else :
1002+             # in the case it is sharded, we have already the index 
1003+             if  is_sharded :
1004+                 resolved_archive_file , sharded_metadata  =  _get_checkpoint_shard_files (
9891005                    pretrained_model_name_or_path ,
990-                     weights_name = FLAX_WEIGHTS_NAME ,
1006+                     index_file ,
9911007                    cache_dir = cache_dir ,
992-                     force_download = force_download ,
9931008                    proxies = proxies ,
9941009                    local_files_only = local_files_only ,
9951010                    token = token ,
996-                     revision = revision ,
997-                     subfolder = subfolder ,
9981011                    user_agent = user_agent ,
999-                     commit_hash = commit_hash ,
1012+                     revision = revision ,
1013+                     subfolder = subfolder  or  "" ,
1014+                     dduf_entries = dduf_entries ,
10001015                )
1001-                 model  =  cls .from_config (config , ** unused_kwargs )
1002- 
1003-                 # Convert the weights 
1004-                 from  .modeling_pytorch_flax_utils  import  load_flax_checkpoint_in_pytorch_model 
1005- 
1006-                 model  =  load_flax_checkpoint_in_pytorch_model (model , resolved_archive_file )
1007-             else :
1008-                 # in the case it is sharded, we have already the index 
1009-                 if  is_sharded :
1010-                     resolved_archive_file , sharded_metadata  =  _get_checkpoint_shard_files (
1011-                         pretrained_model_name_or_path ,
1012-                         index_file ,
1013-                         cache_dir = cache_dir ,
1014-                         proxies = proxies ,
1015-                         local_files_only = local_files_only ,
1016-                         token = token ,
1017-                         user_agent = user_agent ,
1018-                         revision = revision ,
1019-                         subfolder = subfolder  or  "" ,
1020-                         dduf_entries = dduf_entries ,
1021-                     )
1022-                 elif  use_safetensors :
1023-                     try :
1024-                         resolved_archive_file  =  _get_model_file (
1025-                             pretrained_model_name_or_path ,
1026-                             weights_name = _add_variant (SAFETENSORS_WEIGHTS_NAME , variant ),
1027-                             cache_dir = cache_dir ,
1028-                             force_download = force_download ,
1029-                             proxies = proxies ,
1030-                             local_files_only = local_files_only ,
1031-                             token = token ,
1032-                             revision = revision ,
1033-                             subfolder = subfolder ,
1034-                             user_agent = user_agent ,
1035-                             commit_hash = commit_hash ,
1036-                             dduf_entries = dduf_entries ,
1037-                         )
1038- 
1039-                     except  IOError  as  e :
1040-                         logger .error (f"An error occurred while trying to fetch { pretrained_model_name_or_path } { e }  )
1041-                         if  not  allow_pickle :
1042-                             raise 
1043-                         logger .warning (
1044-                             "Defaulting to unsafe serialization. Pass `allow_pickle=False` to raise an error instead." 
1045-                         )
1046- 
1047-                 if  resolved_archive_file  is  None  and  not  is_sharded :
1016+             elif  use_safetensors :
1017+                 try :
10481018                    resolved_archive_file  =  _get_model_file (
10491019                        pretrained_model_name_or_path ,
1050-                         weights_name = _add_variant (WEIGHTS_NAME , variant ),
1020+                         weights_name = _add_variant (SAFETENSORS_WEIGHTS_NAME , variant ),
10511021                        cache_dir = cache_dir ,
10521022                        force_download = force_download ,
10531023                        proxies = proxies ,
@@ -1060,6 +1030,30 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
10601030                        dduf_entries = dduf_entries ,
10611031                    )
10621032
1033+                 except  IOError  as  e :
1034+                     logger .error (f"An error occurred while trying to fetch { pretrained_model_name_or_path } { e }  )
1035+                     if  not  allow_pickle :
1036+                         raise 
1037+                     logger .warning (
1038+                         "Defaulting to unsafe serialization. Pass `allow_pickle=False` to raise an error instead." 
1039+                     )
1040+ 
1041+             if  resolved_archive_file  is  None  and  not  is_sharded :
1042+                 resolved_archive_file  =  _get_model_file (
1043+                     pretrained_model_name_or_path ,
1044+                     weights_name = _add_variant (WEIGHTS_NAME , variant ),
1045+                     cache_dir = cache_dir ,
1046+                     force_download = force_download ,
1047+                     proxies = proxies ,
1048+                     local_files_only = local_files_only ,
1049+                     token = token ,
1050+                     revision = revision ,
1051+                     subfolder = subfolder ,
1052+                     user_agent = user_agent ,
1053+                     commit_hash = commit_hash ,
1054+                     dduf_entries = dduf_entries ,
1055+                 )
1056+ 
10631057        if  not  isinstance (resolved_archive_file , list ):
10641058            resolved_archive_file  =  [resolved_archive_file ]
10651059
@@ -1084,7 +1078,8 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
10841078        if  dtype_orig  is  not None :
10851079            torch .set_default_dtype (dtype_orig )
10861080
1087-         if  not  is_sharded  and  state_dict  is  None :
1081+         state_dict  =  None 
1082+         if  not  is_sharded :
10881083            # Time to load the checkpoint 
10891084            state_dict  =  load_state_dict (
10901085                resolved_archive_file [0 ], disable_mmap = disable_mmap , dduf_entries = dduf_entries 
0 commit comments