Skip to content

Commit bd1da66

Browse files
Split model and pipeline tests
Also increase test coverage by also targeting conv2d layers (support of which was added recently on the PEFT PR).
1 parent bc157e6 commit bd1da66

File tree

2 files changed

+238
-142
lines changed

2 files changed

+238
-142
lines changed

tests/models/test_modeling_common.py

Lines changed: 189 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,15 +56,19 @@
5656
from diffusers.utils.hub_utils import _add_variant
5757
from diffusers.utils.testing_utils import (
5858
CaptureLogger,
59+
backend_empty_cache,
60+
floats_tensor,
5961
get_python_version,
6062
is_torch_compile,
6163
numpy_cosine_similarity_distance,
64+
require_peft_backend,
6265
require_torch_2,
6366
require_torch_accelerator,
6467
require_torch_accelerator_with_training,
6568
require_torch_gpu,
6669
require_torch_multi_gpu,
6770
run_test_in_subprocess,
71+
slow,
6872
torch_all_close,
6973
torch_device,
7074
)
@@ -1519,3 +1523,188 @@ def test_push_to_hub_library_name(self):
15191523

15201524
# Reset repo
15211525
delete_repo(self.repo_id, token=TOKEN)
1526+
1527+
1528+
class TestLoraHotSwappingForModel(unittest.TestCase):
1529+
"""Test that hotswapping does not result in recompilation on the model directly.
1530+
1531+
We're not extensively testing the hotswapping functionality since it is implemented in PEFT and is extensively
1532+
tested there. The goal of this test is specifically to ensure that hotswapping with diffusers does not require
1533+
recompilation.
1534+
1535+
See
1536+
https://github.com/huggingface/peft/blob/eaab05e18d51fb4cce20a73c9acd82a00c013b83/tests/test_gpu_examples.py#L4252
1537+
for the analogous PEFT test.
1538+
1539+
"""
1540+
1541+
def tearDown(self):
1542+
# It is critical that the dynamo cache is reset for each test. Otherwise, if the test re-uses the same model,
1543+
# there will be recompilation errors, as torch caches the model when run in the same process.
1544+
super().tearDown()
1545+
torch._dynamo.reset()
1546+
gc.collect()
1547+
backend_empty_cache(torch_device)
1548+
1549+
def get_small_unet(self):
1550+
# from diffusers UNet2DConditionModelTests
1551+
torch.manual_seed(0)
1552+
init_dict = {
1553+
"block_out_channels": (4, 8),
1554+
"norm_num_groups": 4,
1555+
"down_block_types": ("CrossAttnDownBlock2D", "DownBlock2D"),
1556+
"up_block_types": ("UpBlock2D", "CrossAttnUpBlock2D"),
1557+
"cross_attention_dim": 8,
1558+
"attention_head_dim": 2,
1559+
"out_channels": 4,
1560+
"in_channels": 4,
1561+
"layers_per_block": 1,
1562+
"sample_size": 16,
1563+
}
1564+
model = UNet2DConditionModel(**init_dict)
1565+
return model.to(torch_device)
1566+
1567+
def get_unet_lora_config(self, lora_rank, lora_alpha, target_modules):
1568+
# from diffusers test_models_unet_2d_condition.py
1569+
from peft import LoraConfig
1570+
1571+
unet_lora_config = LoraConfig(
1572+
r=lora_rank,
1573+
lora_alpha=lora_alpha,
1574+
target_modules=target_modules,
1575+
init_lora_weights=False,
1576+
use_dora=False,
1577+
)
1578+
return unet_lora_config
1579+
1580+
def get_dummy_input(self):
1581+
# from UNet2DConditionModelTests
1582+
batch_size = 4
1583+
num_channels = 4
1584+
sizes = (16, 16)
1585+
1586+
noise = floats_tensor((batch_size, num_channels) + sizes).to(torch_device)
1587+
time_step = torch.tensor([10]).to(torch_device)
1588+
encoder_hidden_states = floats_tensor((batch_size, 4, 8)).to(torch_device)
1589+
1590+
return {"sample": noise, "timestep": time_step, "encoder_hidden_states": encoder_hidden_states}
1591+
1592+
def check_model_hotswap(self, do_compile, rank0, rank1, target_modules):
1593+
"""
1594+
Check that hotswapping works on a small unet.
1595+
1596+
Steps:
1597+
- create 2 LoRA adapters and save them
1598+
- load the first adapter
1599+
- hotswap the second adapter
1600+
- check that the outputs are correct
1601+
- optionally compile the model
1602+
1603+
Note: We set rank == alpha here because save_lora_adapter does not save the alpha scalings, thus the test would
1604+
fail if the values are different. Since rank != alpha does not matter for the purpose of this test, this is
1605+
fine.
1606+
1607+
"""
1608+
from peft.utils.hotswap import prepare_model_for_compiled_hotswap
1609+
1610+
dummy_input = self.get_dummy_input()
1611+
alpha0, alpha1 = rank0, rank1
1612+
max_rank = max([rank0, rank1])
1613+
lora_config0 = self.get_unet_lora_config(rank0, alpha0, target_modules)
1614+
lora_config1 = self.get_unet_lora_config(rank1, alpha1, target_modules)
1615+
1616+
unet = self.get_small_unet()
1617+
unet.add_adapter(lora_config0, adapter_name="adapter0")
1618+
with torch.inference_mode():
1619+
output0_before = unet(**dummy_input)["sample"]
1620+
1621+
unet.add_adapter(lora_config1, adapter_name="adapter1")
1622+
unet.set_adapter("adapter1")
1623+
with torch.inference_mode():
1624+
output1_before = unet(**dummy_input)["sample"]
1625+
1626+
# sanity checks:
1627+
tol = 5e-3
1628+
assert not torch.allclose(output0_before, output1_before, atol=tol, rtol=tol)
1629+
assert not (output0_before == 0).all()
1630+
assert not (output1_before == 0).all()
1631+
1632+
with tempfile.TemporaryDirectory() as tmp_dirname:
1633+
unet.save_lora_adapter(os.path.join(tmp_dirname, "0"), safe_serialization=True, adapter_name="adapter0")
1634+
unet.save_lora_adapter(os.path.join(tmp_dirname, "1"), safe_serialization=True, adapter_name="adapter1")
1635+
del unet
1636+
1637+
unet = self.get_small_unet()
1638+
file_name0 = os.path.join(os.path.join(tmp_dirname, "0"), "pytorch_lora_weights.safetensors")
1639+
file_name1 = os.path.join(os.path.join(tmp_dirname, "1"), "pytorch_lora_weights.safetensors")
1640+
unet.load_lora_adapter(file_name0, safe_serialization=True, adapter_name="adapter0")
1641+
1642+
if do_compile or (rank0 != rank1):
1643+
# no need to prepare if the model is not compiled or if the ranks are identical
1644+
prepare_model_for_compiled_hotswap(
1645+
unet,
1646+
config={"adapter0": lora_config0, "adapter1": lora_config1},
1647+
target_rank=max_rank,
1648+
)
1649+
if do_compile:
1650+
unet = torch.compile(unet, mode="reduce-overhead")
1651+
1652+
with torch.inference_mode():
1653+
output0_after = unet(**dummy_input)["sample"]
1654+
assert torch.allclose(output0_before, output0_after, atol=tol, rtol=tol)
1655+
1656+
unet.load_lora_adapter(file_name1, adapter_name="adapter0", hotswap=True)
1657+
1658+
# we need to call forward to potentially trigger recompilation
1659+
with torch.inference_mode():
1660+
output1_after = unet(**dummy_input)["sample"]
1661+
assert torch.allclose(output1_before, output1_after, atol=tol, rtol=tol)
1662+
1663+
# check error when not passing valid adapter name
1664+
name = "does-not-exist"
1665+
msg = f"Trying to hotswap LoRA adapter '{name}' but there is no existing adapter by that name"
1666+
with self.assertRaisesRegex(ValueError, msg):
1667+
unet.load_lora_adapter(file_name1, adapter_name=name, hotswap=True)
1668+
1669+
@parameterized.expand([(11, 11), (7, 13), (13, 7)]) # important to test small to large and vice versa
1670+
@slow
1671+
@require_torch_2
1672+
@require_torch_accelerator
1673+
@require_peft_backend
1674+
def test_hotswapping_model(self, rank0, rank1):
1675+
self.check_model_hotswap(
1676+
do_compile=False, rank0=rank0, rank1=rank1, target_modules=["to_q", "to_k", "to_v", "to_out.0"]
1677+
)
1678+
1679+
@parameterized.expand([(11, 11), (7, 13), (13, 7)]) # important to test small to large and vice versa
1680+
@slow
1681+
@require_torch_2
1682+
@require_torch_accelerator
1683+
@require_peft_backend
1684+
def test_hotswapping_compiled_model_linear(self, rank0, rank1):
1685+
# It's important to add this context to raise an error on recompilation
1686+
target_modules = ["to_q", "to_k", "to_v", "to_out.0"]
1687+
with torch._dynamo.config.patch(error_on_recompile=True):
1688+
self.check_model_hotswap(do_compile=True, rank0=rank0, rank1=rank1, target_modules=target_modules)
1689+
1690+
@parameterized.expand([(11, 11), (7, 13), (13, 7)]) # important to test small to large and vice versa
1691+
@slow
1692+
@require_torch_2
1693+
@require_torch_accelerator
1694+
@require_peft_backend
1695+
def test_hotswapping_compiled_model_conv2d(self, rank0, rank1):
1696+
# It's important to add this context to raise an error on recompilation
1697+
target_modules = ["conv", "conv1", "conv2"]
1698+
with torch._dynamo.config.patch(error_on_recompile=True):
1699+
self.check_model_hotswap(do_compile=True, rank0=rank0, rank1=rank1, target_modules=target_modules)
1700+
1701+
@parameterized.expand([(11, 11), (7, 13), (13, 7)]) # important to test small to large and vice versa
1702+
@slow
1703+
@require_torch_2
1704+
@require_torch_accelerator
1705+
@require_peft_backend
1706+
def test_hotswapping_compiled_model_both_linear_and_conv2d(self, rank0, rank1):
1707+
# It's important to add this context to raise an error on recompilation
1708+
target_modules = ["to_q", "conv"]
1709+
with torch._dynamo.config.patch(error_on_recompile=True):
1710+
self.check_model_hotswap(do_compile=True, rank0=rank0, rank1=rank1, target_modules=target_modules)

0 commit comments

Comments
 (0)