Skip to content

Commit a331838

Browse files
committed
refactor hotswap tester.
1 parent 8819cda commit a331838

File tree

3 files changed

+86
-74
lines changed

3 files changed

+86
-74
lines changed

tests/models/test_modeling_common.py

Lines changed: 80 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,6 @@
5959
from diffusers.utils.testing_utils import (
6060
CaptureLogger,
6161
backend_empty_cache,
62-
floats_tensor,
6362
get_python_version,
6463
is_torch_compile,
6564
numpy_cosine_similarity_distance,
@@ -1720,7 +1719,7 @@ def test_push_to_hub_library_name(self):
17201719
@require_peft_backend
17211720
@require_peft_version_greater("0.14.0")
17221721
@is_torch_compile
1723-
class TestLoraHotSwappingForModel(unittest.TestCase):
1722+
class LoraHotSwappingForModelTesterMixin:
17241723
"""Test that hotswapping does not result in recompilation on the model directly.
17251724
17261725
We're not extensively testing the hotswapping functionality since it is implemented in PEFT and is extensively
@@ -1741,48 +1740,24 @@ def tearDown(self):
17411740
gc.collect()
17421741
backend_empty_cache(torch_device)
17431742

1744-
def get_small_unet(self):
1745-
# from diffusers UNet2DConditionModelTests
1746-
torch.manual_seed(0)
1747-
init_dict = {
1748-
"block_out_channels": (4, 8),
1749-
"norm_num_groups": 4,
1750-
"down_block_types": ("CrossAttnDownBlock2D", "DownBlock2D"),
1751-
"up_block_types": ("UpBlock2D", "CrossAttnUpBlock2D"),
1752-
"cross_attention_dim": 8,
1753-
"attention_head_dim": 2,
1754-
"out_channels": 4,
1755-
"in_channels": 4,
1756-
"layers_per_block": 1,
1757-
"sample_size": 16,
1758-
}
1759-
model = UNet2DConditionModel(**init_dict)
1760-
return model.to(torch_device)
1761-
1762-
def get_unet_lora_config(self, lora_rank, lora_alpha, target_modules):
1743+
def get_lora_config(self, lora_rank, lora_alpha, target_modules):
17631744
# from diffusers test_models_unet_2d_condition.py
17641745
from peft import LoraConfig
17651746

1766-
unet_lora_config = LoraConfig(
1747+
lora_config = LoraConfig(
17671748
r=lora_rank,
17681749
lora_alpha=lora_alpha,
17691750
target_modules=target_modules,
17701751
init_lora_weights=False,
17711752
use_dora=False,
17721753
)
1773-
return unet_lora_config
1774-
1775-
def get_dummy_input(self):
1776-
# from UNet2DConditionModelTests
1777-
batch_size = 4
1778-
num_channels = 4
1779-
sizes = (16, 16)
1780-
1781-
noise = floats_tensor((batch_size, num_channels) + sizes).to(torch_device)
1782-
time_step = torch.tensor([10]).to(torch_device)
1783-
encoder_hidden_states = floats_tensor((batch_size, 4, 8)).to(torch_device)
1754+
return lora_config
17841755

1785-
return {"sample": noise, "timestep": time_step, "encoder_hidden_states": encoder_hidden_states}
1756+
def get_linear_module_name_other_than_attn(self, model):
1757+
linear_names = [
1758+
name for name, module in model.named_modules() if isinstance(module, nn.Linear) and "to_" not in name
1759+
]
1760+
return linear_names[0]
17861761

17871762
def check_model_hotswap(self, do_compile, rank0, rank1, target_modules0, target_modules1=None):
17881763
"""
@@ -1800,23 +1775,26 @@ def check_model_hotswap(self, do_compile, rank0, rank1, target_modules0, target_
18001775
fine.
18011776
"""
18021777
# create 2 adapters with different ranks and alphas
1803-
dummy_input = self.get_dummy_input()
1778+
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
1779+
model = self.model_class(**init_dict).to(torch_device)
1780+
18041781
alpha0, alpha1 = rank0, rank1
18051782
max_rank = max([rank0, rank1])
18061783
if target_modules1 is None:
18071784
target_modules1 = target_modules0[:]
1808-
lora_config0 = self.get_unet_lora_config(rank0, alpha0, target_modules0)
1809-
lora_config1 = self.get_unet_lora_config(rank1, alpha1, target_modules1)
1785+
lora_config0 = self.get_lora_config(rank0, alpha0, target_modules0)
1786+
lora_config1 = self.get_lora_config(rank1, alpha1, target_modules1)
18101787

1811-
unet = self.get_small_unet()
1812-
unet.add_adapter(lora_config0, adapter_name="adapter0")
1788+
model.add_adapter(lora_config0, adapter_name="adapter0")
18131789
with torch.inference_mode():
1814-
output0_before = unet(**dummy_input)["sample"]
1790+
torch.manual_seed(0)
1791+
output0_before = model(**inputs_dict)["sample"]
18151792

1816-
unet.add_adapter(lora_config1, adapter_name="adapter1")
1817-
unet.set_adapter("adapter1")
1793+
model.add_adapter(lora_config1, adapter_name="adapter1")
1794+
model.set_adapter("adapter1")
18181795
with torch.inference_mode():
1819-
output1_before = unet(**dummy_input)["sample"]
1796+
torch.manual_seed(0)
1797+
output1_before = model(**inputs_dict)["sample"]
18201798

18211799
# sanity checks:
18221800
tol = 5e-3
@@ -1826,40 +1804,44 @@ def check_model_hotswap(self, do_compile, rank0, rank1, target_modules0, target_
18261804

18271805
with tempfile.TemporaryDirectory() as tmp_dirname:
18281806
# save the adapter checkpoints
1829-
unet.save_lora_adapter(os.path.join(tmp_dirname, "0"), safe_serialization=True, adapter_name="adapter0")
1830-
unet.save_lora_adapter(os.path.join(tmp_dirname, "1"), safe_serialization=True, adapter_name="adapter1")
1831-
del unet
1807+
model.save_lora_adapter(os.path.join(tmp_dirname, "0"), safe_serialization=True, adapter_name="adapter0")
1808+
model.save_lora_adapter(os.path.join(tmp_dirname, "1"), safe_serialization=True, adapter_name="adapter1")
1809+
del model
18321810

18331811
# load the first adapter
1834-
unet = self.get_small_unet()
1812+
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
1813+
model = self.model_class(**init_dict).to(torch_device)
1814+
18351815
if do_compile or (rank0 != rank1):
18361816
# no need to prepare if the model is not compiled or if the ranks are identical
1837-
unet.enable_lora_hotswap(target_rank=max_rank)
1817+
model.enable_lora_hotswap(target_rank=max_rank)
18381818

18391819
file_name0 = os.path.join(os.path.join(tmp_dirname, "0"), "pytorch_lora_weights.safetensors")
18401820
file_name1 = os.path.join(os.path.join(tmp_dirname, "1"), "pytorch_lora_weights.safetensors")
1841-
unet.load_lora_adapter(file_name0, safe_serialization=True, adapter_name="adapter0", prefix=None)
1821+
model.load_lora_adapter(file_name0, safe_serialization=True, adapter_name="adapter0", prefix=None)
18421822

18431823
if do_compile:
1844-
unet = torch.compile(unet, mode="reduce-overhead")
1824+
model = torch.compile(model, mode="reduce-overhead")
18451825

18461826
with torch.inference_mode():
1847-
output0_after = unet(**dummy_input)["sample"]
1827+
torch.manual_seed(0)
1828+
output0_after = model(**inputs_dict)["sample"]
18481829
assert torch.allclose(output0_before, output0_after, atol=tol, rtol=tol)
18491830

18501831
# hotswap the 2nd adapter
1851-
unet.load_lora_adapter(file_name1, adapter_name="adapter0", hotswap=True, prefix=None)
1832+
model.load_lora_adapter(file_name1, adapter_name="adapter0", hotswap=True, prefix=None)
18521833

18531834
# we need to call forward to potentially trigger recompilation
18541835
with torch.inference_mode():
1855-
output1_after = unet(**dummy_input)["sample"]
1836+
torch.manual_seed(0)
1837+
output1_after = model(**inputs_dict)["sample"]
18561838
assert torch.allclose(output1_before, output1_after, atol=tol, rtol=tol)
18571839

18581840
# check error when not passing valid adapter name
18591841
name = "does-not-exist"
18601842
msg = f"Trying to hotswap LoRA adapter '{name}' but there is no existing adapter by that name"
18611843
with self.assertRaisesRegex(ValueError, msg):
1862-
unet.load_lora_adapter(file_name1, adapter_name=name, hotswap=True, prefix=None)
1844+
model.load_lora_adapter(file_name1, adapter_name=name, hotswap=True, prefix=None)
18631845

18641846
@parameterized.expand([(11, 11), (7, 13), (13, 7)]) # important to test small to large and vice versa
18651847
def test_hotswapping_model(self, rank0, rank1):
@@ -1877,58 +1859,86 @@ def test_hotswapping_compiled_model_linear(self, rank0, rank1):
18771859
@parameterized.expand([(11, 11), (7, 13), (13, 7)]) # important to test small to large and vice versa
18781860
def test_hotswapping_compiled_model_conv2d(self, rank0, rank1):
18791861
# It's important to add this context to raise an error on recompilation
1862+
if "unet" not in self.model_class.__name__.lower():
1863+
return
1864+
18801865
target_modules = ["conv", "conv1", "conv2"]
18811866
with torch._dynamo.config.patch(error_on_recompile=True):
18821867
self.check_model_hotswap(do_compile=True, rank0=rank0, rank1=rank1, target_modules0=target_modules)
18831868

18841869
@parameterized.expand([(11, 11), (7, 13), (13, 7)]) # important to test small to large and vice versa
18851870
def test_hotswapping_compiled_model_both_linear_and_conv2d(self, rank0, rank1):
18861871
# It's important to add this context to raise an error on recompilation
1872+
if "unet" not in self.model_class.__name__.lower():
1873+
return
1874+
18871875
target_modules = ["to_q", "conv"]
18881876
with torch._dynamo.config.patch(error_on_recompile=True):
18891877
self.check_model_hotswap(do_compile=True, rank0=rank0, rank1=rank1, target_modules0=target_modules)
18901878

1879+
@parameterized.expand([(11, 11), (7, 13), (13, 7)]) # important to test small to large and vice versa
1880+
def test_hotswapping_compiled_model_both_linear_and_other(self, rank0, rank1):
1881+
# In `test_hotswapping_compiled_model_both_linear_and_conv2d()`, we check if we can do hotswapping
1882+
# with `torch.compile()` for models that have both linear and conv layers. In this test, we check
1883+
# if we can target a linear layer from the transformer blocks and another linear layer from non-attention
1884+
# block.
1885+
# It's important to add this context to raise an error on recompilation
1886+
target_modules = ["to_q"]
1887+
init_dict, _ = self.prepare_init_args_and_inputs_for_common()
1888+
model = self.model_class(**init_dict)
1889+
1890+
target_modules.append(self.get_linear_module_name_other_than_attn(model))
1891+
del model
1892+
1893+
with torch._dynamo.config.patch(error_on_recompile=True):
1894+
self.check_model_hotswap(do_compile=True, rank0=rank0, rank1=rank1, target_modules0=target_modules)
1895+
18911896
def test_enable_lora_hotswap_called_after_adapter_added_raises(self):
18921897
# ensure that enable_lora_hotswap is called before loading the first adapter
1893-
lora_config = self.get_unet_lora_config(8, 8, target_modules=["to_q"])
1894-
unet = self.get_small_unet()
1895-
unet.add_adapter(lora_config)
1898+
lora_config = self.get_lora_config(8, 8, target_modules=["to_q"])
1899+
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
1900+
model = self.model_class(**init_dict).to(torch_device)
1901+
model.add_adapter(lora_config)
1902+
18961903
msg = re.escape("Call `enable_lora_hotswap` before loading the first adapter.")
18971904
with self.assertRaisesRegex(RuntimeError, msg):
1898-
unet.enable_lora_hotswap(target_rank=32)
1905+
model.enable_lora_hotswap(target_rank=32)
18991906

19001907
def test_enable_lora_hotswap_called_after_adapter_added_warning(self):
19011908
# ensure that enable_lora_hotswap is called before loading the first adapter
19021909
from diffusers.loaders.peft import logger
19031910

1904-
lora_config = self.get_unet_lora_config(8, 8, target_modules=["to_q"])
1905-
unet = self.get_small_unet()
1906-
unet.add_adapter(lora_config)
1911+
lora_config = self.get_lora_config(8, 8, target_modules=["to_q"])
1912+
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
1913+
model = self.model_class(**init_dict).to(torch_device)
1914+
model.add_adapter(lora_config)
19071915
msg = (
19081916
"It is recommended to call `enable_lora_hotswap` before loading the first adapter to avoid recompilation."
19091917
)
19101918
with self.assertLogs(logger=logger, level="WARNING") as cm:
1911-
unet.enable_lora_hotswap(target_rank=32, check_compiled="warn")
1919+
model.enable_lora_hotswap(target_rank=32, check_compiled="warn")
19121920
assert any(msg in log for log in cm.output)
19131921

19141922
def test_enable_lora_hotswap_called_after_adapter_added_ignore(self):
19151923
# check possibility to ignore the error/warning
1916-
lora_config = self.get_unet_lora_config(8, 8, target_modules=["to_q"])
1917-
unet = self.get_small_unet()
1918-
unet.add_adapter(lora_config)
1924+
lora_config = self.get_lora_config(8, 8, target_modules=["to_q"])
1925+
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
1926+
model = self.model_class(**init_dict).to(torch_device)
1927+
model.add_adapter(lora_config)
19191928
with warnings.catch_warnings(record=True) as w:
19201929
warnings.simplefilter("always") # Capture all warnings
1921-
unet.enable_lora_hotswap(target_rank=32, check_compiled="warn")
1930+
model.enable_lora_hotswap(target_rank=32, check_compiled="warn")
19221931
self.assertEqual(len(w), 0, f"Expected no warnings, but got: {[str(warn.message) for warn in w]}")
19231932

19241933
def test_enable_lora_hotswap_wrong_check_compiled_argument_raises(self):
19251934
# check that wrong argument value raises an error
1926-
lora_config = self.get_unet_lora_config(8, 8, target_modules=["to_q"])
1927-
unet = self.get_small_unet()
1928-
unet.add_adapter(lora_config)
1935+
lora_config = self.get_lora_config(8, 8, target_modules=["to_q"])
1936+
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
1937+
model = self.model_class(**init_dict).to(torch_device)
1938+
model.add_adapter(lora_config)
19291939
msg = re.escape("check_compiles should be one of 'error', 'warn', or 'ignore', got 'wrong-argument' instead.")
19301940
with self.assertRaisesRegex(ValueError, msg):
1931-
unet.enable_lora_hotswap(target_rank=32, check_compiled="wrong-argument")
1941+
model.enable_lora_hotswap(target_rank=32, check_compiled="wrong-argument")
19321942

19331943
def test_hotswap_second_adapter_targets_more_layers_raises(self):
19341944
# check the error and log

tests/models/transformers/test_models_transformer_flux.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from diffusers.models.embeddings import ImageProjection
2323
from diffusers.utils.testing_utils import enable_full_determinism, torch_device
2424

25-
from ..test_modeling_common import ModelTesterMixin
25+
from ..test_modeling_common import LoraHotSwappingForModelTesterMixin, ModelTesterMixin
2626

2727

2828
enable_full_determinism()
@@ -78,7 +78,7 @@ def create_flux_ip_adapter_state_dict(model):
7878
return ip_state_dict
7979

8080

81-
class FluxTransformerTests(ModelTesterMixin, unittest.TestCase):
81+
class FluxTransformerTests(ModelTesterMixin, LoraHotSwappingForModelTesterMixin, unittest.TestCase):
8282
model_class = FluxTransformer2DModel
8383
main_input_name = "hidden_states"
8484
# We override the items here because the transformer under consideration is small.

tests/models/unets/test_models_unet_2d_condition.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@
5353
torch_device,
5454
)
5555

56-
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
56+
from ..test_modeling_common import LoraHotSwappingForModelTesterMixin, ModelTesterMixin, UNetTesterMixin
5757

5858

5959
if is_peft_available():
@@ -350,7 +350,9 @@ def create_custom_diffusion_layers(model, mock_weights: bool = True):
350350
return custom_diffusion_attn_procs
351351

352352

353-
class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
353+
class UNet2DConditionModelTests(
354+
ModelTesterMixin, LoraHotSwappingForModelTesterMixin, UNetTesterMixin, unittest.TestCase
355+
):
354356
model_class = UNet2DConditionModel
355357
main_input_name = "sample"
356358
# We override the items here because the unet under consideration is small.

0 commit comments

Comments
 (0)