diff --git a/src/diffusers/loaders/lora_conversion_utils.py b/src/diffusers/loaders/lora_conversion_utils.py index d0ca40213b14..51a406b2f6a3 100644 --- a/src/diffusers/loaders/lora_conversion_utils.py +++ b/src/diffusers/loaders/lora_conversion_utils.py @@ -636,10 +636,15 @@ def handle_qkv(sds_sd, ait_sd, sds_key, ait_keys, dims=None): block_num = re.search(r"single_blocks\.(\d+)", old_key).group(1) new_key = f"transformer.single_transformer_blocks.{block_num}" - if "proj_lora1" in old_key or "proj_lora2" in old_key: + if "proj_lora" in old_key: new_key += ".proj_out" - elif "qkv_lora1" in old_key or "qkv_lora2" in old_key: - new_key += ".norm.linear" + elif "qkv_lora" in old_key and "up" not in old_key: + handle_qkv( + old_state_dict, + new_state_dict, + old_key, + [f"transformer.single_transformer_blocks.{block_num}.norm.linear"], + ) if "down" in old_key: new_key += ".lora_A.weight" diff --git a/tests/lora/test_lora_layers_flux.py b/tests/lora/test_lora_layers_flux.py index b58525cc7a6f..e6e87c7ba939 100644 --- a/tests/lora/test_lora_layers_flux.py +++ b/tests/lora/test_lora_layers_flux.py @@ -282,3 +282,28 @@ def test_flux_xlabs(self): max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), out_slice) assert max_diff < 1e-3 + + def test_flux_xlabs_load_lora_with_single_blocks(self): + self.pipeline.load_lora_weights( + "salinasr/test_xlabs_flux_lora_with_singleblocks", weight_name="lora.safetensors" + ) + self.pipeline.fuse_lora() + self.pipeline.unload_lora_weights() + self.pipeline.enable_model_cpu_offload() + + prompt = "a wizard mouse playing chess" + + out = self.pipeline( + prompt, + num_inference_steps=self.num_inference_steps, + guidance_scale=3.5, + output_type="np", + generator=torch.manual_seed(self.seed), + ).images + out_slice = out[0, -3:, -3:, -1].flatten() + expected_slice = np.array( + [0.04882812, 0.04101562, 0.04882812, 0.03710938, 0.02929688, 0.02734375, 0.0234375, 0.01757812, 0.0390625] + ) + max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), out_slice) + + assert max_diff < 1e-3