1919
2020from ..configuration_utils import ConfigMixin
2121from ..utils import logging
22+ from ..utils .dynamic_modules_utils import get_class_from_dynamic_module , resolve_trust_remote_code
2223
2324
2425logger = logging .get_logger (__name__ )
@@ -114,6 +115,8 @@ def from_pretrained(cls, pretrained_model_or_path: Optional[Union[str, os.PathLi
114115 disable_mmap ('bool', *optional*, defaults to 'False'):
115116 Whether to disable mmap when loading a Safetensors model. This option can perform better when the model
116117 is on a network mount or hard drive, which may not handle the seeky-ness of mmap very well.
118+ trust_remote_cocde (`bool`, *optional*, defaults to `False`):
119+ Whether to trust remote code
117120
118121 <Tip>
119122
@@ -140,22 +143,22 @@ def from_pretrained(cls, pretrained_model_or_path: Optional[Union[str, os.PathLi
140143 You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
141144 ```
142145 """
143- cache_dir = kwargs .pop ("cache_dir" , None )
144- force_download = kwargs .pop ("force_download" , False )
145- proxies = kwargs .pop ("proxies" , None )
146- token = kwargs .pop ("token" , None )
147- local_files_only = kwargs .pop ("local_files_only" , False )
148- revision = kwargs .pop ("revision" , None )
149146 subfolder = kwargs .pop ("subfolder" , None )
150-
151- load_config_kwargs = {
152- "cache_dir" : cache_dir ,
153- "force_download" : force_download ,
154- "proxies" : proxies ,
155- "token" : token ,
156- "local_files_only" : local_files_only ,
157- "revision" : revision ,
158- }
147+ trust_remote_code = kwargs .pop ("trust_remote_code" , False )
148+
149+ hub_kwargs_names = [
150+ "cache_dir" ,
151+ "force_download" ,
152+ "local_files_only" ,
153+ "proxies" ,
154+ "resume_download" ,
155+ "revision" ,
156+ "token" ,
157+ ]
158+ hub_kwargs = {name : kwargs .pop (name , None ) for name in hub_kwargs_names }
159+
160+ # load_config_kwargs uses the same hub kwargs minus subfolder and resume_download
161+ load_config_kwargs = {k : v for k , v in hub_kwargs .items () if k not in ["subfolder" , "resume_download" ]}
159162
160163 library = None
161164 orig_class_name = None
@@ -189,15 +192,35 @@ def from_pretrained(cls, pretrained_model_or_path: Optional[Union[str, os.PathLi
189192 else :
190193 raise ValueError (f"Couldn't find model associated with the config file at { pretrained_model_or_path } ." )
191194
192- from ..pipelines .pipeline_loading_utils import ALL_IMPORTABLE_CLASSES , get_class_obj_and_candidates
193-
194- model_cls , _ = get_class_obj_and_candidates (
195- library_name = library ,
196- class_name = orig_class_name ,
197- importable_classes = ALL_IMPORTABLE_CLASSES ,
198- pipelines = None ,
199- is_pipeline_module = False ,
200- )
195+ has_remote_code = "auto_map" in config and cls .__name__ in config ["auto_map" ]
196+ trust_remote_code = resolve_trust_remote_code (trust_remote_code , pretrained_model_or_path , has_remote_code )
197+ if not (has_remote_code and trust_remote_code ):
198+ raise ValueError (
199+ "Selected model repository does not happear to have any custom code or does not have a valid `config.json` file."
200+ )
201+
202+ if has_remote_code and trust_remote_code :
203+ class_ref = config ["auto_map" ][cls .__name__ ]
204+ module_file , class_name = class_ref .split ("." )
205+ module_file = module_file + ".py"
206+ model_cls = get_class_from_dynamic_module (
207+ pretrained_model_or_path ,
208+ subfolder = subfolder ,
209+ module_file = module_file ,
210+ class_name = class_name ,
211+ ** hub_kwargs ,
212+ ** kwargs ,
213+ )
214+ else :
215+ from ..pipelines .pipeline_loading_utils import ALL_IMPORTABLE_CLASSES , get_class_obj_and_candidates
216+
217+ model_cls , _ = get_class_obj_and_candidates (
218+ library_name = library ,
219+ class_name = orig_class_name ,
220+ importable_classes = ALL_IMPORTABLE_CLASSES ,
221+ pipelines = None ,
222+ is_pipeline_module = False ,
223+ )
201224
202225 if model_cls is None :
203226 raise ValueError (f"AutoModel can't find a model linked to { orig_class_name } ." )
0 commit comments