|
39 | 39 | is_torch_version, |
40 | 40 | require_peft_backend, |
41 | 41 | require_peft_version_greater, |
| 42 | + require_torch_accelerator, |
42 | 43 | require_transformers_version_greater, |
43 | 44 | skip_mps, |
44 | 45 | torch_device, |
@@ -2355,3 +2356,73 @@ def test_inference_load_delete_load_adapters(self): |
2355 | 2356 | pipe.load_lora_weights(tmpdirname) |
2356 | 2357 | output_lora_loaded = pipe(**inputs, generator=torch.manual_seed(0))[0] |
2357 | 2358 | self.assertTrue(np.allclose(output_adapter_1, output_lora_loaded, atol=1e-3, rtol=1e-3)) |
| 2359 | + |
| 2360 | + def _test_group_offloading_inference_denoiser(self, offload_type, use_stream): |
| 2361 | + from diffusers.hooks.group_offloading import _get_top_level_group_offload_hook |
| 2362 | + |
| 2363 | + onload_device = torch_device |
| 2364 | + offload_device = torch.device("cpu") |
| 2365 | + |
| 2366 | + components, text_lora_config, denoiser_lora_config = self.get_dummy_components(self.scheduler_classes[0]) |
| 2367 | + pipe = self.pipeline_class(**components) |
| 2368 | + pipe = pipe.to(torch_device) |
| 2369 | + pipe.set_progress_bar_config(disable=None) |
| 2370 | + |
| 2371 | + denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet |
| 2372 | + denoiser.add_adapter(denoiser_lora_config) |
| 2373 | + self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.") |
| 2374 | + |
| 2375 | + with tempfile.TemporaryDirectory() as tmpdirname: |
| 2376 | + modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True) |
| 2377 | + lora_state_dicts = self._get_lora_state_dicts(modules_to_save) |
| 2378 | + self.pipeline_class.save_lora_weights( |
| 2379 | + save_directory=tmpdirname, safe_serialization=True, **lora_state_dicts |
| 2380 | + ) |
| 2381 | + self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))) |
| 2382 | + |
| 2383 | + components, _, _ = self.get_dummy_components(self.scheduler_classes[0]) |
| 2384 | + pipe = self.pipeline_class(**components) |
| 2385 | + pipe = pipe.to(torch_device) |
| 2386 | + pipe.set_progress_bar_config(disable=None) |
| 2387 | + denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet |
| 2388 | + |
| 2389 | + pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")) |
| 2390 | + check_if_lora_correctly_set(denoiser) |
| 2391 | + _, _, inputs = self.get_dummy_inputs(with_generator=False) |
| 2392 | + |
| 2393 | + # Test group offloading with load_lora_weights |
| 2394 | + denoiser.enable_group_offload( |
| 2395 | + onload_device=onload_device, |
| 2396 | + offload_device=offload_device, |
| 2397 | + offload_type=offload_type, |
| 2398 | + num_blocks_per_group=1, |
| 2399 | + use_stream=use_stream, |
| 2400 | + ) |
| 2401 | + group_offload_hook_1 = _get_top_level_group_offload_hook(denoiser) |
| 2402 | + self.assertTrue(group_offload_hook_1 is not None) |
| 2403 | + output_1 = pipe(**inputs, generator=torch.manual_seed(0))[0] |
| 2404 | + |
| 2405 | + # Test group offloading after removing the lora |
| 2406 | + pipe.unload_lora_weights() |
| 2407 | + group_offload_hook_2 = _get_top_level_group_offload_hook(denoiser) |
| 2408 | + self.assertTrue(group_offload_hook_2 is not None) |
| 2409 | + output_2 = pipe(**inputs, generator=torch.manual_seed(0))[0] # noqa: F841 |
| 2410 | + |
| 2411 | + # Add the lora again and check if group offloading works |
| 2412 | + pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")) |
| 2413 | + check_if_lora_correctly_set(denoiser) |
| 2414 | + group_offload_hook_3 = _get_top_level_group_offload_hook(denoiser) |
| 2415 | + self.assertTrue(group_offload_hook_3 is not None) |
| 2416 | + output_3 = pipe(**inputs, generator=torch.manual_seed(0))[0] |
| 2417 | + |
| 2418 | + self.assertTrue(np.allclose(output_1, output_3, atol=1e-3, rtol=1e-3)) |
| 2419 | + |
| 2420 | + @parameterized.expand([("block_level", True), ("leaf_level", False), ("leaf_level", True)]) |
| 2421 | + @require_torch_accelerator |
| 2422 | + def test_group_offloading_inference_denoiser(self, offload_type, use_stream): |
| 2423 | + for cls in inspect.getmro(self.__class__): |
| 2424 | + if "test_group_offloading_inference_denoiser" in cls.__dict__ and cls is not PeftLoraLoaderMixinTests: |
| 2425 | + # Skip this test if it is overwritten by child class. We need to do this because parameterized |
| 2426 | + # materializes the test methods on invocation which cannot be overridden. |
| 2427 | + return |
| 2428 | + self._test_group_offloading_inference_denoiser(offload_type, use_stream) |
0 commit comments