Skip to content

Conversation

@BenjaminBossan
Copy link
Member

What does this PR do?

LoRA hotswapping has been available in PEFT since 0.15.0. There is already a diffusers
integration (huggingface/diffusers#9453), but the transformers integration was still missing this feature. This PR remedies this. It sticks closely to the diffusers PR, both implementation-wise and API-wise.

Hotswapping allows to swap different LoRA adapters in-place instead of loading multiple adapters and switching between them. Not only can this be advantageous to safe memory and potentially for quicker loading, the biggest advantage is that if the model is compiled, we can hotswap without triggering recompilation (loading a separate adapter would require recompilation).

There are some caveats to using this feature, most notably that only LoRA is supported. This was fine for diffusers, as it only works with LoRA, but the transformers integration works with other PEFT methods too. However, LoRA should be by far the most common method, so this should be fine for now. This and other caveats have been documented.

Note that testing is not super deep, there could be more edge cases being tested. However, as this is mainly about calling PEFT functionality, which is extensively tested in PEFT, I focused the tests mostly on the integration, with only a couple of tests for the functionality itself (namely the tests that call _check_model_hotswap).

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

LoRA hotswapping has been available in PEFT since 0.15.0. There is
already a diffusers
integration (huggingface/diffusers#9453), but
the transformers integration was still missing this feature. This PR
remedies this.

Hotswapping allows to swap different LoRA adapters in-place instead of
loading multiple adapters and switchint between them. Not only can this
be advantageous to safe memory and potentially for quicker loading, the
biggest advantage is that if the model is compiled, we can hotswap
without triggering recompilation (loading a separate adapter would
require recompilation).

There are some caveats to using this feature, most notably that only
LoRA is supported. This was fine for diffusers, as it only works with
LoRA, but the transformers integration works with other PEFT methods
too. However, LoRA should be by far the most common method, so this
should be fine for now. This and other caveats have been documented.
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM 🔥

Let's maybe update diffusers to uplift our current constraint that this feature isn't available for text encoder LoRAs once this is merged.

@require_peft
@require_torch
@slow
class PeftHotswapIntegrationTester(unittest.TestCase):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ydshieh okay for you?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a note, these are the same decorators as for PeftIntegrationTester, but I didn't want to use it here (nor PeftTesterMixin), as the hotswap tests don't make use of the test matrix defined there.

def tearDown(self):
# It is critical that the dynamo cache is reset for each test. Otherwise, if the test re-uses the same model,
# there will be recompilation errors, as torch caches the model when run in the same process.
torch._dynamo.reset()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
torch._dynamo.reset()
torch.compiler.reset()

We could also introduce torch._inductor.utils.fresh_inductor_cache(). Example: https://github.com/huggingface/diffusers/blob/7242b5ff627fad93dd85834b0278267b6cbe2d6d/tests/models/test_modeling_common.py#L2061C13-L2061C57

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, also added the use of torch._inductor.utils.fresh_inductor_cache().

@BenjaminBossan
Copy link
Member Author

Let's maybe update diffusers to uplift our current constraint that this feature isn't available for text encoder LoRAs once this is merged.

Yes, that would be the next step once this PR is merged.

@BenjaminBossan
Copy link
Member Author

@ydshieh How to proceed with this PR? Sayak approved but I guess we need a transformers dev's approval too?

To make the usage more intuitive, hotswap is now auto-enabled after
calling model.enable_peft_hotswap(). For this, we detect if
enable_peft_hotswap() was called *and* if the adapter being loaded
is *not* the first adapter (because the first adapter cannot be
hotswapped, it needs to be loaded normally).
@ArthurZucker
Copy link
Collaborator

sorry cc @Cyrilvallez can you reviiew?

Copy link
Member

@Cyrilvallez Cyrilvallez left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very nice! Just mostly a bit concerned about the default which seem to be able to crash currently code that is valid?

Comment on lines 200 to 201
from peft import PeftType

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does that import exist even for the MIN_VERSION below? If not, let's import after the check, so that we have gracious crash

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, the import should work with any PEFT version but I agree it's cleaner that way.

if hotswap == "auto":
# if user called model.enable_peft_hotswap and this is not the first adapter, enable hotswap
hotswap_enabled = getattr(self, "_hotswap_enabled", False)
not_first_adapter = bool(self._hf_peft_config_loaded and (adapter_name in self.peft_config))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should not need bool casting here, do we?

Suggested change
not_first_adapter = bool(self._hf_peft_config_loaded and (adapter_name in self.peft_config))
not_first_adapter = self._hf_peft_config_loaded and adapter_name in self.peft_config

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right now, _hf_peft_config_loaded is a bool, so coercing is not needed. But if its type changes in the future, the type of not_first_adapter could also change, so the cast makes this line future proof. Just an example:

>>> _hf_peft_config_loaded = {"foo": 1}
>>> adapter_name = "default"
>>> peft_config = {"default": 2}
>>> not_first_adapter = _hf_peft_config_loaded and adapter_name in peft_config
>>> not_first_adapter
True
>>> _hf_peft_config_loaded = {}  # falsy value short-circuits the conditional
>>> not_first_adapter = _hf_peft_config_loaded and adapter_name in peft_config
>>> not_first_adapter
{}

As you can see, in the last line, we suddenly have a different type for not_first_adapter.

Comment on lines +218 to +220
if any(conf.peft_type != PeftType.LORA for conf in self.peft_config.values()):
raise ValueError("Hotswapping is currently only supported for LoRA, please set `hotswap=False`.")

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just concerned about that, maybe False would be a better default no? Here looks like current code would crash without reason when loading more adapters that are not lora
Or maybe a None value, that becomes auto if we use lora, and False otherwise?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would only crash if the user tries to load non-LoRA and if they:

  • passed hotswap=True or
  • called enable_peft_hotswap

In either case, the user intent is to use hotswapping. Therefore, I think that raising is the better choice, otherwise the user would think they used hotswapping successfully when it's not actually being used.

Comment on lines +340 to +350
if self._prepare_peft_hotswap_kwargs is not None:
# For hotswapping of compiled models or adapters with different ranks.
# If the user called enable_peft_hotswap, we need to ensure it is called:
# - after the first adapter was loaded
# - before the model is compiled and the 2nd adapter is being hotswapped in
# Therefore, it needs to be called here
from peft.utils.hotswap import prepare_model_for_compiled_hotswap

prepare_model_for_compiled_hotswap(self, config=peft_config, **self._prepare_peft_hotswap_kwargs)
# We only want to call prepare_model_for_compiled_hotswap once
self._prepare_peft_hotswap_kwargs = None
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be under the condition if not hotswap? 🤔

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is, see line 335.

Copy link
Member Author

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Cyrilvallez Thanks a lot for your review. I addressed your comments, please check again.

Comment on lines 200 to 201
from peft import PeftType

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, the import should work with any PEFT version but I agree it's cleaner that way.

if hotswap == "auto":
# if user called model.enable_peft_hotswap and this is not the first adapter, enable hotswap
hotswap_enabled = getattr(self, "_hotswap_enabled", False)
not_first_adapter = bool(self._hf_peft_config_loaded and (adapter_name in self.peft_config))
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right now, _hf_peft_config_loaded is a bool, so coercing is not needed. But if its type changes in the future, the type of not_first_adapter could also change, so the cast makes this line future proof. Just an example:

>>> _hf_peft_config_loaded = {"foo": 1}
>>> adapter_name = "default"
>>> peft_config = {"default": 2}
>>> not_first_adapter = _hf_peft_config_loaded and adapter_name in peft_config
>>> not_first_adapter
True
>>> _hf_peft_config_loaded = {}  # falsy value short-circuits the conditional
>>> not_first_adapter = _hf_peft_config_loaded and adapter_name in peft_config
>>> not_first_adapter
{}

As you can see, in the last line, we suddenly have a different type for not_first_adapter.

Comment on lines +218 to +220
if any(conf.peft_type != PeftType.LORA for conf in self.peft_config.values()):
raise ValueError("Hotswapping is currently only supported for LoRA, please set `hotswap=False`.")

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would only crash if the user tries to load non-LoRA and if they:

  • passed hotswap=True or
  • called enable_peft_hotswap

In either case, the user intent is to use hotswapping. Therefore, I think that raising is the better choice, otherwise the user would think they used hotswapping successfully when it's not actually being used.

Comment on lines +340 to +350
if self._prepare_peft_hotswap_kwargs is not None:
# For hotswapping of compiled models or adapters with different ranks.
# If the user called enable_peft_hotswap, we need to ensure it is called:
# - after the first adapter was loaded
# - before the model is compiled and the 2nd adapter is being hotswapped in
# Therefore, it needs to be called here
from peft.utils.hotswap import prepare_model_for_compiled_hotswap

prepare_model_for_compiled_hotswap(self, config=peft_config, **self._prepare_peft_hotswap_kwargs)
# We only want to call prepare_model_for_compiled_hotswap once
self._prepare_peft_hotswap_kwargs = None
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is, see line 335.

@sayakpaul sayakpaul requested a review from Cyrilvallez October 27, 2025 09:07
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants