Skip to content
Merged
Show file tree
Hide file tree
Changes from 20 commits
Commits
Show all changes
55 commits
Select commit Hold shift + click to select a range
d3fbd7b
[WIP][LoRA] Implement hot-swapping of LoRA
BenjaminBossan Sep 17, 2024
84bae62
Reviewer feedback
BenjaminBossan Sep 18, 2024
63ece9d
Reviewer feedback, adjust test
BenjaminBossan Oct 16, 2024
94c669c
Merge branch 'main' into lora-hot-swapping
BenjaminBossan Oct 16, 2024
c7378ed
Fix, doc
BenjaminBossan Oct 16, 2024
7c67b38
Make fix
BenjaminBossan Oct 16, 2024
ea12e0d
Fix for possible g++ error
BenjaminBossan Oct 16, 2024
ec4b0d5
Add test for recompilation w/o hotswapping
BenjaminBossan Oct 18, 2024
e07323a
Merge branch 'main' into lora-hot-swapping
sayakpaul Oct 18, 2024
529a523
Merge branch 'main' into lora-hot-swapping
sayakpaul Oct 22, 2024
ac1346d
Merge branch 'main' into lora-hot-swapping
sayakpaul Oct 25, 2024
58b35ba
Merge branch 'main' into lora-hot-swapping
BenjaminBossan Feb 6, 2025
d21a988
Merge branch 'main' into lora-hot-swapping
BenjaminBossan Feb 6, 2025
488f2f0
Make hotswap work
BenjaminBossan Feb 7, 2025
ece3d0f
Merge branch 'main' into lora-hot-swapping
sayakpaul Feb 8, 2025
5ab1460
Address reviewer feedback:
BenjaminBossan Feb 10, 2025
bc157e6
Change order of test decorators
BenjaminBossan Feb 10, 2025
bd1da66
Split model and pipeline tests
BenjaminBossan Feb 11, 2025
119a8ed
Reviewer feedback: Move decorator to test classes
BenjaminBossan Feb 12, 2025
53c2f84
Merge branch 'main' into lora-hot-swapping
BenjaminBossan Feb 12, 2025
a715559
Apply suggestions from code review
BenjaminBossan Feb 13, 2025
e40390d
Reviewer feedback: version check, TODO comment
BenjaminBossan Feb 13, 2025
1b834ec
Add enable_lora_hotswap method
BenjaminBossan Feb 14, 2025
4b01401
Merge branch 'main' into lora-hot-swapping
BenjaminBossan Feb 14, 2025
2cd3665
Reviewer feedback: check _lora_loadable_modules
BenjaminBossan Feb 17, 2025
efbd820
Merge branch 'main' into lora-hot-swapping
BenjaminBossan Feb 18, 2025
e735ac2
Revert changes in unet.py
BenjaminBossan Feb 18, 2025
69b637d
Merge branch 'main' into lora-hot-swapping
BenjaminBossan Feb 21, 2025
3a6677c
Add possibility to ignore enabled at wrong time
BenjaminBossan Feb 21, 2025
a96f3fd
Fix docstrings
BenjaminBossan Feb 21, 2025
deab0eb
Merge branch 'main' into lora-hot-swapping
BenjaminBossan Feb 27, 2025
2c6b435
Log possible PEFT error, test
BenjaminBossan Feb 27, 2025
ccb45f7
Raise helpful error if hotswap not supported
BenjaminBossan Feb 27, 2025
09e2ec7
Formatting
BenjaminBossan Feb 27, 2025
67ab6bf
More linter
BenjaminBossan Feb 27, 2025
f03fe6b
More ruff
BenjaminBossan Feb 27, 2025
2d407ca
Doc-builder complaint
BenjaminBossan Feb 27, 2025
6b59ecf
Update docstring:
BenjaminBossan Mar 3, 2025
f14146f
Merge branch 'main' into lora-hot-swapping
yiyixuxu Mar 3, 2025
a79876d
Merge branch 'main' into lora-hot-swapping
BenjaminBossan Mar 5, 2025
c3c1bdf
Fix error in docstring
BenjaminBossan Mar 5, 2025
387ddf6
Update more methods with hotswap argument
BenjaminBossan Mar 7, 2025
7f72d0b
Merge branch 'main' into lora-hot-swapping
BenjaminBossan Mar 7, 2025
dec4d10
Add hotswap argument to load_lora_into_transformer
BenjaminBossan Mar 11, 2025
204f521
Merge branch 'main' into lora-hot-swapping
BenjaminBossan Mar 11, 2025
716f446
Extend docstrings
BenjaminBossan Mar 12, 2025
4d82111
Add version guards to tests
BenjaminBossan Mar 12, 2025
425cb39
Formatting
BenjaminBossan Mar 12, 2025
115c77d
Fix LoRA loading call to add prefix=None
BenjaminBossan Mar 12, 2025
5d90753
Run make fix-copies
BenjaminBossan Mar 12, 2025
62c1c13
Merge branch 'main' into lora-hot-swapping
BenjaminBossan Mar 12, 2025
d6d23b8
Merge branch 'main' into lora-hot-swapping
BenjaminBossan Mar 17, 2025
366632d
Add hot swap documentation to the docs
BenjaminBossan Mar 17, 2025
b181a47
Apply suggestions from code review
BenjaminBossan Mar 18, 2025
f2a6146
Merge branch 'main' into lora-hot-swapping
BenjaminBossan Apr 8, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 90 additions & 3 deletions src/diffusers/loaders/lora_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,11 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
text_encoder_name = TEXT_ENCODER_NAME

