| 
51 | 51 | 
 
  | 
52 | 52 | logger = logging.get_logger(__name__)  | 
53 | 53 | 
 
  | 
 | 54 | +LORA_WEIGHT_NAME = "pytorch_lora_weights.bin"  | 
 | 55 | +LORA_WEIGHT_NAME_SAFE = "pytorch_lora_weights.safetensors"  | 
 | 56 | + | 
54 | 57 | 
 
  | 
55 | 58 | def fuse_text_encoder_lora(text_encoder, lora_scale=1.0, safe_fusing=False, adapter_names=None):  | 
56 | 59 |     """  | 
@@ -181,6 +184,119 @@ def _remove_text_encoder_monkey_patch(text_encoder):  | 
181 | 184 |         text_encoder._hf_peft_config_loaded = None  | 
182 | 185 | 
 
  | 
183 | 186 | 
 
  | 
 | 187 | +def _fetch_state_dict(  | 
 | 188 | +    pretrained_model_name_or_path_or_dict,  | 
 | 189 | +    weight_name,  | 
 | 190 | +    use_safetensors,  | 
 | 191 | +    local_files_only,  | 
 | 192 | +    cache_dir,  | 
 | 193 | +    force_download,  | 
 | 194 | +    proxies,  | 
 | 195 | +    token,  | 
 | 196 | +    revision,  | 
 | 197 | +    subfolder,  | 
 | 198 | +    user_agent,  | 
 | 199 | +    allow_pickle,  | 
 | 200 | +):  | 
 | 201 | +    model_file = None  | 
 | 202 | +    if not isinstance(pretrained_model_name_or_path_or_dict, dict):  | 
 | 203 | +        # Let's first try to load .safetensors weights  | 
 | 204 | +        if (use_safetensors and weight_name is None) or (  | 
 | 205 | +            weight_name is not None and weight_name.endswith(".safetensors")  | 
 | 206 | +        ):  | 
 | 207 | +            try:  | 
 | 208 | +                # Here we're relaxing the loading check to enable more Inference API  | 
 | 209 | +                # friendliness where sometimes, it's not at all possible to automatically  | 
 | 210 | +                # determine `weight_name`.  | 
 | 211 | +                if weight_name is None:  | 
 | 212 | +                    weight_name = _best_guess_weight_name(  | 
 | 213 | +                        pretrained_model_name_or_path_or_dict,  | 
 | 214 | +                        file_extension=".safetensors",  | 
 | 215 | +                        local_files_only=local_files_only,  | 
 | 216 | +                    )  | 
 | 217 | +                model_file = _get_model_file(  | 
 | 218 | +                    pretrained_model_name_or_path_or_dict,  | 
 | 219 | +                    weights_name=weight_name or LORA_WEIGHT_NAME_SAFE,  | 
 | 220 | +                    cache_dir=cache_dir,  | 
 | 221 | +                    force_download=force_download,  | 
 | 222 | +                    proxies=proxies,  | 
 | 223 | +                    local_files_only=local_files_only,  | 
 | 224 | +                    token=token,  | 
 | 225 | +                    revision=revision,  | 
 | 226 | +                    subfolder=subfolder,  | 
 | 227 | +                    user_agent=user_agent,  | 
 | 228 | +                )  | 
 | 229 | +                state_dict = safetensors.torch.load_file(model_file, device="cpu")  | 
 | 230 | +            except (IOError, safetensors.SafetensorError) as e:  | 
 | 231 | +                if not allow_pickle:  | 
 | 232 | +                    raise e  | 
 | 233 | +                # try loading non-safetensors weights  | 
 | 234 | +                model_file = None  | 
 | 235 | +                pass  | 
 | 236 | + | 
 | 237 | +        if model_file is None:  | 
 | 238 | +            if weight_name is None:  | 
 | 239 | +                weight_name = _best_guess_weight_name(  | 
 | 240 | +                    pretrained_model_name_or_path_or_dict, file_extension=".bin", local_files_only=local_files_only  | 
 | 241 | +                )  | 
 | 242 | +            model_file = _get_model_file(  | 
 | 243 | +                pretrained_model_name_or_path_or_dict,  | 
 | 244 | +                weights_name=weight_name or LORA_WEIGHT_NAME,  | 
 | 245 | +                cache_dir=cache_dir,  | 
 | 246 | +                force_download=force_download,  | 
 | 247 | +                proxies=proxies,  | 
 | 248 | +                local_files_only=local_files_only,  | 
 | 249 | +                token=token,  | 
 | 250 | +                revision=revision,  | 
 | 251 | +                subfolder=subfolder,  | 
 | 252 | +                user_agent=user_agent,  | 
 | 253 | +            )  | 
 | 254 | +            state_dict = load_state_dict(model_file)  | 
 | 255 | +    else:  | 
 | 256 | +        state_dict = pretrained_model_name_or_path_or_dict  | 
 | 257 | + | 
 | 258 | +    return state_dict  | 
 | 259 | + | 
 | 260 | + | 
 | 261 | +def _best_guess_weight_name(  | 
 | 262 | +    pretrained_model_name_or_path_or_dict, file_extension=".safetensors", local_files_only=False  | 
 | 263 | +):  | 
 | 264 | +    if local_files_only or HF_HUB_OFFLINE:  | 
 | 265 | +        raise ValueError("When using the offline mode, you must specify a `weight_name`.")  | 
 | 266 | + | 
 | 267 | +    targeted_files = []  | 
 | 268 | + | 
 | 269 | +    if os.path.isfile(pretrained_model_name_or_path_or_dict):  | 
 | 270 | +        return  | 
 | 271 | +    elif os.path.isdir(pretrained_model_name_or_path_or_dict):  | 
 | 272 | +        targeted_files = [f for f in os.listdir(pretrained_model_name_or_path_or_dict) if f.endswith(file_extension)]  | 
 | 273 | +    else:  | 
 | 274 | +        files_in_repo = model_info(pretrained_model_name_or_path_or_dict).siblings  | 
 | 275 | +        targeted_files = [f.rfilename for f in files_in_repo if f.rfilename.endswith(file_extension)]  | 
 | 276 | +    if len(targeted_files) == 0:  | 
 | 277 | +        return  | 
 | 278 | + | 
 | 279 | +    # "scheduler" does not correspond to a LoRA checkpoint.  | 
 | 280 | +    # "optimizer" does not correspond to a LoRA checkpoint  | 
 | 281 | +    # only top-level checkpoints are considered and not the other ones, hence "checkpoint".  | 
 | 282 | +    unallowed_substrings = {"scheduler", "optimizer", "checkpoint"}  | 
 | 283 | +    targeted_files = list(  | 
 | 284 | +        filter(lambda x: all(substring not in x for substring in unallowed_substrings), targeted_files)  | 
 | 285 | +    )  | 
 | 286 | + | 
 | 287 | +    if any(f.endswith(LORA_WEIGHT_NAME) for f in targeted_files):  | 
 | 288 | +        targeted_files = list(filter(lambda x: x.endswith(LORA_WEIGHT_NAME), targeted_files))  | 
 | 289 | +    elif any(f.endswith(LORA_WEIGHT_NAME_SAFE) for f in targeted_files):  | 
 | 290 | +        targeted_files = list(filter(lambda x: x.endswith(LORA_WEIGHT_NAME_SAFE), targeted_files))  | 
 | 291 | + | 
 | 292 | +    if len(targeted_files) > 1:  | 
 | 293 | +        raise ValueError(  | 
 | 294 | +            f"Provided path contains more than one weights file in the {file_extension} format. Either specify `weight_name` in `load_lora_weights` or make sure there's only one  `.safetensors` or `.bin` file in  {pretrained_model_name_or_path_or_dict}."  | 
 | 295 | +        )  | 
 | 296 | +    weight_name = targeted_files[0]  | 
 | 297 | +    return weight_name  | 
 | 298 | + | 
 | 299 | + | 
184 | 300 | class LoraBaseMixin:  | 
185 | 301 |     """Utility class for handling LoRAs."""  | 
186 | 302 | 
 
  | 
@@ -234,124 +350,16 @@ def _optionally_disable_offloading(cls, _pipeline):  | 
234 | 350 |         return (is_model_cpu_offload, is_sequential_cpu_offload)  | 
235 | 351 | 
 
  | 
236 | 352 |     @classmethod  | 
237 |  | -    def _fetch_state_dict(  | 
238 |  | -        cls,  | 
239 |  | -        pretrained_model_name_or_path_or_dict,  | 
240 |  | -        weight_name,  | 
241 |  | -        use_safetensors,  | 
242 |  | -        local_files_only,  | 
243 |  | -        cache_dir,  | 
244 |  | -        force_download,  | 
245 |  | -        proxies,  | 
246 |  | -        token,  | 
247 |  | -        revision,  | 
248 |  | -        subfolder,  | 
249 |  | -        user_agent,  | 
250 |  | -        allow_pickle,  | 
251 |  | -    ):  | 
252 |  | -        from .lora_pipeline import LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE  | 
253 |  | - | 
254 |  | -        model_file = None  | 
255 |  | -        if not isinstance(pretrained_model_name_or_path_or_dict, dict):  | 
256 |  | -            # Let's first try to load .safetensors weights  | 
257 |  | -            if (use_safetensors and weight_name is None) or (  | 
258 |  | -                weight_name is not None and weight_name.endswith(".safetensors")  | 
259 |  | -            ):  | 
260 |  | -                try:  | 
261 |  | -                    # Here we're relaxing the loading check to enable more Inference API  | 
262 |  | -                    # friendliness where sometimes, it's not at all possible to automatically  | 
263 |  | -                    # determine `weight_name`.  | 
264 |  | -                    if weight_name is None:  | 
265 |  | -                        weight_name = cls._best_guess_weight_name(  | 
266 |  | -                            pretrained_model_name_or_path_or_dict,  | 
267 |  | -                            file_extension=".safetensors",  | 
268 |  | -                            local_files_only=local_files_only,  | 
269 |  | -                        )  | 
270 |  | -                    model_file = _get_model_file(  | 
271 |  | -                        pretrained_model_name_or_path_or_dict,  | 
272 |  | -                        weights_name=weight_name or LORA_WEIGHT_NAME_SAFE,  | 
273 |  | -                        cache_dir=cache_dir,  | 
274 |  | -                        force_download=force_download,  | 
275 |  | -                        proxies=proxies,  | 
276 |  | -                        local_files_only=local_files_only,  | 
277 |  | -                        token=token,  | 
278 |  | -                        revision=revision,  | 
279 |  | -                        subfolder=subfolder,  | 
280 |  | -                        user_agent=user_agent,  | 
281 |  | -                    )  | 
282 |  | -                    state_dict = safetensors.torch.load_file(model_file, device="cpu")  | 
283 |  | -                except (IOError, safetensors.SafetensorError) as e:  | 
284 |  | -                    if not allow_pickle:  | 
285 |  | -                        raise e  | 
286 |  | -                    # try loading non-safetensors weights  | 
287 |  | -                    model_file = None  | 
288 |  | -                    pass  | 
289 |  | - | 
290 |  | -            if model_file is None:  | 
291 |  | -                if weight_name is None:  | 
292 |  | -                    weight_name = cls._best_guess_weight_name(  | 
293 |  | -                        pretrained_model_name_or_path_or_dict, file_extension=".bin", local_files_only=local_files_only  | 
294 |  | -                    )  | 
295 |  | -                model_file = _get_model_file(  | 
296 |  | -                    pretrained_model_name_or_path_or_dict,  | 
297 |  | -                    weights_name=weight_name or LORA_WEIGHT_NAME,  | 
298 |  | -                    cache_dir=cache_dir,  | 
299 |  | -                    force_download=force_download,  | 
300 |  | -                    proxies=proxies,  | 
301 |  | -                    local_files_only=local_files_only,  | 
302 |  | -                    token=token,  | 
303 |  | -                    revision=revision,  | 
304 |  | -                    subfolder=subfolder,  | 
305 |  | -                    user_agent=user_agent,  | 
306 |  | -                )  | 
307 |  | -                state_dict = load_state_dict(model_file)  | 
308 |  | -        else:  | 
309 |  | -            state_dict = pretrained_model_name_or_path_or_dict  | 
310 |  | - | 
311 |  | -        return state_dict  | 
 | 353 | +    def _fetch_state_dict(cls, *args, **kwargs):  | 
 | 354 | +        deprecation_message = f"Using the `_fetch_state_dict()` method from {cls} has been deprecated and will be removed in a future version. Please use `from diffusers.loaders.lora_base import _fetch_state_dict`."  | 
 | 355 | +        deprecate("_fetch_state_dict", "0.35.0", deprecation_message)  | 
 | 356 | +        return _fetch_state_dict(*args, **kwargs)  | 
312 | 357 | 
 
  | 
313 | 358 |     @classmethod  | 
314 |  | -    def _best_guess_weight_name(  | 
315 |  | -        cls, pretrained_model_name_or_path_or_dict, file_extension=".safetensors", local_files_only=False  | 
316 |  | -    ):  | 
317 |  | -        from .lora_pipeline import LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE  | 
318 |  | - | 
319 |  | -        if local_files_only or HF_HUB_OFFLINE:  | 
320 |  | -            raise ValueError("When using the offline mode, you must specify a `weight_name`.")  | 
321 |  | - | 
322 |  | -        targeted_files = []  | 
323 |  | - | 
324 |  | -        if os.path.isfile(pretrained_model_name_or_path_or_dict):  | 
325 |  | -            return  | 
326 |  | -        elif os.path.isdir(pretrained_model_name_or_path_or_dict):  | 
327 |  | -            targeted_files = [  | 
328 |  | -                f for f in os.listdir(pretrained_model_name_or_path_or_dict) if f.endswith(file_extension)  | 
329 |  | -            ]  | 
330 |  | -        else:  | 
331 |  | -            files_in_repo = model_info(pretrained_model_name_or_path_or_dict).siblings  | 
332 |  | -            targeted_files = [f.rfilename for f in files_in_repo if f.rfilename.endswith(file_extension)]  | 
333 |  | -        if len(targeted_files) == 0:  | 
334 |  | -            return  | 
335 |  | - | 
336 |  | -        # "scheduler" does not correspond to a LoRA checkpoint.  | 
337 |  | -        # "optimizer" does not correspond to a LoRA checkpoint  | 
338 |  | -        # only top-level checkpoints are considered and not the other ones, hence "checkpoint".  | 
339 |  | -        unallowed_substrings = {"scheduler", "optimizer", "checkpoint"}  | 
340 |  | -        targeted_files = list(  | 
341 |  | -            filter(lambda x: all(substring not in x for substring in unallowed_substrings), targeted_files)  | 
342 |  | -        )  | 
343 |  | - | 
344 |  | -        if any(f.endswith(LORA_WEIGHT_NAME) for f in targeted_files):  | 
345 |  | -            targeted_files = list(filter(lambda x: x.endswith(LORA_WEIGHT_NAME), targeted_files))  | 
346 |  | -        elif any(f.endswith(LORA_WEIGHT_NAME_SAFE) for f in targeted_files):  | 
347 |  | -            targeted_files = list(filter(lambda x: x.endswith(LORA_WEIGHT_NAME_SAFE), targeted_files))  | 
348 |  | - | 
349 |  | -        if len(targeted_files) > 1:  | 
350 |  | -            raise ValueError(  | 
351 |  | -                f"Provided path contains more than one weights file in the {file_extension} format. Either specify `weight_name` in `load_lora_weights` or make sure there's only one  `.safetensors` or `.bin` file in  {pretrained_model_name_or_path_or_dict}."  | 
352 |  | -            )  | 
353 |  | -        weight_name = targeted_files[0]  | 
354 |  | -        return weight_name  | 
 | 359 | +    def _best_guess_weight_name(cls, *args, **kwargs):  | 
 | 360 | +        deprecation_message = f"Using the `_best_guess_weight_name()` method from {cls} has been deprecated and will be removed in a future version. Please use `from diffusers.loaders.lora_base import _best_guess_weight_name`."  | 
 | 361 | +        deprecate("_best_guess_weight_name", "0.35.0", deprecation_message)  | 
 | 362 | +        return _best_guess_weight_name(*args, **kwargs)  | 
355 | 363 | 
 
  | 
356 | 364 |     def unload_lora_weights(self):  | 
357 | 365 |         """  | 
@@ -725,8 +733,6 @@ def write_lora_layers(  | 
725 | 733 |         save_function: Callable,  | 
726 | 734 |         safe_serialization: bool,  | 
727 | 735 |     ):  | 
728 |  | -        from .lora_pipeline import LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE  | 
729 |  | - | 
730 | 736 |         if os.path.isfile(save_directory):  | 
731 | 737 |             logger.error(f"Provided path ({save_directory}) should be a directory, not a file")  | 
732 | 738 |             return  | 
 | 
0 commit comments