|  | 
| 28 | 28 | from ..utils import ( | 
| 29 | 29 |     USE_PEFT_BACKEND, | 
| 30 | 30 |     _get_model_file, | 
|  | 31 | +    convert_state_dict_to_diffusers, | 
|  | 32 | +    convert_state_dict_to_peft, | 
| 31 | 33 |     delete_adapter_layers, | 
| 32 | 34 |     deprecate, | 
|  | 35 | +    get_adapter_name, | 
|  | 36 | +    get_peft_kwargs, | 
| 33 | 37 |     is_accelerate_available, | 
| 34 | 38 |     is_peft_available, | 
|  | 39 | +    is_peft_version, | 
| 35 | 40 |     is_transformers_available, | 
|  | 41 | +    is_transformers_version, | 
| 36 | 42 |     logging, | 
| 37 | 43 |     recurse_remove_peft_layers, | 
|  | 44 | +    scale_lora_layers, | 
| 38 | 45 |     set_adapter_layers, | 
| 39 | 46 |     set_weights_and_activate_adapters, | 
| 40 | 47 | ) | 
|  | 
| 43 | 50 | if is_transformers_available(): | 
| 44 | 51 |     from transformers import PreTrainedModel | 
| 45 | 52 | 
 | 
|  | 53 | +    from ..models.lora import text_encoder_attn_modules, text_encoder_mlp_modules | 
|  | 54 | + | 
| 46 | 55 | if is_peft_available(): | 
| 47 | 56 |     from peft.tuners.tuners_utils import BaseTunerLayer | 
| 48 | 57 | 
 | 
