Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
a331838
refactor hotswap tester.
sayakpaul Apr 15, 2025
cf2ea33
fix seeds..
sayakpaul Apr 15, 2025
4b11ab2
add to nightly ci.
sayakpaul Apr 15, 2025
de312da
Merge branch 'main' into enable-hotswap-testing-ci
sayakpaul Apr 15, 2025
dbc78a4
move comment.
sayakpaul Apr 15, 2025
fed0ee1
Merge branch 'main' into enable-hotswap-testing-ci
sayakpaul Apr 15, 2025
714c458
Merge branch 'main' into enable-hotswap-testing-ci
sayakpaul Apr 16, 2025
24a0374
Merge branch 'main' into enable-hotswap-testing-ci
sayakpaul Apr 16, 2025
d2e6c9c
Merge branch 'main' into enable-hotswap-testing-ci
sayakpaul Apr 18, 2025
580e7ae
Merge branch 'main' into enable-hotswap-testing-ci
sayakpaul Apr 21, 2025
56333b9
Merge branch 'main' into enable-hotswap-testing-ci
sayakpaul Apr 22, 2025
e2cd241
move to nightly
sayakpaul Apr 22, 2025
c04b0d2
Merge branch 'main' into enable-hotswap-testing-ci
sayakpaul Apr 24, 2025
95c0b52
Merge branch 'main' into enable-hotswap-testing-ci
sayakpaul Apr 25, 2025
5ad508f
fix conflicts.
sayakpaul Apr 28, 2025
c062b08
fix conflicts.
sayakpaul Apr 28, 2025
6f7011a
Merge branch 'main' into enable-hotswap-testing-ci
sayakpaul Apr 29, 2025
c9f443d
Merge branch 'main' into enable-hotswap-testing-ci
sayakpaul May 2, 2025
4e8dffe
Merge branch 'main' into enable-hotswap-testing-ci
sayakpaul May 3, 2025
b295b69
Merge branch 'main' into enable-hotswap-testing-ci
sayakpaul May 7, 2025
ef31f3e
Merge branch 'main' into enable-hotswap-testing-ci
sayakpaul May 9, 2025
dcedbc0
Merge branch 'main' into enable-hotswap-testing-ci
sayakpaul May 9, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
150 changes: 80 additions & 70 deletions tests/models/test_modeling_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,6 @@
from diffusers.utils.testing_utils import (
CaptureLogger,
backend_empty_cache,
floats_tensor,
get_python_version,
is_torch_compile,
numpy_cosine_similarity_distance,
Expand Down Expand Up @@ -1720,7 +1719,7 @@ def test_push_to_hub_library_name(self):
@require_peft_backend
@require_peft_version_greater("0.14.0")
@is_torch_compile
class TestLoraHotSwappingForModel(unittest.TestCase):
class LoraHotSwappingForModelTesterMixin:
"""Test that hotswapping does not result in recompilation on the model directly.

We're not extensively testing the hotswapping functionality since it is implemented in PEFT and is extensively
Expand All @@ -1741,48 +1740,24 @@ def tearDown(self):
gc.collect()
backend_empty_cache(torch_device)

def get_small_unet(self):
# from diffusers UNet2DConditionModelTests
torch.manual_seed(0)
init_dict = {
"block_out_channels": (4, 8),
"norm_num_groups": 4,
"down_block_types": ("CrossAttnDownBlock2D", "DownBlock2D"),
"up_block_types": ("UpBlock2D", "CrossAttnUpBlock2D"),
"cross_attention_dim": 8,
"attention_head_dim": 2,
"out_channels": 4,
"in_channels": 4,
"layers_per_block": 1,
"sample_size": 16,
}
model = UNet2DConditionModel(**init_dict)
return model.to(torch_device)

def get_unet_lora_config(self, lora_rank, lora_alpha, target_modules):
def get_lora_config(self, lora_rank, lora_alpha, target_modules):
# from diffusers test_models_unet_2d_condition.py
from peft import LoraConfig

unet_lora_config = LoraConfig(
lora_config = LoraConfig(
r=lora_rank,
lora_alpha=lora_alpha,
target_modules=target_modules,
init_lora_weights=False,
use_dora=False,
)
return unet_lora_config

def get_dummy_input(self):
# from UNet2DConditionModelTests
batch_size = 4
num_channels = 4
sizes = (16, 16)

noise = floats_tensor((batch_size, num_channels) + sizes).to(torch_device)
time_step = torch.tensor([10]).to(torch_device)
encoder_hidden_states = floats_tensor((batch_size, 4, 8)).to(torch_device)
return lora_config

return {"sample": noise, "timestep": time_step, "encoder_hidden_states": encoder_hidden_states}
def get_linear_module_name_other_than_attn(self, model):
linear_names = [
name for name, module in model.named_modules() if isinstance(module, nn.Linear) and "to_" not in name
]
return linear_names[0]

def check_model_hotswap(self, do_compile, rank0, rank1, target_modules0, target_modules1=None):
"""
Expand All @@ -1800,23 +1775,27 @@ def check_model_hotswap(self, do_compile, rank0, rank1, target_modules0, target_
fine.
"""
# create 2 adapters with different ranks and alphas
dummy_input = self.get_dummy_input()
torch.manual_seed(0)
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict).to(torch_device)

alpha0, alpha1 = rank0, rank1
max_rank = max([rank0, rank1])
if target_modules1 is None:
target_modules1 = target_modules0[:]
lora_config0 = self.get_unet_lora_config(rank0, alpha0, target_modules0)
lora_config1 = self.get_unet_lora_config(rank1, alpha1, target_modules1)
lora_config0 = self.get_lora_config(rank0, alpha0, target_modules0)
lora_config1 = self.get_lora_config(rank1, alpha1, target_modules1)

unet = self.get_small_unet()
unet.add_adapter(lora_config0, adapter_name="adapter0")
model.add_adapter(lora_config0, adapter_name="adapter0")
with torch.inference_mode():
output0_before = unet(**dummy_input)["sample"]
torch.manual_seed(0)
output0_before = model(**inputs_dict)["sample"]

unet.add_adapter(lora_config1, adapter_name="adapter1")
unet.set_adapter("adapter1")
model.add_adapter(lora_config1, adapter_name="adapter1")
model.set_adapter("adapter1")
with torch.inference_mode():
output1_before = unet(**dummy_input)["sample"]
torch.manual_seed(0)
output1_before = model(**inputs_dict)["sample"]

# sanity checks:
tol = 5e-3
Expand All @@ -1826,40 +1805,43 @@ def check_model_hotswap(self, do_compile, rank0, rank1, target_modules0, target_

with tempfile.TemporaryDirectory() as tmp_dirname:
# save the adapter checkpoints
unet.save_lora_adapter(os.path.join(tmp_dirname, "0"), safe_serialization=True, adapter_name="adapter0")
unet.save_lora_adapter(os.path.join(tmp_dirname, "1"), safe_serialization=True, adapter_name="adapter1")
del unet
model.save_lora_adapter(os.path.join(tmp_dirname, "0"), safe_serialization=True, adapter_name="adapter0")
model.save_lora_adapter(os.path.join(tmp_dirname, "1"), safe_serialization=True, adapter_name="adapter1")
del model

# load the first adapter
unet = self.get_small_unet()
torch.manual_seed(0)
init_dict, _ = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict).to(torch_device)

if do_compile or (rank0 != rank1):
# no need to prepare if the model is not compiled or if the ranks are identical
unet.enable_lora_hotswap(target_rank=max_rank)
model.enable_lora_hotswap(target_rank=max_rank)

file_name0 = os.path.join(os.path.join(tmp_dirname, "0"), "pytorch_lora_weights.safetensors")
file_name1 = os.path.join(os.path.join(tmp_dirname, "1"), "pytorch_lora_weights.safetensors")
unet.load_lora_adapter(file_name0, safe_serialization=True, adapter_name="adapter0", prefix=None)
model.load_lora_adapter(file_name0, safe_serialization=True, adapter_name="adapter0", prefix=None)

if do_compile:
unet = torch.compile(unet, mode="reduce-overhead")
model = torch.compile(model, mode="reduce-overhead")

with torch.inference_mode():
output0_after = unet(**dummy_input)["sample"]
output0_after = model(**inputs_dict)["sample"]
assert torch.allclose(output0_before, output0_after, atol=tol, rtol=tol)

# hotswap the 2nd adapter
unet.load_lora_adapter(file_name1, adapter_name="adapter0", hotswap=True, prefix=None)
model.load_lora_adapter(file_name1, adapter_name="adapter0", hotswap=True, prefix=None)

# we need to call forward to potentially trigger recompilation
with torch.inference_mode():
output1_after = unet(**dummy_input)["sample"]
output1_after = model(**inputs_dict)["sample"]
assert torch.allclose(output1_before, output1_after, atol=tol, rtol=tol)

# check error when not passing valid adapter name
name = "does-not-exist"
msg = f"Trying to hotswap LoRA adapter '{name}' but there is no existing adapter by that name"
with self.assertRaisesRegex(ValueError, msg):
unet.load_lora_adapter(file_name1, adapter_name=name, hotswap=True, prefix=None)
model.load_lora_adapter(file_name1, adapter_name=name, hotswap=True, prefix=None)

@parameterized.expand([(11, 11), (7, 13), (13, 7)]) # important to test small to large and vice versa
def test_hotswapping_model(self, rank0, rank1):
Expand All @@ -1877,58 +1859,86 @@ def test_hotswapping_compiled_model_linear(self, rank0, rank1):
@parameterized.expand([(11, 11), (7, 13), (13, 7)]) # important to test small to large and vice versa
def test_hotswapping_compiled_model_conv2d(self, rank0, rank1):
# It's important to add this context to raise an error on recompilation
if "unet" not in self.model_class.__name__.lower():
return

target_modules = ["conv", "conv1", "conv2"]
with torch._dynamo.config.patch(error_on_recompile=True):
self.check_model_hotswap(do_compile=True, rank0=rank0, rank1=rank1, target_modules0=target_modules)

@parameterized.expand([(11, 11), (7, 13), (13, 7)]) # important to test small to large and vice versa
def test_hotswapping_compiled_model_both_linear_and_conv2d(self, rank0, rank1):
# It's important to add this context to raise an error on recompilation
if "unet" not in self.model_class.__name__.lower():
return

target_modules = ["to_q", "conv"]
with torch._dynamo.config.patch(error_on_recompile=True):
self.check_model_hotswap(do_compile=True, rank0=rank0, rank1=rank1, target_modules0=target_modules)

@parameterized.expand([(11, 11), (7, 13), (13, 7)]) # important to test small to large and vice versa
def test_hotswapping_compiled_model_both_linear_and_other(self, rank0, rank1):
# In `test_hotswapping_compiled_model_both_linear_and_conv2d()`, we check if we can do hotswapping
# with `torch.compile()` for models that have both linear and conv layers. In this test, we check
# if we can target a linear layer from the transformer blocks and another linear layer from non-attention
# block.
# It's important to add this context to raise an error on recompilation
target_modules = ["to_q"]
init_dict, _ = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict)

target_modules.append(self.get_linear_module_name_other_than_attn(model))
del model

with torch._dynamo.config.patch(error_on_recompile=True):
self.check_model_hotswap(do_compile=True, rank0=rank0, rank1=rank1, target_modules0=target_modules)

def test_enable_lora_hotswap_called_after_adapter_added_raises(self):
# ensure that enable_lora_hotswap is called before loading the first adapter
lora_config = self.get_unet_lora_config(8, 8, target_modules=["to_q"])
unet = self.get_small_unet()
unet.add_adapter(lora_config)
lora_config = self.get_lora_config(8, 8, target_modules=["to_q"])
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict).to(torch_device)
model.add_adapter(lora_config)

msg = re.escape("Call `enable_lora_hotswap` before loading the first adapter.")
with self.assertRaisesRegex(RuntimeError, msg):
unet.enable_lora_hotswap(target_rank=32)
model.enable_lora_hotswap(target_rank=32)

def test_enable_lora_hotswap_called_after_adapter_added_warning(self):
# ensure that enable_lora_hotswap is called before loading the first adapter
from diffusers.loaders.peft import logger

lora_config = self.get_unet_lora_config(8, 8, target_modules=["to_q"])
unet = self.get_small_unet()
unet.add_adapter(lora_config)
lora_config = self.get_lora_config(8, 8, target_modules=["to_q"])
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict).to(torch_device)
model.add_adapter(lora_config)
msg = (
"It is recommended to call `enable_lora_hotswap` before loading the first adapter to avoid recompilation."
)
with self.assertLogs(logger=logger, level="WARNING") as cm:
unet.enable_lora_hotswap(target_rank=32, check_compiled="warn")
model.enable_lora_hotswap(target_rank=32, check_compiled="warn")
assert any(msg in log for log in cm.output)

def test_enable_lora_hotswap_called_after_adapter_added_ignore(self):
# check possibility to ignore the error/warning
lora_config = self.get_unet_lora_config(8, 8, target_modules=["to_q"])
unet = self.get_small_unet()
unet.add_adapter(lora_config)
lora_config = self.get_lora_config(8, 8, target_modules=["to_q"])
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict).to(torch_device)
model.add_adapter(lora_config)
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always") # Capture all warnings
unet.enable_lora_hotswap(target_rank=32, check_compiled="warn")
model.enable_lora_hotswap(target_rank=32, check_compiled="warn")
self.assertEqual(len(w), 0, f"Expected no warnings, but got: {[str(warn.message) for warn in w]}")

def test_enable_lora_hotswap_wrong_check_compiled_argument_raises(self):
# check that wrong argument value raises an error
lora_config = self.get_unet_lora_config(8, 8, target_modules=["to_q"])
unet = self.get_small_unet()
unet.add_adapter(lora_config)
lora_config = self.get_lora_config(8, 8, target_modules=["to_q"])
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict).to(torch_device)
model.add_adapter(lora_config)
msg = re.escape("check_compiles should be one of 'error', 'warn', or 'ignore', got 'wrong-argument' instead.")
with self.assertRaisesRegex(ValueError, msg):
unet.enable_lora_hotswap(target_rank=32, check_compiled="wrong-argument")
model.enable_lora_hotswap(target_rank=32, check_compiled="wrong-argument")

def test_hotswap_second_adapter_targets_more_layers_raises(self):
# check the error and log
Expand Down
4 changes: 2 additions & 2 deletions tests/models/transformers/test_models_transformer_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from diffusers.models.embeddings import ImageProjection
from diffusers.utils.testing_utils import enable_full_determinism, torch_device

from ..test_modeling_common import ModelTesterMixin
from ..test_modeling_common import LoraHotSwappingForModelTesterMixin, ModelTesterMixin


enable_full_determinism()
Expand Down Expand Up @@ -78,7 +78,7 @@ def create_flux_ip_adapter_state_dict(model):
return ip_state_dict


class FluxTransformerTests(ModelTesterMixin, unittest.TestCase):
class FluxTransformerTests(ModelTesterMixin, LoraHotSwappingForModelTesterMixin, unittest.TestCase):
model_class = FluxTransformer2DModel
main_input_name = "hidden_states"
# We override the items here because the transformer under consideration is small.
Expand Down
6 changes: 4 additions & 2 deletions tests/models/unets/test_models_unet_2d_condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@
torch_device,
)

from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
from ..test_modeling_common import LoraHotSwappingForModelTesterMixin, ModelTesterMixin, UNetTesterMixin


if is_peft_available():
Expand Down Expand Up @@ -350,7 +350,9 @@ def create_custom_diffusion_layers(model, mock_weights: bool = True):
return custom_diffusion_attn_procs


class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
class UNet2DConditionModelTests(
ModelTesterMixin, LoraHotSwappingForModelTesterMixin, UNetTesterMixin, unittest.TestCase
):
model_class = UNet2DConditionModel
main_input_name = "sample"
# We override the items here because the unet under consideration is small.
Expand Down