Skip to content

Commit 2d5b305

Browse files
authored
Merge branch 'main' into benchmarking-overhaul
2 parents 7d4f459 + a5f4cc7 commit 2d5b305

File tree

4 files changed

+54
-3
lines changed

4 files changed

+54
-3
lines changed

src/diffusers/loaders/lora_pipeline.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2545,14 +2545,13 @@ def _maybe_expand_lora_state_dict(cls, transformer, lora_state_dict):
25452545
if unexpected_modules:
25462546
logger.debug(f"Found unexpected modules: {unexpected_modules}. These will be ignored.")
25472547

2548-
is_peft_loaded = getattr(transformer, "peft_config", None) is not None
25492548
for k in lora_module_names:
25502549
if k in unexpected_modules:
25512550
continue
25522551

25532552
base_param_name = (
25542553
f"{k.replace(prefix, '')}.base_layer.weight"
2555-
if is_peft_loaded and f"{k.replace(prefix, '')}.base_layer.weight" in transformer_state_dict
2554+
if f"{k.replace(prefix, '')}.base_layer.weight" in transformer_state_dict
25562555
else f"{k.replace(prefix, '')}.weight"
25572556
)
25582557
base_weight_param = transformer_state_dict[base_param_name]

src/diffusers/pipelines/pipeline_utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1665,6 +1665,8 @@ def _get_signature_types(cls):
16651665
signature_types[k] = (v.annotation,)
16661666
elif get_origin(v.annotation) == Union:
16671667
signature_types[k] = get_args(v.annotation)
1668+
elif get_origin(v.annotation) in [List, Dict, list, dict]:
1669+
signature_types[k] = (v.annotation,)
16681670
else:
16691671
logger.warning(f"cannot get type annotation for Parameter {k} of {cls}.")
16701672
return signature_types

src/diffusers/utils/torch_utils.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def maybe_allow_in_graph(cls):
3838
def randn_tensor(
3939
shape: Union[Tuple, List],
4040
generator: Optional[Union[List["torch.Generator"], "torch.Generator"]] = None,
41-
device: Optional["torch.device"] = None,
41+
device: Optional[Union[str, "torch.device"]] = None,
4242
dtype: Optional["torch.dtype"] = None,
4343
layout: Optional["torch.layout"] = None,
4444
):
@@ -47,6 +47,8 @@ def randn_tensor(
4747
is always created on the CPU.
4848
"""
4949
# device on which tensor is created defaults to device
50+
if isinstance(device, str):
51+
device = torch.device(device)
5052
rand_device = device
5153
batch_size = shape[0]
5254

tests/lora/utils.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2149,3 +2149,51 @@ def check_module(denoiser):
21492149

21502150
_, _, inputs = self.get_dummy_inputs(with_generator=False)
21512151
pipe(**inputs, generator=torch.manual_seed(0))[0]
2152+
2153+
def test_inference_load_delete_load_adapters(self):
2154+
"Tests if `load_lora_weights()` -> `delete_adapters()` -> `load_lora_weights()` works."
2155+
for scheduler_cls in self.scheduler_classes:
2156+
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
2157+
pipe = self.pipeline_class(**components)
2158+
pipe = pipe.to(torch_device)
2159+
pipe.set_progress_bar_config(disable=None)
2160+
_, _, inputs = self.get_dummy_inputs(with_generator=False)
2161+
2162+
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
2163+
2164+
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
2165+
pipe.text_encoder.add_adapter(text_lora_config)
2166+
self.assertTrue(
2167+
check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
2168+
)
2169+
2170+
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
2171+
denoiser.add_adapter(denoiser_lora_config)
2172+
self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
2173+
2174+
if self.has_two_text_encoders or self.has_three_text_encoders:
2175+
lora_loadable_components = self.pipeline_class._lora_loadable_modules
2176+
if "text_encoder_2" in lora_loadable_components:
2177+
pipe.text_encoder_2.add_adapter(text_lora_config)
2178+
self.assertTrue(
2179+
check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
2180+
)
2181+
2182+
output_adapter_1 = pipe(**inputs, generator=torch.manual_seed(0))[0]
2183+
2184+
with tempfile.TemporaryDirectory() as tmpdirname:
2185+
modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
2186+
lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
2187+
self.pipeline_class.save_lora_weights(save_directory=tmpdirname, **lora_state_dicts)
2188+
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")))
2189+
2190+
# First, delete adapter and compare.
2191+
pipe.delete_adapters(pipe.get_active_adapters()[0])
2192+
output_no_adapter = pipe(**inputs, generator=torch.manual_seed(0))[0]
2193+
self.assertFalse(np.allclose(output_adapter_1, output_no_adapter, atol=1e-3, rtol=1e-3))
2194+
self.assertTrue(np.allclose(output_no_lora, output_no_adapter, atol=1e-3, rtol=1e-3))
2195+
2196+
# Then load adapter and compare.
2197+
pipe.load_lora_weights(tmpdirname)
2198+
output_lora_loaded = pipe(**inputs, generator=torch.manual_seed(0))[0]
2199+
self.assertTrue(np.allclose(output_adapter_1, output_lora_loaded, atol=1e-3, rtol=1e-3))

0 commit comments

Comments
 (0)