Skip to content

Commit 5b4548a

Browse files
committed
feat: save_lora_adapter.
1 parent 13e8fde commit 5b4548a

File tree

2 files changed

+133
-1
lines changed

2 files changed

+133
-1
lines changed

src/diffusers/loaders/peft.py

Lines changed: 51 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,13 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515
import inspect
16+
import os
1617
from functools import partial
18+
from pathlib import Path
1719
from typing import Dict, List, Optional, Union
1820

21+
import safetensors
22+
import torch
1923
import torch.nn as nn
2024

2125
from ..utils import (
@@ -203,7 +207,7 @@ def load_lora_adapter(self, pretrained_model_name_or_path_or_dict, prefix="trans
203207

204208
if adapter_name in getattr(self, "peft_config", {}):
205209
raise ValueError(
206-
f"Adapter name {adapter_name} already in use in the transformer - please select a new adapter name."
210+
f"Adapter name {adapter_name} already in use in the model - please select a new adapter name."
207211
)
208212

209213
rank = {}
@@ -276,6 +280,52 @@ def load_lora_adapter(self, pretrained_model_name_or_path_or_dict, prefix="trans
276280
_pipeline.enable_sequential_cpu_offload()
277281
# Unsafe code />
278282

283+
def save_lora_adapter(
284+
self,
285+
save_directory,
286+
adapter_name: str = "default",
287+
upcast_before_saving: bool = False,
288+
safe_serialization: bool = True,
289+
weight_name: str = None,
290+
):
291+
"""TODO"""
292+
from peft.utils import get_peft_model_state_dict
293+
294+
from .lora_base import LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE
295+
296+
if adapter_name is None:
297+
adapter_name = get_adapter_name(self)
298+
299+
if adapter_name not in getattr(self, "peft_config", {}):
300+
raise ValueError(f"Adapter name {adapter_name} not found in the model.")
301+
302+
lora_layers_to_save = get_peft_model_state_dict(
303+
self.to(dtype=torch.float32 if upcast_before_saving else None), adapter_name=adapter_name
304+
)
305+
if os.path.isfile(save_directory):
306+
logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
307+
return
308+
309+
if safe_serialization:
310+
311+
def save_function(weights, filename):
312+
return safetensors.torch.save_file(weights, filename, metadata={"format": "pt"})
313+
314+
else:
315+
save_function = torch.save
316+
317+
os.makedirs(save_directory, exist_ok=True)
318+
319+
if weight_name is None:
320+
if safe_serialization:
321+
weight_name = LORA_WEIGHT_NAME_SAFE
322+
else:
323+
weight_name = LORA_WEIGHT_NAME
324+
325+
save_path = Path(save_directory, weight_name).as_posix()
326+
save_function(lora_layers_to_save, save_path)
327+
logger.info(f"Model weights saved in {save_path}")
328+
279329
def set_adapters(
280330
self,
281331
adapter_names: Union[List[str], str],

tests/models/test_modeling_common.py

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
from diffusers.utils import (
4545
SAFE_WEIGHTS_INDEX_NAME,
4646
WEIGHTS_INDEX_NAME,
47+
is_peft_available,
4748
is_torch_npu_available,
4849
is_xformers_available,
4950
logging,
@@ -53,6 +54,7 @@
5354
CaptureLogger,
5455
get_python_version,
5556
is_torch_compile,
57+
require_peft_backend,
5658
require_torch_2,
5759
require_torch_accelerator_with_training,
5860
require_torch_gpu,
@@ -65,6 +67,13 @@
6567
from ..others.test_utils import TOKEN, USER, is_staging_test
6668

6769

70+
if is_peft_available():
71+
from peft import LoraConfig
72+
from peft.tuners.tuners_utils import BaseTunerLayer
73+
74+
from diffusers.loaders import PeftAdapterMixin
75+
76+
6877
def caculate_expected_num_shards(index_map_path):
6978
with open(index_map_path) as f:
7079
weight_map_dict = json.load(f)["weight_map"]
@@ -74,6 +83,16 @@ def caculate_expected_num_shards(index_map_path):
7483
return expected_num_shards
7584

7685

86+
def check_if_lora_correctly_set(model) -> bool:
87+
"""
88+
Checks if the LoRA layers are correctly set with peft
89+
"""
90+
for module in model.modules():
91+
if isinstance(module, BaseTunerLayer):
92+
return True
93+
return False
94+
95+
7796
# Will be run via run_test_in_subprocess
7897
def _test_from_save_pretrained_dynamo(in_queue, out_queue, timeout):
7998
error = None
@@ -902,6 +921,69 @@ def test_deprecated_kwargs(self):
902921
" from `_deprecated_kwargs = [<deprecated_argument>]`"
903922
)
904923

924+
@require_peft_backend
925+
@parameterized.expand([True, False])
926+
def test_load_save_lora_adapter(self, use_dora=False):
927+
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
928+
model = self.model_class(**init_dict).to(torch_device)
929+
930+
if not issubclass(model.__class__, PeftAdapterMixin):
931+
return
932+
933+
torch.manual_seed(0)
934+
output_no_lora = model(**inputs_dict).sample
935+
936+
denoiser_lora_config = LoraConfig(
937+
r=4,
938+
lora_alpha=4,
939+
target_modules=["to_q", "to_k", "to_v", "to_out.0"],
940+
init_lora_weights=False,
941+
use_dora=use_dora,
942+
)
943+
model.add_adapter(denoiser_lora_config)
944+
self.assertTrue(check_if_lora_correctly_set(model), "LoRA layers not set correctly")
945+
946+
torch.manual_seed(0)
947+
outputs_with_lora = model(**inputs_dict).sample
948+
949+
self.assertFalse(torch.allclose(output_no_lora, outputs_with_lora, atol=1e-4, rtol=1e-4))
950+
951+
with tempfile.TemporaryDirectory() as tmpdir:
952+
model.save_lora_adapter(tmpdir)
953+
model.unload_lora()
954+
model.load_lora_adapter(tmpdir, use_safetensors=True)
955+
self.assertTrue(check_if_lora_correctly_set(model), "LoRA layers not set correctly")
956+
957+
torch.manual_seed(0)
958+
outputs_with_lora_2 = model(**inputs_dict).sample
959+
960+
self.assertFalse(torch.allclose(output_no_lora, outputs_with_lora_2, atol=1e-4, rtol=1e-4))
961+
self.assertTrue(torch.allclose(outputs_with_lora, outputs_with_lora_2, atol=1e-4, rtol=1e-4))
962+
963+
def test_wrong_adapter_name_raises_error(self):
964+
init_dict, _ = self.prepare_init_args_and_inputs_for_common()
965+
model = self.model_class(**init_dict).to(torch_device)
966+
967+
if not issubclass(model.__class__, PeftAdapterMixin):
968+
return
969+
970+
denoiser_lora_config = LoraConfig(
971+
r=4,
972+
lora_alpha=4,
973+
target_modules=["to_q", "to_k", "to_v", "to_out.0"],
974+
init_lora_weights=False,
975+
use_dora=False,
976+
)
977+
model.add_adapter(denoiser_lora_config)
978+
self.assertTrue(check_if_lora_correctly_set(model), "LoRA layers not set correctly")
979+
980+
with tempfile.TemporaryDirectory() as tmpdir:
981+
wrong_name = "foo"
982+
with self.assertRaises(ValueError) as err_context:
983+
model.save_lora_adapter(tmpdir, adapter_name=wrong_name)
984+
985+
self.assertTrue(f"Adapter name {wrong_name} not found in the model." in str(err_context.exception))
986+
905987
@require_torch_gpu
906988
def test_cpu_offload(self):
907989
config, inputs_dict = self.prepare_init_args_and_inputs_for_common()

0 commit comments

Comments
 (0)