|
21 | 21 | USE_PEFT_BACKEND, |
22 | 22 | convert_state_dict_to_diffusers, |
23 | 23 | convert_state_dict_to_peft, |
24 | | - convert_unet_state_dict_to_peft, |
25 | 24 | deprecate, |
26 | 25 | get_adapter_name, |
27 | 26 | get_peft_kwargs, |
@@ -1845,92 +1844,15 @@ def load_lora_into_transformer( |
1845 | 1844 | "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`." |
1846 | 1845 | ) |
1847 | 1846 |
|
1848 | | - from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict |
1849 | | - |
1850 | | - keys = list(state_dict.keys()) |
1851 | | - |
1852 | | - transformer_keys = [k for k in keys if k.startswith(cls.transformer_name)] |
1853 | | - state_dict = { |
1854 | | - k.replace(f"{cls.transformer_name}.", ""): v for k, v in state_dict.items() if k in transformer_keys |
1855 | | - } |
1856 | | - |
1857 | | - if len(state_dict.keys()) > 0: |
1858 | | - # check with first key if is not in peft format |
1859 | | - first_key = next(iter(state_dict.keys())) |
1860 | | - if "lora_A" not in first_key: |
1861 | | - state_dict = convert_unet_state_dict_to_peft(state_dict) |
1862 | | - |
1863 | | - if adapter_name in getattr(transformer, "peft_config", {}): |
1864 | | - raise ValueError( |
1865 | | - f"Adapter name {adapter_name} already in use in the transformer - please select a new adapter name." |
1866 | | - ) |
1867 | | - |
1868 | | - rank = {} |
1869 | | - for key, val in state_dict.items(): |
1870 | | - if "lora_B" in key: |
1871 | | - rank[key] = val.shape[1] |
1872 | | - |
1873 | | - if network_alphas is not None and len(network_alphas) >= 1: |
1874 | | - prefix = cls.transformer_name |
1875 | | - alpha_keys = [k for k in network_alphas.keys() if k.startswith(prefix) and k.split(".")[0] == prefix] |
1876 | | - network_alphas = {k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys} |
1877 | | - |
1878 | | - lora_config_kwargs = get_peft_kwargs(rank, network_alpha_dict=network_alphas, peft_state_dict=state_dict) |
1879 | | - if "use_dora" in lora_config_kwargs: |
1880 | | - if lora_config_kwargs["use_dora"] and is_peft_version("<", "0.9.0"): |
1881 | | - raise ValueError( |
1882 | | - "You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`." |
1883 | | - ) |
1884 | | - else: |
1885 | | - lora_config_kwargs.pop("use_dora") |
1886 | | - lora_config = LoraConfig(**lora_config_kwargs) |
1887 | | - |
1888 | | - # adapter_name |
1889 | | - if adapter_name is None: |
1890 | | - adapter_name = get_adapter_name(transformer) |
1891 | | - |
1892 | | - # In case the pipeline has been already offloaded to CPU - temporarily remove the hooks |
1893 | | - # otherwise loading LoRA weights will lead to an error |
1894 | | - is_model_cpu_offload, is_sequential_cpu_offload = cls._optionally_disable_offloading(_pipeline) |
1895 | | - |
1896 | | - peft_kwargs = {} |
1897 | | - if is_peft_version(">=", "0.13.1"): |
1898 | | - peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage |
1899 | | - |
1900 | | - inject_adapter_in_model(lora_config, transformer, adapter_name=adapter_name, **peft_kwargs) |
1901 | | - incompatible_keys = set_peft_model_state_dict(transformer, state_dict, adapter_name, **peft_kwargs) |
1902 | | - |
1903 | | - warn_msg = "" |
1904 | | - if incompatible_keys is not None: |
1905 | | - # Check only for unexpected keys. |
1906 | | - unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None) |
1907 | | - if unexpected_keys: |
1908 | | - lora_unexpected_keys = [k for k in unexpected_keys if "lora_" in k and adapter_name in k] |
1909 | | - if lora_unexpected_keys: |
1910 | | - warn_msg = ( |
1911 | | - f"Loading adapter weights from state_dict led to unexpected keys found in the model:" |
1912 | | - f" {', '.join(lora_unexpected_keys)}. " |
1913 | | - ) |
1914 | | - |
1915 | | - # Filter missing keys specific to the current adapter. |
1916 | | - missing_keys = getattr(incompatible_keys, "missing_keys", None) |
1917 | | - if missing_keys: |
1918 | | - lora_missing_keys = [k for k in missing_keys if "lora_" in k and adapter_name in k] |
1919 | | - if lora_missing_keys: |
1920 | | - warn_msg += ( |
1921 | | - f"Loading adapter weights from state_dict led to missing keys in the model:" |
1922 | | - f" {', '.join(lora_missing_keys)}." |
1923 | | - ) |
1924 | | - |
1925 | | - if warn_msg: |
1926 | | - logger.warning(warn_msg) |
1927 | | - |
1928 | | - # Offload back. |
1929 | | - if is_model_cpu_offload: |
1930 | | - _pipeline.enable_model_cpu_offload() |
1931 | | - elif is_sequential_cpu_offload: |
1932 | | - _pipeline.enable_sequential_cpu_offload() |
1933 | | - # Unsafe code /> |
| 1847 | + # Load the layers corresponding to transformer. |
| 1848 | + logger.info(f"Loading {cls.transformer_name}.") |
| 1849 | + transformer.load_lora_adapter( |
| 1850 | + state_dict, |
| 1851 | + network_alphas=network_alphas, |
| 1852 | + adapter_name=adapter_name, |
| 1853 | + _pipeline=_pipeline, |
| 1854 | + low_cpu_mem_usage=low_cpu_mem_usage, |
| 1855 | + ) |
1934 | 1856 |
|
1935 | 1857 | @classmethod |
1936 | 1858 | # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder |
|
0 commit comments