|
28 | 28 | AutoencoderKL, |
29 | 29 | UNet2DConditionModel, |
30 | 30 | ) |
31 | | -from diffusers.hooks.group_offloading import apply_group_offloading |
| 31 | +from diffusers.hooks.group_offloading import _GROUP_OFFLOADING, apply_group_offloading |
32 | 32 | from diffusers.utils import logging |
33 | 33 | from diffusers.utils.import_utils import is_peft_available |
34 | 34 |
|
@@ -2381,30 +2381,38 @@ def test_lora_group_offloading_delete_adapters(self): |
2381 | 2381 | denoiser.add_adapter(denoiser_lora_config) |
2382 | 2382 | self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.") |
2383 | 2383 |
|
2384 | | - with tempfile.TemporaryDirectory() as tmpdirname: |
2385 | | - modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True) |
2386 | | - lora_state_dicts = self._get_lora_state_dicts(modules_to_save) |
2387 | | - self.pipeline_class.save_lora_weights( |
2388 | | - save_directory=tmpdirname, safe_serialization=True, **lora_state_dicts |
2389 | | - ) |
| 2384 | + try: |
| 2385 | + with tempfile.TemporaryDirectory() as tmpdirname: |
| 2386 | + modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True) |
| 2387 | + lora_state_dicts = self._get_lora_state_dicts(modules_to_save) |
| 2388 | + self.pipeline_class.save_lora_weights( |
| 2389 | + save_directory=tmpdirname, safe_serialization=True, **lora_state_dicts |
| 2390 | + ) |
2390 | 2391 |
|
2391 | | - components, _, _ = self.get_dummy_components() |
2392 | | - pipe = self.pipeline_class(**components) |
2393 | | - pipe.to(torch_device) |
| 2392 | + components, _, _ = self.get_dummy_components() |
| 2393 | + pipe = self.pipeline_class(**components) |
| 2394 | + denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet |
| 2395 | + pipe.to(torch_device) |
| 2396 | + |
| 2397 | + # Enable Group Offloading (leaf_level for more granular testing) |
| 2398 | + apply_group_offloading( |
| 2399 | + denoiser, |
| 2400 | + onload_device=torch_device, |
| 2401 | + offload_device="cpu", |
| 2402 | + offload_type="leaf_level", |
| 2403 | + ) |
2394 | 2404 |
|
2395 | | - denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet |
| 2405 | + pipe.load_lora_weights(tmpdirname, adapter_name="default") |
2396 | 2406 |
|
2397 | | - # Enable Group Offloading (leaf_level) |
2398 | | - apply_group_offloading( |
2399 | | - denoiser, |
2400 | | - onload_device=torch_device, |
2401 | | - offload_device="cpu", |
2402 | | - offload_type="leaf_level", |
2403 | | - ) |
| 2407 | + out_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] |
| 2408 | + |
| 2409 | + # Delete the adapter |
| 2410 | + pipe.delete_adapters("default") |
| 2411 | + |
| 2412 | + out_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] |
2404 | 2413 |
|
2405 | | - pipe.load_lora_weights(tmpdirname, adapter_name="default") |
2406 | | - out_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] |
2407 | | - # Delete the adapter |
2408 | | - pipe.delete_adapters("default") |
2409 | | - out_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] |
2410 | | - self.assertFalse(np.allclose(out_lora, out_no_lora, atol=1e-3, rtol=1e-3)) |
| 2414 | + self.assertFalse(np.allclose(out_lora, out_no_lora, atol=1e-3, rtol=1e-3)) |
| 2415 | + finally: |
| 2416 | + # Clean up the hooks to prevent state leak |
| 2417 | + if hasattr(denoiser, "_diffusers_hook"): |
| 2418 | + denoiser._diffusers_hook.remove_hook(_GROUP_OFFLOADING, recurse=True) |
0 commit comments