Skip to content

Commit 3a4f8a4

Browse files
committed
test_lora_expansion_works_for_absent_keys
1 parent bdc5de5 commit 3a4f8a4

File tree

1 file changed

+51
-0
lines changed

1 file changed

+51
-0
lines changed

tests/lora/test_lora_layers_flux.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
15+
import copy
1516
import gc
1617
import os
1718
import sys
@@ -162,6 +163,56 @@ def test_with_alpha_in_state_dict(self):
162163
)
163164
self.assertFalse(np.allclose(images_lora_with_alpha, images_lora, atol=1e-3, rtol=1e-3))
164165

166+
def test_lora_expansion_works_for_absent_keys(self):
167+
components, _, denoiser_lora_config = self.get_dummy_components(FlowMatchEulerDiscreteScheduler)
168+
pipe = self.pipeline_class(**components)
169+
pipe = pipe.to(torch_device)
170+
pipe.set_progress_bar_config(disable=None)
171+
_, _, inputs = self.get_dummy_inputs(with_generator=False)
172+
173+
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
174+
self.assertTrue(output_no_lora.shape == self.output_shape)
175+
176+
# Modify the config to have a layer which won't be present in the second LoRA we will load.
177+
modified_denoiser_lora_config = copy.deepcopy(denoiser_lora_config)
178+
modified_denoiser_lora_config.target_modules.add("x_embedder")
179+
180+
pipe.transformer.add_adapter(modified_denoiser_lora_config)
181+
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer")
182+
183+
images_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
184+
self.assertFalse(
185+
np.allclose(images_lora, output_no_lora, atol=1e-3, rtol=1e-3),
186+
"LoRA should lead to different results.",
187+
)
188+
189+
with tempfile.TemporaryDirectory() as tmpdirname:
190+
denoiser_state_dict = get_peft_model_state_dict(pipe.transformer)
191+
self.pipeline_class.save_lora_weights(tmpdirname, transformer_lora_layers=denoiser_state_dict)
192+
193+
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")))
194+
pipe.unload_lora_weights()
195+
# Modify the state dict to exclude "x_embedder" related LoRA params.
196+
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))
197+
lora_state_dict_without_xembedder = {k: v for k, v in lora_state_dict.items() if "x_embedder" not in k}
198+
pipe.load_lora_weights(lora_state_dict_without_xembedder, adapter_name="two")
199+
200+
# Load state dict with `x_embedder`.
201+
pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"), adapter_name="one")
202+
203+
pipe.set_adapters(["one", "two"])
204+
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer")
205+
images_lora_with_absent_keys = pipe(**inputs, generator=torch.manual_seed(0)).images
206+
207+
self.assertFalse(
208+
np.allclose(images_lora, images_lora_with_absent_keys, atol=1e-3, rtol=1e-3),
209+
"Different LoRAs should lead to different results.",
210+
)
211+
self.assertFalse(
212+
np.allclose(output_no_lora, images_lora_with_absent_keys, atol=1e-3, rtol=1e-3),
213+
"LoRA should lead to different results.",
214+
)
215+
165216
@unittest.skip("Not supported in Flux.")
166217
def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self):
167218
pass

0 commit comments

Comments
 (0)