|
56 | 56 | from diffusers.utils.hub_utils import _add_variant |
57 | 57 | from diffusers.utils.testing_utils import ( |
58 | 58 | CaptureLogger, |
| 59 | + backend_empty_cache, |
| 60 | + floats_tensor, |
59 | 61 | get_python_version, |
60 | 62 | is_torch_compile, |
61 | 63 | numpy_cosine_similarity_distance, |
| 64 | + require_peft_backend, |
62 | 65 | require_torch_2, |
63 | 66 | require_torch_accelerator, |
64 | 67 | require_torch_accelerator_with_training, |
65 | 68 | require_torch_gpu, |
66 | 69 | require_torch_multi_gpu, |
67 | 70 | run_test_in_subprocess, |
| 71 | + slow, |
68 | 72 | torch_all_close, |
69 | 73 | torch_device, |
70 | 74 | ) |
@@ -1519,3 +1523,188 @@ def test_push_to_hub_library_name(self): |
1519 | 1523 |
|
1520 | 1524 | # Reset repo |
1521 | 1525 | delete_repo(self.repo_id, token=TOKEN) |
| 1526 | + |
| 1527 | + |
| 1528 | +class TestLoraHotSwappingForModel(unittest.TestCase): |
| 1529 | + """Test that hotswapping does not result in recompilation on the model directly. |
| 1530 | +
|
| 1531 | + We're not extensively testing the hotswapping functionality since it is implemented in PEFT and is extensively |
| 1532 | + tested there. The goal of this test is specifically to ensure that hotswapping with diffusers does not require |
| 1533 | + recompilation. |
| 1534 | +
|
| 1535 | + See |
| 1536 | + https://github.com/huggingface/peft/blob/eaab05e18d51fb4cce20a73c9acd82a00c013b83/tests/test_gpu_examples.py#L4252 |
| 1537 | + for the analogous PEFT test. |
| 1538 | +
|
| 1539 | + """ |
| 1540 | + |
| 1541 | + def tearDown(self): |
| 1542 | + # It is critical that the dynamo cache is reset for each test. Otherwise, if the test re-uses the same model, |
| 1543 | + # there will be recompilation errors, as torch caches the model when run in the same process. |
| 1544 | + super().tearDown() |
| 1545 | + torch._dynamo.reset() |
| 1546 | + gc.collect() |
| 1547 | + backend_empty_cache(torch_device) |
| 1548 | + |
| 1549 | + def get_small_unet(self): |
| 1550 | + # from diffusers UNet2DConditionModelTests |
| 1551 | + torch.manual_seed(0) |
| 1552 | + init_dict = { |
| 1553 | + "block_out_channels": (4, 8), |
| 1554 | + "norm_num_groups": 4, |
| 1555 | + "down_block_types": ("CrossAttnDownBlock2D", "DownBlock2D"), |
| 1556 | + "up_block_types": ("UpBlock2D", "CrossAttnUpBlock2D"), |
| 1557 | + "cross_attention_dim": 8, |
| 1558 | + "attention_head_dim": 2, |
| 1559 | + "out_channels": 4, |
| 1560 | + "in_channels": 4, |
| 1561 | + "layers_per_block": 1, |
| 1562 | + "sample_size": 16, |
| 1563 | + } |
| 1564 | + model = UNet2DConditionModel(**init_dict) |
| 1565 | + return model.to(torch_device) |
| 1566 | + |
| 1567 | + def get_unet_lora_config(self, lora_rank, lora_alpha, target_modules): |
| 1568 | + # from diffusers test_models_unet_2d_condition.py |
| 1569 | + from peft import LoraConfig |
| 1570 | + |
| 1571 | + unet_lora_config = LoraConfig( |
| 1572 | + r=lora_rank, |
| 1573 | + lora_alpha=lora_alpha, |
| 1574 | + target_modules=target_modules, |
| 1575 | + init_lora_weights=False, |
| 1576 | + use_dora=False, |
| 1577 | + ) |
| 1578 | + return unet_lora_config |
| 1579 | + |
| 1580 | + def get_dummy_input(self): |
| 1581 | + # from UNet2DConditionModelTests |
| 1582 | + batch_size = 4 |
| 1583 | + num_channels = 4 |
| 1584 | + sizes = (16, 16) |
| 1585 | + |
| 1586 | + noise = floats_tensor((batch_size, num_channels) + sizes).to(torch_device) |
| 1587 | + time_step = torch.tensor([10]).to(torch_device) |
| 1588 | + encoder_hidden_states = floats_tensor((batch_size, 4, 8)).to(torch_device) |
| 1589 | + |
| 1590 | + return {"sample": noise, "timestep": time_step, "encoder_hidden_states": encoder_hidden_states} |
| 1591 | + |
| 1592 | + def check_model_hotswap(self, do_compile, rank0, rank1, target_modules): |
| 1593 | + """ |
| 1594 | + Check that hotswapping works on a small unet. |
| 1595 | +
|
| 1596 | + Steps: |
| 1597 | + - create 2 LoRA adapters and save them |
| 1598 | + - load the first adapter |
| 1599 | + - hotswap the second adapter |
| 1600 | + - check that the outputs are correct |
| 1601 | + - optionally compile the model |
| 1602 | +
|
| 1603 | + Note: We set rank == alpha here because save_lora_adapter does not save the alpha scalings, thus the test would |
| 1604 | + fail if the values are different. Since rank != alpha does not matter for the purpose of this test, this is |
| 1605 | + fine. |
| 1606 | +
|
| 1607 | + """ |
| 1608 | + from peft.utils.hotswap import prepare_model_for_compiled_hotswap |
| 1609 | + |
| 1610 | + dummy_input = self.get_dummy_input() |
| 1611 | + alpha0, alpha1 = rank0, rank1 |
| 1612 | + max_rank = max([rank0, rank1]) |
| 1613 | + lora_config0 = self.get_unet_lora_config(rank0, alpha0, target_modules) |
| 1614 | + lora_config1 = self.get_unet_lora_config(rank1, alpha1, target_modules) |
| 1615 | + |
| 1616 | + unet = self.get_small_unet() |
| 1617 | + unet.add_adapter(lora_config0, adapter_name="adapter0") |
| 1618 | + with torch.inference_mode(): |
| 1619 | + output0_before = unet(**dummy_input)["sample"] |
| 1620 | + |
| 1621 | + unet.add_adapter(lora_config1, adapter_name="adapter1") |
| 1622 | + unet.set_adapter("adapter1") |
| 1623 | + with torch.inference_mode(): |
| 1624 | + output1_before = unet(**dummy_input)["sample"] |
| 1625 | + |
| 1626 | + # sanity checks: |
| 1627 | + tol = 5e-3 |
| 1628 | + assert not torch.allclose(output0_before, output1_before, atol=tol, rtol=tol) |
| 1629 | + assert not (output0_before == 0).all() |
| 1630 | + assert not (output1_before == 0).all() |
| 1631 | + |
| 1632 | + with tempfile.TemporaryDirectory() as tmp_dirname: |
| 1633 | + unet.save_lora_adapter(os.path.join(tmp_dirname, "0"), safe_serialization=True, adapter_name="adapter0") |
| 1634 | + unet.save_lora_adapter(os.path.join(tmp_dirname, "1"), safe_serialization=True, adapter_name="adapter1") |
| 1635 | + del unet |
| 1636 | + |
| 1637 | + unet = self.get_small_unet() |
| 1638 | + file_name0 = os.path.join(os.path.join(tmp_dirname, "0"), "pytorch_lora_weights.safetensors") |
| 1639 | + file_name1 = os.path.join(os.path.join(tmp_dirname, "1"), "pytorch_lora_weights.safetensors") |
| 1640 | + unet.load_lora_adapter(file_name0, safe_serialization=True, adapter_name="adapter0") |
| 1641 | + |
| 1642 | + if do_compile or (rank0 != rank1): |
| 1643 | + # no need to prepare if the model is not compiled or if the ranks are identical |
| 1644 | + prepare_model_for_compiled_hotswap( |
| 1645 | + unet, |
| 1646 | + config={"adapter0": lora_config0, "adapter1": lora_config1}, |
| 1647 | + target_rank=max_rank, |
| 1648 | + ) |
| 1649 | + if do_compile: |
| 1650 | + unet = torch.compile(unet, mode="reduce-overhead") |
| 1651 | + |
| 1652 | + with torch.inference_mode(): |
| 1653 | + output0_after = unet(**dummy_input)["sample"] |
| 1654 | + assert torch.allclose(output0_before, output0_after, atol=tol, rtol=tol) |
| 1655 | + |
| 1656 | + unet.load_lora_adapter(file_name1, adapter_name="adapter0", hotswap=True) |
| 1657 | + |
| 1658 | + # we need to call forward to potentially trigger recompilation |
| 1659 | + with torch.inference_mode(): |
| 1660 | + output1_after = unet(**dummy_input)["sample"] |
| 1661 | + assert torch.allclose(output1_before, output1_after, atol=tol, rtol=tol) |
| 1662 | + |
| 1663 | + # check error when not passing valid adapter name |
| 1664 | + name = "does-not-exist" |
| 1665 | + msg = f"Trying to hotswap LoRA adapter '{name}' but there is no existing adapter by that name" |
| 1666 | + with self.assertRaisesRegex(ValueError, msg): |
| 1667 | + unet.load_lora_adapter(file_name1, adapter_name=name, hotswap=True) |
| 1668 | + |
| 1669 | + @parameterized.expand([(11, 11), (7, 13), (13, 7)]) # important to test small to large and vice versa |
| 1670 | + @slow |
| 1671 | + @require_torch_2 |
| 1672 | + @require_torch_accelerator |
| 1673 | + @require_peft_backend |
| 1674 | + def test_hotswapping_model(self, rank0, rank1): |
| 1675 | + self.check_model_hotswap( |
| 1676 | + do_compile=False, rank0=rank0, rank1=rank1, target_modules=["to_q", "to_k", "to_v", "to_out.0"] |
| 1677 | + ) |
| 1678 | + |
| 1679 | + @parameterized.expand([(11, 11), (7, 13), (13, 7)]) # important to test small to large and vice versa |
| 1680 | + @slow |
| 1681 | + @require_torch_2 |
| 1682 | + @require_torch_accelerator |
| 1683 | + @require_peft_backend |
| 1684 | + def test_hotswapping_compiled_model_linear(self, rank0, rank1): |
| 1685 | + # It's important to add this context to raise an error on recompilation |
| 1686 | + target_modules = ["to_q", "to_k", "to_v", "to_out.0"] |
| 1687 | + with torch._dynamo.config.patch(error_on_recompile=True): |
| 1688 | + self.check_model_hotswap(do_compile=True, rank0=rank0, rank1=rank1, target_modules=target_modules) |
| 1689 | + |
| 1690 | + @parameterized.expand([(11, 11), (7, 13), (13, 7)]) # important to test small to large and vice versa |
| 1691 | + @slow |
| 1692 | + @require_torch_2 |
| 1693 | + @require_torch_accelerator |
| 1694 | + @require_peft_backend |
| 1695 | + def test_hotswapping_compiled_model_conv2d(self, rank0, rank1): |
| 1696 | + # It's important to add this context to raise an error on recompilation |
| 1697 | + target_modules = ["conv", "conv1", "conv2"] |
| 1698 | + with torch._dynamo.config.patch(error_on_recompile=True): |
| 1699 | + self.check_model_hotswap(do_compile=True, rank0=rank0, rank1=rank1, target_modules=target_modules) |
| 1700 | + |
| 1701 | + @parameterized.expand([(11, 11), (7, 13), (13, 7)]) # important to test small to large and vice versa |
| 1702 | + @slow |
| 1703 | + @require_torch_2 |
| 1704 | + @require_torch_accelerator |
| 1705 | + @require_peft_backend |
| 1706 | + def test_hotswapping_compiled_model_both_linear_and_conv2d(self, rank0, rank1): |
| 1707 | + # It's important to add this context to raise an error on recompilation |
| 1708 | + target_modules = ["to_q", "conv"] |
| 1709 | + with torch._dynamo.config.patch(error_on_recompile=True): |
| 1710 | + self.check_model_hotswap(do_compile=True, rank0=rank0, rank1=rank1, target_modules=target_modules) |
0 commit comments