Skip to content

Commit 7e8ae22

Browse files
authored
Merge branch 'main' into xpu
2 parents d389758 + 3f3f0c1 commit 7e8ae22

File tree

12 files changed

+204
-27
lines changed

12 files changed

+204
-27
lines changed

docs/source/en/tutorials/using_peft_for_inference.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -315,6 +315,8 @@ pipeline.load_lora_weights(
315315
> [!TIP]
316316
> Move your code inside the `with torch._dynamo.config.patch(error_on_recompile=True)` context manager to detect if a model was recompiled. If a model is recompiled despite following all the steps above, please open an [issue](https://github.com/huggingface/diffusers/issues) with a reproducible example.
317317
318+
If you expect to varied resolutions during inference with this feature, then make sure set `dynamic=True` during compilation. Refer to [this document](../optimization/fp16#dynamic-shape-compilation) for more details.
319+
318320
There are still scenarios where recompulation is unavoidable, such as when the hotswapped LoRA targets more layers than the initial adapter. Try to load the LoRA that targets the most layers *first*. For more details about this limitation, refer to the PEFT [hotswapping](https://huggingface.co/docs/peft/main/en/package_reference/hotswap#peft.utils.hotswap.hotswap_adapter) docs.
319321

320322
## Merge

src/diffusers/loaders/peft.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -244,13 +244,20 @@ def load_lora_adapter(
244244
k.removeprefix(f"{prefix}."): v for k, v in network_alphas.items() if k in alpha_keys
245245
}
246246

247-
# create LoraConfig
248-
lora_config = _create_lora_config(state_dict, network_alphas, metadata, rank)
249-
250247
# adapter_name
251248
if adapter_name is None:
252249
adapter_name = get_adapter_name(self)
253250

251+
# create LoraConfig
252+
lora_config = _create_lora_config(
253+
state_dict,
254+
network_alphas,
255+
metadata,
256+
rank,
257+
model_state_dict=self.state_dict(),
258+
adapter_name=adapter_name,
259+
)
260+
254261
# <Unsafe code
255262
# We can be sure that the following works as it just sets attention processors, lora layers and puts all in the same dtype
256263
# Now we remove any existing hooks to `_pipeline`.

src/diffusers/loaders/unet_loader_utils.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
import copy
1515
from typing import TYPE_CHECKING, Dict, List, Union
1616

17+
from torch import nn
18+
1719
from ..utils import logging
1820

1921

@@ -52,7 +54,7 @@ def _maybe_expand_lora_scales(
5254
weight_for_adapter,
5355
blocks_with_transformer,
5456
transformer_per_block,
55-
unet.state_dict(),
57+
model=unet,
5658
default_scale=default_scale,
5759
)
5860
for weight_for_adapter in weight_scales
@@ -65,7 +67,7 @@ def _maybe_expand_lora_scales_for_one_adapter(
6567
scales: Union[float, Dict],
6668
blocks_with_transformer: Dict[str, int],
6769
transformer_per_block: Dict[str, int],
68-
state_dict: None,
70+
model: nn.Module,
6971
default_scale: float = 1.0,
7072
):
7173
"""
@@ -154,6 +156,7 @@ def _maybe_expand_lora_scales_for_one_adapter(
154156

155157
del scales[updown]
156158

159+
state_dict = model.state_dict()
157160
for layer in scales.keys():
158161
if not any(_translate_into_actual_layer_name(layer) in module for module in state_dict.keys()):
159162
raise ValueError(

src/diffusers/schedulers/scheduling_scm.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,6 @@ def set_timesteps(
168168
else:
169169
# max_timesteps=arctan(80/0.5)=1.56454 is the default from sCM paper, we choose a different value here
170170
self.timesteps = torch.linspace(max_timesteps, 0, num_inference_steps + 1, device=device).float()
171-
print(f"Set timesteps: {self.timesteps}")
172171

173172
self._step_index = None
174173
self._begin_index = None

src/diffusers/utils/peft_utils.py

Lines changed: 49 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,9 @@ def unscale_lora_layers(model, weight: Optional[float] = None):
150150
module.set_scale(adapter_name, 1.0)
151151

152152

153-
def get_peft_kwargs(rank_dict, network_alpha_dict, peft_state_dict, is_unet=True):
153+
def get_peft_kwargs(
154+
rank_dict, network_alpha_dict, peft_state_dict, is_unet=True, model_state_dict=None, adapter_name=None
155+
):
154156
rank_pattern = {}
155157
alpha_pattern = {}
156158
r = lora_alpha = list(rank_dict.values())[0]
@@ -180,7 +182,6 @@ def get_peft_kwargs(rank_dict, network_alpha_dict, peft_state_dict, is_unet=True
180182
else:
181183
lora_alpha = set(network_alpha_dict.values()).pop()
182184

183-
# layer names without the Diffusers specific
184185
target_modules = list({name.split(".lora")[0] for name in peft_state_dict.keys()})
185186
use_dora = any("lora_magnitude_vector" in k for k in peft_state_dict)
186187
# for now we know that the "bias" keys are only associated with `lora_B`.
@@ -195,6 +196,21 @@ def get_peft_kwargs(rank_dict, network_alpha_dict, peft_state_dict, is_unet=True
195196
"use_dora": use_dora,
196197
"lora_bias": lora_bias,
197198
}
199+
200+
# Example: try load FusionX LoRA into Wan VACE
201+
exclude_modules = _derive_exclude_modules(model_state_dict, peft_state_dict, adapter_name)
202+
if exclude_modules:
203+
if not is_peft_version(">=", "0.14.0"):
204+
msg = """
205+
It seems like there are certain modules that need to be excluded when initializing `LoraConfig`. Your current `peft`
206+
version doesn't support passing an `exclude_modules` to `LoraConfig`. Please update it by running `pip install -U
207+
peft`. For most cases, this can be completely ignored. But if it seems unexpected, please file an issue -
208+
https://github.com/huggingface/diffusers/issues/new
209+
"""
210+
logger.debug(msg)
211+
else:
212+
lora_config_kwargs.update({"exclude_modules": exclude_modules})
213+
198214
return lora_config_kwargs
199215

200216

@@ -294,19 +310,20 @@ def check_peft_version(min_version: str) -> None:
294310

295311

296312
def _create_lora_config(
297-
state_dict,
298-
network_alphas,
299-
metadata,
300-
rank_pattern_dict,
301-
is_unet: bool = True,
313+
state_dict, network_alphas, metadata, rank_pattern_dict, is_unet=True, model_state_dict=None, adapter_name=None
302314
):
303315
from peft import LoraConfig
304316

305317
if metadata is not None:
306318
lora_config_kwargs = metadata
307319
else:
308320
lora_config_kwargs = get_peft_kwargs(
309-
rank_pattern_dict, network_alpha_dict=network_alphas, peft_state_dict=state_dict, is_unet=is_unet
321+
rank_pattern_dict,
322+
network_alpha_dict=network_alphas,
323+
peft_state_dict=state_dict,
324+
is_unet=is_unet,
325+
model_state_dict=model_state_dict,
326+
adapter_name=adapter_name,
310327
)
311328

312329
_maybe_raise_error_for_ambiguous_keys(lora_config_kwargs)
@@ -371,3 +388,27 @@ def _maybe_warn_for_unhandled_keys(incompatible_keys, adapter_name):
371388

372389
if warn_msg:
373390
logger.warning(warn_msg)
391+
392+
393+
def _derive_exclude_modules(model_state_dict, peft_state_dict, adapter_name=None):
394+
"""
395+
Derives the modules to exclude while initializing `LoraConfig` through `exclude_modules`. It works by comparing the
396+
`model_state_dict` and `peft_state_dict` and adds a module from `model_state_dict` to the exclusion set if it
397+
doesn't exist in `peft_state_dict`.
398+
"""
399+
if model_state_dict is None:
400+
return
401+
all_modules = set()
402+
string_to_replace = f"{adapter_name}." if adapter_name else ""
403+
404+
for name in model_state_dict.keys():
405+
if string_to_replace:
406+
name = name.replace(string_to_replace, "")
407+
if "." in name:
408+
module_name = name.rsplit(".", 1)[0]
409+
all_modules.add(module_name)
410+
411+
target_modules_set = {name.split(".lora")[0] for name in peft_state_dict.keys()}
412+
exclude_modules = list(all_modules - target_modules_set)
413+
414+
return exclude_modules

tests/lora/test_lora_layers_wan.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,11 @@
2424
WanPipeline,
2525
WanTransformer3DModel,
2626
)
27-
from diffusers.utils.testing_utils import floats_tensor, require_peft_backend, skip_mps
27+
from diffusers.utils.testing_utils import (
28+
floats_tensor,
29+
require_peft_backend,
30+
skip_mps,
31+
)
2832

2933

3034
sys.path.append(".")

tests/lora/utils.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
15+
import copy
1516
import inspect
1617
import os
1718
import re
@@ -291,6 +292,20 @@ def _get_modules_to_save(self, pipe, has_denoiser=False):
291292

292293
return modules_to_save
293294

295+
def _get_exclude_modules(self, pipe):
296+
from diffusers.utils.peft_utils import _derive_exclude_modules
297+
298+
modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
299+
denoiser = "unet" if self.unet_kwargs is not None else "transformer"
300+
modules_to_save = {k: v for k, v in modules_to_save.items() if k == denoiser}
301+
denoiser_lora_state_dict = self._get_lora_state_dicts(modules_to_save)[f"{denoiser}_lora_layers"]
302+
pipe.unload_lora_weights()
303+
denoiser_state_dict = pipe.unet.state_dict() if self.unet_kwargs is not None else pipe.transformer.state_dict()
304+
exclude_modules = _derive_exclude_modules(
305+
denoiser_state_dict, denoiser_lora_state_dict, adapter_name="default"
306+
)
307+
return exclude_modules
308+
294309
def add_adapters_to_pipeline(self, pipe, text_lora_config=None, denoiser_lora_config=None, adapter_name="default"):
295310
if text_lora_config is not None:
296311
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
@@ -2326,6 +2341,58 @@ def test_lora_unload_add_adapter(self):
23262341
)
23272342
_ = pipe(**inputs, generator=torch.manual_seed(0))[0]
23282343

2344+
@require_peft_version_greater("0.13.2")
2345+
def test_lora_exclude_modules(self):
2346+
"""
2347+
Test to check if `exclude_modules` works or not. It works in the following way:
2348+
we first create a pipeline and insert LoRA config into it. We then derive a `set`
2349+
of modules to exclude by investigating its denoiser state dict and denoiser LoRA
2350+
state dict.
2351+
2352+
We then create a new LoRA config to include the `exclude_modules` and perform tests.
2353+
"""
2354+
scheduler_cls = self.scheduler_classes[0]
2355+
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
2356+
pipe = self.pipeline_class(**components).to(torch_device)
2357+
_, _, inputs = self.get_dummy_inputs(with_generator=False)
2358+
2359+
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
2360+
self.assertTrue(output_no_lora.shape == self.output_shape)
2361+
2362+
# only supported for `denoiser` now
2363+
pipe_cp = copy.deepcopy(pipe)
2364+
pipe_cp, _ = self.add_adapters_to_pipeline(
2365+
pipe_cp, text_lora_config=text_lora_config, denoiser_lora_config=denoiser_lora_config
2366+
)
2367+
denoiser_exclude_modules = self._get_exclude_modules(pipe_cp)
2368+
pipe_cp.to("cpu")
2369+
del pipe_cp
2370+
2371+
denoiser_lora_config.exclude_modules = denoiser_exclude_modules
2372+
pipe, _ = self.add_adapters_to_pipeline(
2373+
pipe, text_lora_config=text_lora_config, denoiser_lora_config=denoiser_lora_config
2374+
)
2375+
output_lora_exclude_modules = pipe(**inputs, generator=torch.manual_seed(0))[0]
2376+
2377+
with tempfile.TemporaryDirectory() as tmpdir:
2378+
modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
2379+
lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
2380+
lora_metadatas = self._get_lora_adapter_metadata(modules_to_save)
2381+
self.pipeline_class.save_lora_weights(save_directory=tmpdir, **lora_state_dicts, **lora_metadatas)
2382+
pipe.unload_lora_weights()
2383+
pipe.load_lora_weights(tmpdir)
2384+
2385+
output_lora_pretrained = pipe(**inputs, generator=torch.manual_seed(0))[0]
2386+
2387+
self.assertTrue(
2388+
not np.allclose(output_no_lora, output_lora_exclude_modules, atol=1e-3, rtol=1e-3),
2389+
"LoRA should change outputs.",
2390+
)
2391+
self.assertTrue(
2392+
np.allclose(output_lora_exclude_modules, output_lora_pretrained, atol=1e-3, rtol=1e-3),
2393+
"Lora outputs should match.",
2394+
)
2395+
23292396
def test_inference_load_delete_load_adapters(self):
23302397
"Tests if `load_lora_weights()` -> `delete_adapters()` -> `load_lora_weights()` works."
23312398
for scheduler_cls in self.scheduler_classes:

tests/models/test_modeling_common.py

Lines changed: 40 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1349,7 +1349,6 @@ def test_model_parallelism(self):
13491349
new_model = self.model_class.from_pretrained(tmp_dir, device_map="auto", max_memory=max_memory)
13501350
# Making sure part of the model will actually end up offloaded
13511351
self.assertSetEqual(set(new_model.hf_device_map.values()), {0, 1})
1352-
print(f" new_model.hf_device_map:{new_model.hf_device_map}")
13531352

13541353
self.check_device_map_is_respected(new_model, new_model.hf_device_map)
13551354

@@ -2018,6 +2017,8 @@ class LoraHotSwappingForModelTesterMixin:
20182017
20192018
"""
20202019

2020+
different_shapes_for_compilation = None
2021+
20212022
def tearDown(self):
20222023
# It is critical that the dynamo cache is reset for each test. Otherwise, if the test re-uses the same model,
20232024
# there will be recompilation errors, as torch caches the model when run in the same process.
@@ -2055,11 +2056,13 @@ def check_model_hotswap(self, do_compile, rank0, rank1, target_modules0, target_
20552056
- hotswap the second adapter
20562057
- check that the outputs are correct
20572058
- optionally compile the model
2059+
- optionally check if recompilations happen on different shapes
20582060
20592061
Note: We set rank == alpha here because save_lora_adapter does not save the alpha scalings, thus the test would
20602062
fail if the values are different. Since rank != alpha does not matter for the purpose of this test, this is
20612063
fine.
20622064
"""
2065+
different_shapes = self.different_shapes_for_compilation
20632066
# create 2 adapters with different ranks and alphas
20642067
torch.manual_seed(0)
20652068
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
@@ -2109,19 +2112,30 @@ def check_model_hotswap(self, do_compile, rank0, rank1, target_modules0, target_
21092112
model.load_lora_adapter(file_name0, safe_serialization=True, adapter_name="adapter0", prefix=None)
21102113

21112114
if do_compile:
2112-
model = torch.compile(model, mode="reduce-overhead")
2115+
model = torch.compile(model, mode="reduce-overhead", dynamic=different_shapes is not None)
21132116

21142117
with torch.inference_mode():
2115-
output0_after = model(**inputs_dict)["sample"]
2116-
assert torch.allclose(output0_before, output0_after, atol=tol, rtol=tol)
2118+
# additionally check if dynamic compilation works.
2119+
if different_shapes is not None:
2120+
for height, width in different_shapes:
2121+
new_inputs_dict = self.prepare_dummy_input(height=height, width=width)
2122+
_ = model(**new_inputs_dict)
2123+
else:
2124+
output0_after = model(**inputs_dict)["sample"]
2125+
assert torch.allclose(output0_before, output0_after, atol=tol, rtol=tol)
21172126

21182127
# hotswap the 2nd adapter
21192128
model.load_lora_adapter(file_name1, adapter_name="adapter0", hotswap=True, prefix=None)
21202129

21212130
# we need to call forward to potentially trigger recompilation
21222131
with torch.inference_mode():
2123-
output1_after = model(**inputs_dict)["sample"]
2124-
assert torch.allclose(output1_before, output1_after, atol=tol, rtol=tol)
2132+
if different_shapes is not None:
2133+
for height, width in different_shapes:
2134+
new_inputs_dict = self.prepare_dummy_input(height=height, width=width)
2135+
_ = model(**new_inputs_dict)
2136+
else:
2137+
output1_after = model(**inputs_dict)["sample"]
2138+
assert torch.allclose(output1_before, output1_after, atol=tol, rtol=tol)
21252139

21262140
# check error when not passing valid adapter name
21272141
name = "does-not-exist"
@@ -2239,3 +2253,23 @@ def test_hotswap_second_adapter_targets_more_layers_raises(self):
22392253
do_compile=True, rank0=8, rank1=8, target_modules0=target_modules0, target_modules1=target_modules1
22402254
)
22412255
assert any("Hotswapping adapter0 was unsuccessful" in log for log in cm.output)
2256+
2257+
@parameterized.expand([(11, 11), (7, 13), (13, 7)])
2258+
@require_torch_version_greater("2.7.1")
2259+
def test_hotswapping_compile_on_different_shapes(self, rank0, rank1):
2260+
different_shapes_for_compilation = self.different_shapes_for_compilation
2261+
if different_shapes_for_compilation is None:
2262+
pytest.skip(f"Skipping as `different_shapes_for_compilation` is not set for {self.__class__.__name__}.")
2263+
# Specifying `use_duck_shape=False` instructs the compiler if it should use the same symbolic
2264+
# variable to represent input sizes that are the same. For more details,
2265+
# check out this [comment](https://github.com/huggingface/diffusers/pull/11327#discussion_r2047659790).
2266+
torch.fx.experimental._config.use_duck_shape = False
2267+
2268+
target_modules = ["to_q", "to_k", "to_v", "to_out.0"]
2269+
with torch._dynamo.config.patch(error_on_recompile=True):
2270+
self.check_model_hotswap(
2271+
do_compile=True,
2272+
rank0=rank0,
2273+
rank1=rank1,
2274+
target_modules0=target_modules,
2275+
)

tests/models/transformers/test_models_transformer_flux.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,10 @@ def prepare_dummy_input(self, height, width):
186186

187187
class FluxTransformerLoRAHotSwapTests(LoraHotSwappingForModelTesterMixin, unittest.TestCase):
188188
model_class = FluxTransformer2DModel
189+
different_shapes_for_compilation = [(4, 4), (4, 8), (8, 8)]
189190

190191
def prepare_init_args_and_inputs_for_common(self):
191192
return FluxTransformerTests().prepare_init_args_and_inputs_for_common()
193+
194+
def prepare_dummy_input(self, height, width):
195+
return FluxTransformerTests().prepare_dummy_input(height=height, width=width)

0 commit comments

Comments
 (0)