Skip to content

Commit 4f0691f

Browse files
committed
Test: Add try/finally to clean up group offloading hooks
1 parent e53ec15 commit 4f0691f

File tree

1 file changed

+32
-24
lines changed

1 file changed

+32
-24
lines changed

tests/lora/utils.py

Lines changed: 32 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
AutoencoderKL,
2929
UNet2DConditionModel,
3030
)
31-
from diffusers.hooks.group_offloading import apply_group_offloading
31+
from diffusers.hooks.group_offloading import _GROUP_OFFLOADING, apply_group_offloading
3232
from diffusers.utils import logging
3333
from diffusers.utils.import_utils import is_peft_available
3434

@@ -2381,30 +2381,38 @@ def test_lora_group_offloading_delete_adapters(self):
23812381
denoiser.add_adapter(denoiser_lora_config)
23822382
self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
23832383

2384-
with tempfile.TemporaryDirectory() as tmpdirname:
2385-
modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
2386-
lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
2387-
self.pipeline_class.save_lora_weights(
2388-
save_directory=tmpdirname, safe_serialization=True, **lora_state_dicts
2389-
)
2384+
try:
2385+
with tempfile.TemporaryDirectory() as tmpdirname:
2386+
modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
2387+
lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
2388+
self.pipeline_class.save_lora_weights(
2389+
save_directory=tmpdirname, safe_serialization=True, **lora_state_dicts
2390+
)
23902391

2391-
components, _, _ = self.get_dummy_components()
2392-
pipe = self.pipeline_class(**components)
2393-
pipe.to(torch_device)
2392+
components, _, _ = self.get_dummy_components()
2393+
pipe = self.pipeline_class(**components)
2394+
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
2395+
pipe.to(torch_device)
2396+
2397+
# Enable Group Offloading (leaf_level for more granular testing)
2398+
apply_group_offloading(
2399+
denoiser,
2400+
onload_device=torch_device,
2401+
offload_device="cpu",
2402+
offload_type="leaf_level",
2403+
)
23942404

2395-
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
2405+
pipe.load_lora_weights(tmpdirname, adapter_name="default")
23962406

2397-
# Enable Group Offloading (leaf_level)
2398-
apply_group_offloading(
2399-
denoiser,
2400-
onload_device=torch_device,
2401-
offload_device="cpu",
2402-
offload_type="leaf_level",
2403-
)
2407+
out_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
2408+
2409+
# Delete the adapter
2410+
pipe.delete_adapters("default")
2411+
2412+
out_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
24042413

2405-
pipe.load_lora_weights(tmpdirname, adapter_name="default")
2406-
out_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
2407-
# Delete the adapter
2408-
pipe.delete_adapters("default")
2409-
out_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
2410-
self.assertFalse(np.allclose(out_lora, out_no_lora, atol=1e-3, rtol=1e-3))
2414+
self.assertFalse(np.allclose(out_lora, out_no_lora, atol=1e-3, rtol=1e-3))
2415+
finally:
2416+
# Clean up the hooks to prevent state leak
2417+
if hasattr(denoiser, "_diffusers_hook"):
2418+
denoiser._diffusers_hook.remove_hook(_GROUP_OFFLOADING, recurse=True)

0 commit comments

Comments
 (0)