| @@ -297,6 +306,152 @@ def _best_guess_weight_name( | 
| 297 | 306 |     return weight_name | 
| 298 | 307 | 
 | 
| 299 | 308 | 
 | 
|  | 309 | +def _load_lora_into_text_encoder( | 
|  | 310 | +    state_dict, | 
|  | 311 | +    network_alphas, | 
|  | 312 | +    text_encoder, | 
|  | 313 | +    prefix=None, | 
|  | 314 | +    lora_scale=1.0, | 
|  | 315 | +    text_encoder_name="text_encoder", | 
|  | 316 | +    adapter_name=None, | 
|  | 317 | +    _pipeline=None, | 
|  | 318 | +    low_cpu_mem_usage=False, | 
|  | 319 | +): | 
|  | 320 | +    if not USE_PEFT_BACKEND: | 
|  | 321 | +        raise ValueError("PEFT backend is required for this method.") | 
|  | 322 | + | 
|  | 323 | +    peft_kwargs = {} | 
|  | 324 | +    if low_cpu_mem_usage: | 
|  | 325 | +        if not is_peft_version(">=", "0.13.1"): | 
|  | 326 | +            raise ValueError( | 
|  | 327 | +                "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`." | 
|  | 328 | +            ) | 
|  | 329 | +        if not is_transformers_version(">", "4.45.2"): | 
|  | 330 | +            # Note from sayakpaul: It's not in `transformers` stable yet. | 
|  | 331 | +            # https://github.com/huggingface/transformers/pull/33725/ | 
|  | 332 | +            raise ValueError( | 
|  | 333 | +                "`low_cpu_mem_usage=True` is not compatible with this `transformers` version. Please update it with `pip install -U transformers`." | 
|  | 334 | +            ) | 
|  | 335 | +        peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage | 
|  | 336 | + | 
|  | 337 | +    from peft import LoraConfig | 
|  | 338 | + | 
|  | 339 | +    # If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918), | 
|  | 340 | +    # then the `state_dict` keys should have `unet_name` and/or `text_encoder_name` as | 
|  | 341 | +    # their prefixes. | 
|  | 342 | +    keys = list(state_dict.keys()) | 
|  | 343 | +    prefix = text_encoder_name if prefix is None else prefix | 
|  | 344 | + | 
|  | 345 | +    # Safe prefix to check with. | 
|  | 346 | +    if any(text_encoder_name in key for key in keys): | 
|  | 347 | +        # Load the layers corresponding to text encoder and make necessary adjustments. | 
|  | 348 | +        text_encoder_keys = [k for k in keys if k.startswith(prefix) and k.split(".")[0] == prefix] | 
|  | 349 | +        text_encoder_lora_state_dict = { | 
|  | 350 | +            k.replace(f"{prefix}.", ""): v for k, v in state_dict.items() if k in text_encoder_keys | 
|  | 351 | +        } | 
|  | 352 | + | 
|  | 353 | +        if len(text_encoder_lora_state_dict) > 0: | 
|  | 354 | +            logger.info(f"Loading {prefix}.") | 
|  | 355 | +            rank = {} | 
|  | 356 | +            text_encoder_lora_state_dict = convert_state_dict_to_diffusers(text_encoder_lora_state_dict) | 
|  | 357 | + | 
|  | 358 | +            # convert state dict | 
|  | 359 | +            text_encoder_lora_state_dict = convert_state_dict_to_peft(text_encoder_lora_state_dict) | 
|  | 360 | + | 
|  | 361 | +            for name, _ in text_encoder_attn_modules(text_encoder): | 
|  | 362 | +                for module in ("out_proj", "q_proj", "k_proj", "v_proj"): | 
|  | 363 | +                    rank_key = f"{name}.{module}.lora_B.weight" | 
|  | 364 | +                    if rank_key not in text_encoder_lora_state_dict: | 
|  | 365 | +                        continue | 
|  | 366 | +                    rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1] | 
|  | 367 | + | 
|  | 368 | +            for name, _ in text_encoder_mlp_modules(text_encoder): | 
|  | 369 | +                for module in ("fc1", "fc2"): | 
|  | 370 | +                    rank_key = f"{name}.{module}.lora_B.weight" | 
|  | 371 | +                    if rank_key not in text_encoder_lora_state_dict: | 
|  | 372 | +                        continue | 
|  | 373 | +                    rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1] | 
|  | 374 | + | 
|  | 375 | +            if network_alphas is not None: | 
|  | 376 | +                alpha_keys = [k for k in network_alphas.keys() if k.startswith(prefix) and k.split(".")[0] == prefix] | 
|  | 377 | +                network_alphas = {k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys} | 
|  | 378 | + | 
|  | 379 | +            lora_config_kwargs = get_peft_kwargs(rank, network_alphas, text_encoder_lora_state_dict, is_unet=False) | 
|  | 380 | + | 
|  | 381 | +            if "use_dora" in lora_config_kwargs: | 
|  | 382 | +                if lora_config_kwargs["use_dora"]: | 
|  | 383 | +                    if is_peft_version("<", "0.9.0"): | 
|  | 384 | +                        raise ValueError( | 
|  | 385 | +                            "You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`." | 
|  | 386 | +                        ) | 
|  | 387 | +                else: | 
|  | 388 | +                    if is_peft_version("<", "0.9.0"): | 
|  | 389 | +                        lora_config_kwargs.pop("use_dora") | 
|  | 390 | + | 
|  | 391 | +            if "lora_bias" in lora_config_kwargs: | 
|  | 392 | +                if lora_config_kwargs["lora_bias"]: | 
|  | 393 | +                    if is_peft_version("<=", "0.13.2"): | 
|  | 394 | +                        raise ValueError( | 
|  | 395 | +                            "You need `peft` 0.14.0 at least to use `bias` in LoRAs. Please upgrade your installation of `peft`." | 
|  | 396 | +                        ) | 
|  | 397 | +                else: | 
|  | 398 | +                    if is_peft_version("<=", "0.13.2"): | 
|  | 399 | +                        lora_config_kwargs.pop("lora_bias") | 
|  | 400 | + | 
|  | 401 | +            lora_config = LoraConfig(**lora_config_kwargs) | 
|  | 402 | + | 
|  | 403 | +            # adapter_name | 
|  | 404 | +            if adapter_name is None: | 
|  | 405 | +                adapter_name = get_adapter_name(text_encoder) | 
|  | 406 | + | 
|  | 407 | +            is_model_cpu_offload, is_sequential_cpu_offload = _func_optionally_disable_offloading(_pipeline) | 
|  | 408 | + | 
|  | 409 | +            # inject LoRA layers and load the state dict | 
|  | 410 | +            # in transformers we automatically check whether the adapter name is already in use or not | 
|  | 411 | +            text_encoder.load_adapter( | 
|  | 412 | +                adapter_name=adapter_name, | 
|  | 413 | +                adapter_state_dict=text_encoder_lora_state_dict, | 
|  | 414 | +                peft_config=lora_config, | 
|  | 415 | +                **peft_kwargs, | 
|  | 416 | +            ) | 
|  | 417 | + | 
|  | 418 | +            # scale LoRA layers with `lora_scale` | 
|  | 419 | +            scale_lora_layers(text_encoder, weight=lora_scale) | 
|  | 420 | + | 
|  | 421 | +            text_encoder.to(device=text_encoder.device, dtype=text_encoder.dtype) | 
|  | 422 | + | 
|  | 423 | +            # Offload back. | 
|  | 424 | +            if is_model_cpu_offload: | 
|  | 425 | +                _pipeline.enable_model_cpu_offload() | 
|  | 426 | +            elif is_sequential_cpu_offload: | 
|  | 427 | +                _pipeline.enable_sequential_cpu_offload() | 
|  | 428 | +            # Unsafe code /> | 
|  | 429 | + | 
|  | 430 | + | 
|  | 431 | +def _func_optionally_disable_offloading(_pipeline): | 
|  | 432 | +    is_model_cpu_offload = False | 
|  | 433 | +    is_sequential_cpu_offload = False | 
|  | 434 | + | 
|  | 435 | +    if _pipeline is not None and _pipeline.hf_device_map is None: | 
|  | 436 | +        for _, component in _pipeline.components.items(): | 
|  | 437 | +            if isinstance(component, nn.Module) and hasattr(component, "_hf_hook"): | 
|  | 438 | +                if not is_model_cpu_offload: | 
|  | 439 | +                    is_model_cpu_offload = isinstance(component._hf_hook, CpuOffload) | 
|  | 440 | +                if not is_sequential_cpu_offload: | 
|  | 441 | +                    is_sequential_cpu_offload = ( | 
|  | 442 | +                        isinstance(component._hf_hook, AlignDevicesHook) | 
|  | 443 | +                        or hasattr(component._hf_hook, "hooks") | 
|  | 444 | +                        and isinstance(component._hf_hook.hooks[0], AlignDevicesHook) | 
|  | 445 | +                    ) | 
|  | 446 | + | 
|  | 447 | +                logger.info( | 
|  | 448 | +                    "Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again." | 
|  | 449 | +                ) | 
|  | 450 | +                remove_hook_from_module(component, recurse=is_sequential_cpu_offload) | 
|  | 451 | + | 
|  | 452 | +    return (is_model_cpu_offload, is_sequential_cpu_offload) | 
|  | 453 | + | 
|  | 454 | + | 
| 300 | 455 | class LoraBaseMixin: | 
| 301 | 456 |     """Utility class for handling LoRAs.""" | 
| 302 | 457 | 
 | 
