Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions tests/models/test_modeling_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
import unittest
import unittest.mock as mock
import uuid
import warnings
from collections import defaultdict
from typing import Dict, List, Optional, Tuple, Union

Expand Down Expand Up @@ -2373,14 +2372,15 @@ def test_enable_lora_hotswap_called_after_adapter_added_warning(self):

def test_enable_lora_hotswap_called_after_adapter_added_ignore(self):
# check possibility to ignore the error/warning
from diffusers.loaders.peft import logger

lora_config = self.get_lora_config(8, 8, target_modules=["to_q"])
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict).to(torch_device)
model.add_adapter(lora_config)
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always") # Capture all warnings
model.enable_lora_hotswap(target_rank=32, check_compiled="warn")
self.assertEqual(len(w), 0, f"Expected no warnings, but got: {[str(warn.message) for warn in w]}")
# note: assertNoLogs requires Python 3.10+
with self.assertNoLogs(logger, level="WARNING"):
model.enable_lora_hotswap(target_rank=32, check_compiled="ignore")

def test_enable_lora_hotswap_wrong_check_compiled_argument_raises(self):
# check that wrong argument value raises an error
Expand Down