Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions src/diffusers/loaders/lora_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -2460,13 +2460,17 @@ def _maybe_expand_lora_state_dict(cls, transformer, lora_state_dict):
if unexpected_modules:
logger.debug(f"Found unexpected modules: {unexpected_modules}. These will be ignored.")

is_peft_loaded = getattr(transformer, "peft_config", None) is not None
transformer_base_layer_keys = {
k[: -len(".base_layer.weight")] for k in transformer_state_dict.keys() if ".base_layer.weight" in k
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note base_layer substring can only be present when the underlying pipeline has at least one LoRA loaded that affects the layer under consideration. So, perhaps it's better to have an is_peft_loaded check?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In your PR description you mention:

If the first loaded Lora model does not have weights for layer n, and the second one does, loading the second model will lead to an error since the transformer state dict currently does not have key n.base_layer.weight.

Note that we may also have an opposite situation i.e., the first LoRA ckpt may have the params while the second LoRA may not. This is what I considered in #10388.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if is_peft_loaded and ".base_layer.weight" in k might be clearer that this is something when a lora is already loaded.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The case where the first LoRA has extra weights than the second is ok on main

  1. Hyper-FLUX.1-dev-8steps-lora.safetensors
  2. Purz/choose-your-own-adventure

or

  1. alimama-creative/FLUX.1-Turbo-Alpha
  2. TTPlanet/Migration_Lora_flux

In this case base_param_name is set to f"{k.replace(prefix, '')}.base_layer.weight" for the 2nd LoRA and all keys exist.

If loaded in the reverse order f"{k.replace(prefix, '')}.base_layer.weight" doesn't exist for the extra weights.

  1. Purz/choose-your-own-adventure
  2. Hyper-FLUX.1-dev-8steps-lora.safetensors

or

  1. TTPlanet/Migration_Lora_flux
  2. alimama-creative/FLUX.1-Turbo-Alpha

KeyError context_embedder.base_layer.weight

So for the extra weights we use f"{k.replace(prefix, '')}.weight". If another LoRA were loaded with context_embedder it would then use context_embedder.base_layer.weight.

We could continue if f"{k.replace(prefix, '')}.base_layer.weight" is not found but the extra weights may need to be expanded.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this case, we are considering that LoRA params for certain modules exist in the first checkpoint while they don't exist in the second checkpoint (or any other subsequent checkpoint).

In this case, we don't want to expand no? Or am I missing something? Perhaps better expressed through a short test case like the one I added here?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The test case passes on main, the test case should be in the reverse order:

        with tempfile.TemporaryDirectory() as tmpdirname:
            denoiser_state_dict = get_peft_model_state_dict(pipe.transformer)
            self.pipeline_class.save_lora_weights(tmpdirname, transformer_lora_layers=denoiser_state_dict)

            self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")))
            pipe.unload_lora_weights()
            # Modify the state dict to exclude "x_embedder" related LoRA params.
            lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))
            lora_state_dict_without_xembedder = {k: v for k, v in lora_state_dict.items() if "x_embedder" not in k}
            pipe.load_lora_weights(lora_state_dict_without_xembedder, adapter_name="two")

            # Load state dict with `x_embedder`.
            pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"), adapter_name="one")
            base_param_name = (
                f"{k.replace(prefix, '')}.base_layer.weight" if is_peft_loaded else f"{k.replace(prefix, '')}.weight"
            )
>           base_weight_param = transformer_state_dict[base_param_name]
E           KeyError: 'x_embedder.base_layer.weight'

src\diffusers\loaders\lora_pipeline.py:2471: KeyError

I think we still want to check whether the param needs to be expanded

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool, I understand it better now. Thanks!

Might be better to ship this PR with proper testing then. Okay with me.

for k in lora_module_names:
if k in unexpected_modules:
continue

base_param_name = (
f"{k.replace(prefix, '')}.base_layer.weight" if is_peft_loaded else f"{k.replace(prefix, '')}.weight"
f"{k.replace(prefix, '')}.base_layer.weight"
if k in transformer_base_layer_keys
else f"{k.replace(prefix, '')}.weight"
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

base_param_name = f"{k.replace(prefix, '')}.weight"
base_layer_name = f"{k.replace(prefix, '')}.base_layer.weight"
if is_peft_loaded and base_layer_name in transformer_state_dict:
    base_param_name = base_layer_name

Something like this might be better.

base_weight_param = transformer_state_dict[base_param_name]
lora_A_param = lora_state_dict[f"{prefix}{k}.lora_A.weight"]
Expand Down
51 changes: 51 additions & 0 deletions tests/lora/test_lora_layers_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import gc
import os
import sys
Expand Down Expand Up @@ -162,6 +163,56 @@ def test_with_alpha_in_state_dict(self):
)
self.assertFalse(np.allclose(images_lora_with_alpha, images_lora, atol=1e-3, rtol=1e-3))

def test_lora_expansion_works_for_absent_keys(self):
components, _, denoiser_lora_config = self.get_dummy_components(FlowMatchEulerDiscreteScheduler)
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)

output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
self.assertTrue(output_no_lora.shape == self.output_shape)

# Modify the config to have a layer which won't be present in the second LoRA we will load.
modified_denoiser_lora_config = copy.deepcopy(denoiser_lora_config)
modified_denoiser_lora_config.target_modules.add("x_embedder")

pipe.transformer.add_adapter(modified_denoiser_lora_config)
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer")

images_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
self.assertFalse(
np.allclose(images_lora, output_no_lora, atol=1e-3, rtol=1e-3),
"LoRA should lead to different results.",
)

with tempfile.TemporaryDirectory() as tmpdirname:
denoiser_state_dict = get_peft_model_state_dict(pipe.transformer)
self.pipeline_class.save_lora_weights(tmpdirname, transformer_lora_layers=denoiser_state_dict)

self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")))
pipe.unload_lora_weights()
# Modify the state dict to exclude "x_embedder" related LoRA params.
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))
lora_state_dict_without_xembedder = {k: v for k, v in lora_state_dict.items() if "x_embedder" not in k}
pipe.load_lora_weights(lora_state_dict_without_xembedder, adapter_name="two")

# Load state dict with `x_embedder`.
pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"), adapter_name="one")

pipe.set_adapters(["one", "two"])
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer")
images_lora_with_absent_keys = pipe(**inputs, generator=torch.manual_seed(0)).images

self.assertFalse(
np.allclose(images_lora, images_lora_with_absent_keys, atol=1e-3, rtol=1e-3),
"Different LoRAs should lead to different results.",
)
self.assertFalse(
np.allclose(output_no_lora, images_lora_with_absent_keys, atol=1e-3, rtol=1e-3),
"LoRA should lead to different results.",
)

@unittest.skip("Not supported in Flux.")
def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self):
pass
Expand Down
Loading