| @@ -327,27 +482,7 @@ def _optionally_disable_offloading(cls, _pipeline): | 
| 327 | 482 |             tuple: | 
| 328 | 483 |                 A tuple indicating if `is_model_cpu_offload` or `is_sequential_cpu_offload` is True. | 
| 329 | 484 |         """ | 
| 330 |  | -        is_model_cpu_offload = False | 
| 331 |  | -        is_sequential_cpu_offload = False | 
| 332 |  | - | 
| 333 |  | -        if _pipeline is not None and _pipeline.hf_device_map is None: | 
| 334 |  | -            for _, component in _pipeline.components.items(): | 
| 335 |  | -                if isinstance(component, nn.Module) and hasattr(component, "_hf_hook"): | 
| 336 |  | -                    if not is_model_cpu_offload: | 
| 337 |  | -                        is_model_cpu_offload = isinstance(component._hf_hook, CpuOffload) | 
| 338 |  | -                    if not is_sequential_cpu_offload: | 
| 339 |  | -                        is_sequential_cpu_offload = ( | 
| 340 |  | -                            isinstance(component._hf_hook, AlignDevicesHook) | 
| 341 |  | -                            or hasattr(component._hf_hook, "hooks") | 
| 342 |  | -                            and isinstance(component._hf_hook.hooks[0], AlignDevicesHook) | 
| 343 |  | -                        ) | 
| 344 |  | - | 
| 345 |  | -                    logger.info( | 
| 346 |  | -                        "Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again." | 
| 347 |  | -                    ) | 
| 348 |  | -                    remove_hook_from_module(component, recurse=is_sequential_cpu_offload) | 
| 349 |  | - | 
| 350 |  | -        return (is_model_cpu_offload, is_sequential_cpu_offload) | 
|  | 485 | +        return _func_optionally_disable_offloading(_pipeline=_pipeline) | 
| 351 | 486 | 
 | 
| 352 | 487 |     @classmethod | 
| 353 | 488 |     def _fetch_state_dict(cls, *args, **kwargs): | 
|  | 
0 commit comments