def load_lora_weights(
self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs
self,
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
adapter_name=None,
hotswap: bool = False,
**kwargs,
):
"""
Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.unet` and
Expand All @@ -103,6 +107,28 @@ def load_lora_weights(
low_cpu_mem_usage (`bool`, *optional*):
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
weights.
hotswap : (`bool`, *optional*)
Defaults to `False`. Whether to substitute an existing adapter with the newly loaded adapter in-place.
This means that, instead of loading an additional adapter, this will take the existing adapter weights
and replace them with the weights of the new adapter. This can be faster and more memory efficient.
However, the main advantage of hotswapping is that when the model is compiled with torch.compile,
loading the new adapter does not require recompilation of the model.

If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need
to call an additional method before loading the adapter:

```py
from peft.utils.hotswap import prepare_model_for_compiled_hotswap

model = ... # load diffusers model with first LoRA adapter
max_rank = ... # the highest rank among all LoRAs that you want to load
prepare_model_for_compiled_hotswap(model, target_rank=max_rank) # call *before* compiling
model = torch.compile(model)
model.load_lora_adapter(..., hotswap=True) # now hotswap the 2nd adapter
```

There are some limitations to this technique, which are documented here:
https://huggingface.co/docs/peft/main/en/package_reference/hotswap
kwargs (`dict`, *optional*):
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
"""
Expand Down Expand Up @@ -133,6 +159,7 @@ def load_lora_weights(
adapter_name=adapter_name,
_pipeline=self,
low_cpu_mem_usage=low_cpu_mem_usage,
hotswap=hotswap,
)
self.load_lora_into_text_encoder(
state_dict,
Expand Down Expand Up @@ -263,7 +290,14 @@ def lora_state_dict(

@classmethod
def load_lora_into_unet(
cls, state_dict, network_alphas, unet, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False
cls,
state_dict,
network_alphas,
unet,
adapter_name=None,
_pipeline=None,
low_cpu_mem_usage=False,
hotswap: bool = False,
):
"""
This will load the LoRA layers specified in `state_dict` into `unet`.
Expand All @@ -285,6 +319,28 @@ def load_lora_into_unet(
low_cpu_mem_usage (`bool`, *optional*):
Speed up model loading only loading the pretrained LoRA weights and not initializing the random
weights.
hotswap : (`bool`, *optional*)
Defaults to `False`. Whether to substitute an existing adapter with the newly loaded adapter in-place.
This means that, instead of loading an additional adapter, this will take the existing adapter weights
and replace them with the weights of the new adapter. This can be faster and more memory efficient.
However, the main advantage of hotswapping is that when the model is compiled with torch.compile,
loading the new adapter does not require recompilation of the model.

If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need
to call an additional method before loading the adapter:

```py
from peft.utils.hotswap import prepare_model_for_compiled_hotswap

model = ... # load diffusers model with first LoRA adapter
max_rank = ... # the highest rank among all LoRAs that you want to load
prepare_model_for_compiled_hotswap(model, target_rank=max_rank) # call *before* compiling
model = torch.compile(model)
model.load_lora_adapter(..., hotswap=True) # now hotswap the 2nd adapter
```

There are some limitations to this technique, which are documented here:
https://huggingface.co/docs/peft/main/en/package_reference/hotswap
"""
if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for this method.")
Expand All @@ -309,6 +365,7 @@ def load_lora_into_unet(
adapter_name=adapter_name,
_pipeline=_pipeline,
low_cpu_mem_usage=low_cpu_mem_usage,
hotswap=hotswap,
)

@classmethod
Expand Down Expand Up @@ -703,7 +760,14 @@ def lora_state_dict(
@classmethod
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.load_lora_into_unet
def load_lora_into_unet(
cls, state_dict, network_alphas, unet, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False
cls,
state_dict,
network_alphas,
unet,
adapter_name=None,
_pipeline=None,
low_cpu_mem_usage=False,
hotswap: bool = False,
):
"""
This will load the LoRA layers specified in `state_dict` into `unet`.
Expand All @@ -725,6 +789,28 @@ def load_lora_into_unet(
low_cpu_mem_usage (`bool`, *optional*):
Speed up model loading only loading the pretrained LoRA weights and not initializing the random
weights.
hotswap : (`bool`, *optional*)
Defaults to `False`. Whether to substitute an existing adapter with the newly loaded adapter in-place.
This means that, instead of loading an additional adapter, this will take the existing adapter weights
and replace them with the weights of the new adapter. This can be faster and more memory efficient.
However, the main advantage of hotswapping is that when the model is compiled with torch.compile,
loading the new adapter does not require recompilation of the model.

If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need
to call an additional method before loading the adapter:

```py
from peft.utils.hotswap import prepare_model_for_compiled_hotswap

model = ... # load diffusers model with first LoRA adapter
max_rank = ... # the highest rank among all LoRAs that you want to load
prepare_model_for_compiled_hotswap(model, target_rank=max_rank) # call *before* compiling
model = torch.compile(model)
model.load_lora_adapter(..., hotswap=True) # now hotswap the 2nd adapter
```

There are some limitations to this technique, which are documented here:
https://huggingface.co/docs/peft/main/en/package_reference/hotswap
"""
if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for this method.")
Expand All @@ -749,6 +835,7 @@ def load_lora_into_unet(
adapter_name=adapter_name,
_pipeline=_pipeline,
low_cpu_mem_usage=low_cpu_mem_usage,
hotswap=hotswap,
)

@classmethod
Expand Down
73 changes: 69 additions & 4 deletions src/diffusers/loaders/peft.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,9 @@ def _optionally_disable_offloading(cls, _pipeline):
"""
return _func_optionally_disable_offloading(_pipeline=_pipeline)

def load_lora_adapter(self, pretrained_model_name_or_path_or_dict, prefix="transformer", **kwargs):
def load_lora_adapter(
self, pretrained_model_name_or_path_or_dict, prefix="transformer", hotswap: bool = False, **kwargs
):
r"""
Loads a LoRA adapter into the underlying model.

Expand Down Expand Up @@ -182,6 +184,28 @@ def load_lora_adapter(self, pretrained_model_name_or_path_or_dict, prefix="trans
low_cpu_mem_usage (`bool`, *optional*):
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
weights.
hotswap : (`bool`, *optional*)
Defaults to `False`. Whether to substitute an existing adapter with the newly loaded adapter in-place.
This means that, instead of loading an additional adapter, this will take the existing adapter weights
and replace them with the weights of the new adapter. This can be faster and more memory efficient.
However, the main advantage of hotswapping is that when the model is compiled with torch.compile,
loading the new adapter does not require recompilation of the model.

If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need
to call an additional method before loading the adapter:

```py
from peft.utils.hotswap import prepare_model_for_compiled_hotswap

model = ... # load diffusers model with first LoRA adapter
max_rank = ... # the highest rank among all LoRAs that you want to load
prepare_model_for_compiled_hotswap(model, target_rank=max_rank) # call *before* compiling
model = torch.compile(model)
model.load_lora_adapter(..., hotswap=True) # now hotswap the 2nd adapter
```

There are some limitations to this technique, which are documented here:
https://huggingface.co/docs/peft/main/en/package_reference/hotswap
"""
from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict
from peft.tuners.tuners_utils import BaseTunerLayer
Expand Down Expand Up @@ -235,10 +259,15 @@ def load_lora_adapter(self, pretrained_model_name_or_path_or_dict, prefix="trans
state_dict = {k.replace(f"{prefix}.", ""): v for k, v in state_dict.items() if k in model_keys}

if len(state_dict) > 0:
if adapter_name in getattr(self, "peft_config", {}):
if adapter_name in getattr(self, "peft_config", {}) and not hotswap:
raise ValueError(
f"Adapter name {adapter_name} already in use in the model - please select a new adapter name."
)
elif adapter_name not in getattr(self, "peft_config", {}) and hotswap:
raise ValueError(
f"Trying to hotswap LoRA adapter '{adapter_name}' but there is no existing adapter by that name. "
"Please choose an existing adapter name or set `hotswap=False` to prevent hotswapping."
)

# check with first key if is not in peft format
first_key = next(iter(state_dict.keys()))
Expand Down Expand Up @@ -296,11 +325,47 @@ def load_lora_adapter(self, pretrained_model_name_or_path_or_dict, prefix="trans
if is_peft_version(">=", "0.13.1"):
peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage

if hotswap:
try:
from peft.utils.hotswap import check_hotswap_configs_compatible, hotswap_adapter_from_state_dict
except ImportError as exc:
msg = (
"Hotswapping requires PEFT > v0.14. Please upgrade PEFT to a higher version or install it "
"from source."
)
raise ImportError(msg) from exc
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we check the version instead of relying on exception? We have is_peft_version

def is_peft_version(operation: str, version: str):

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did this on purpose, as it allows to test the feature by installing PEFT from main. Otherwise, we'd have to wait for the next PEFT release. Normally, I'd also avoid try import ... for the side effect, but at this point, PEFT is already imported, so that's not a factor.

If you still want me to change this, LMK.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can't we do is_peft_version(">", "0.14.0")?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah yes, that should work, I'll fix.


if hotswap:

def map_state_dict_for_hotswap(sd):
# For hotswapping, we need the adapter name to be present in the state dict keys
new_sd = {}
for k, v in sd.items():
if k.endswith("lora_A.weight") or key.endswith("lora_B.weight"):
k = k[: -len(".weight")] + f".{adapter_name}.weight"
elif k.endswith("lora_B.bias"): # lora_bias=True option
k = k[: -len(".bias")] + f".{adapter_name}.bias"
new_sd[k] = v
return new_sd

# To handle scenarios where we cannot successfully set state dict. If it's unsucessful,
# we should also delete the `peft_config` associated to the `adapter_name`.
try:
inject_adapter_in_model(lora_config, self, adapter_name=adapter_name, **peft_kwargs)
incompatible_keys = set_peft_model_state_dict(self, state_dict, adapter_name, **peft_kwargs)
if hotswap:
state_dict = map_state_dict_for_hotswap(state_dict)
check_hotswap_configs_compatible(self.peft_config[adapter_name], lora_config)
hotswap_adapter_from_state_dict(
model=self,
state_dict=state_dict,
adapter_name=adapter_name,
config=lora_config,
)
# the hotswap function raises if there are incompatible keys, so if we reach this point we can set
# it to None
incompatible_keys = None
else:
inject_adapter_in_model(lora_config, self, adapter_name=adapter_name, **peft_kwargs)
incompatible_keys = set_peft_model_state_dict(self, state_dict, adapter_name, **peft_kwargs)
except Exception as e:
# In case `inject_adapter_in_model()` was unsuccessful even before injecting the `peft_config`.
if hasattr(self, "peft_config"):
Expand Down
72 changes: 68 additions & 4 deletions src/diffusers/loaders/unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,14 @@ def _process_custom_diffusion(self, state_dict):
return attn_processors

def _process_lora(
self, state_dict, unet_identifier_key, network_alphas, adapter_name, _pipeline, low_cpu_mem_usage
self,
state_dict,
unet_identifier_key,
network_alphas,
adapter_name,
_pipeline,
low_cpu_mem_usage,
hotswap: bool = False,
):
# This method does the following things:
# 1. Filters the `state_dict` with keys matching `unet_identifier_key` when using the non-legacy
Expand All @@ -294,6 +301,7 @@ def _process_lora(
raise ValueError("PEFT backend is required for this method.")

from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict
from peft.tuners.tuners_utils import BaseTunerLayer

keys = list(state_dict.keys())

Expand All @@ -313,10 +321,15 @@ def _process_lora(
state_dict_to_be_used = unet_state_dict if len(unet_state_dict) > 0 else state_dict

if len(state_dict_to_be_used) > 0:
if adapter_name in getattr(self, "peft_config", {}):
if adapter_name in getattr(self, "peft_config", {}) and not hotswap:
raise ValueError(
f"Adapter name {adapter_name} already in use in the Unet - please select a new adapter name."
)
elif adapter_name not in getattr(self, "peft_config", {}) and hotswap:
raise ValueError(
f"Trying to hotswap LoRA adapter '{adapter_name}' but there is no existing adapter by that name. "
"Please choose an existing adapter name or set `hotswap=False` to prevent hotswapping."
)

state_dict = convert_unet_state_dict_to_peft(state_dict_to_be_used)

Expand Down Expand Up @@ -364,8 +377,59 @@ def _process_lora(
if is_peft_version(">=", "0.13.1"):
peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage

inject_adapter_in_model(lora_config, self, adapter_name=adapter_name, **peft_kwargs)
incompatible_keys = set_peft_model_state_dict(self, state_dict, adapter_name, **peft_kwargs)
if hotswap:
try:
from peft.utils.hotswap import check_hotswap_configs_compatible, hotswap_adapter_from_state_dict
except ImportError as exc:
msg = (
"Hotswapping requires PEFT > v0.14. Please upgrade PEFT to a higher version or install it "
"from source."
)
raise ImportError(msg) from exc
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we check the version instead of relying on exception? We have is_peft_version

def is_peft_version(operation: str, version: str):

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comment as above.


if hotswap:

def map_state_dict_for_hotswap(sd):
# For hotswapping, we need the adapter name to be present in the state dict keys
new_sd = {}
for k, v in sd.items():
if k.endswith("lora_A.weight") or key.endswith("lora_B.weight"):
k = k[:-7] + f".{adapter_name}.weight"
elif k.endswith("lora_B.bias"): # lora_bias=True option
k = k[:-5] + f".{adapter_name}.bias"
new_sd[k] = v
return new_sd

# To handle scenarios where we cannot successfully set state dict. If it's unsucessful,
# we should also delete the `peft_config` associated to the `adapter_name`.
try:
if hotswap:
check_hotswap_configs_compatible(self.peft_config[adapter_name], lora_config)
hotswap_adapter_from_state_dict(
model=self,
state_dict=state_dict,
adapter_name=adapter_name,
config=lora_config,
)
# the hotswap function raises if there are incompatible keys, so if we reach this point we can set
# it to None
incompatible_keys = None
else:
inject_adapter_in_model(lora_config, self, adapter_name=adapter_name, **peft_kwargs)
incompatible_keys = set_peft_model_state_dict(self, state_dict, adapter_name, **peft_kwargs)
except Exception as e:
# In case `inject_adapter_in_model()` was unsuccessful even before injecting the `peft_config`.
if hasattr(self, "peft_config"):
for module in self.modules():
if isinstance(module, BaseTunerLayer):
active_adapters = module.active_adapters
for active_adapter in active_adapters:
if adapter_name in active_adapter:
module.delete_adapter(adapter_name)

self.peft_config.pop(adapter_name)
logger.error(f"Loading {adapter_name} was unsucessful with the following error: \n{e}")
raise
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we raise the exception properly instead of logging an error?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, are we testing if this error is raised?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that I just copied the pattern from here:

try:
inject_adapter_in_model(lora_config, self, adapter_name=adapter_name, **peft_kwargs)
incompatible_keys = set_peft_model_state_dict(self, state_dict, adapter_name, **peft_kwargs)
except Exception as e:
# In case `inject_adapter_in_model()` was unsuccessful even before injecting the `peft_config`.
if hasattr(self, "peft_config"):
for module in self.modules():
if isinstance(module, BaseTunerLayer):
active_adapters = module.active_adapters
for active_adapter in active_adapters:
if adapter_name in active_adapter:
module.delete_adapter(adapter_name)
self.peft_config.pop(adapter_name)
logger.error(f"Loading {adapter_name} was unsucessful with the following error: \n{e}")
raise

So this is just for consistency.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. Let me provide some reasoning as to why I added the error (referenced) that way.

PEFT already raises a nice error when the code reaches that part in PEFT. So, it didn't make sense to craft another error message on top of it and instead, we decided to just propagate it to the users coming via diffusers.

I think that is okay to do here.

Regardless, @BenjaminBossan are we testing for the error that should be raised here in case hotswap fails? Or no need?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we don't need a test that is specific for hotswap failing, as it doesn't really matter why loading the adapter fails. If a test is added, it should probably be something similar to the test that was added when the change was introduced in peft.py: https://github.com/huggingface/diffusers/pull/10188/files#diff-b544edcc938e163009735ef4fa963abd0a41615c175552160c9e0f94ceb7f552.

Not sure if it's possible, but maybe that test can be adjusted to trigger this code path?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I can look into it in a future PR. Possible to add a note?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a TODO comment.


warn_msg = ""
if incompatible_keys is not None:
Expand Down
Loading