|
25 | 25 | import unittest |
26 | 26 | import unittest.mock as mock |
27 | 27 | import uuid |
28 | | -import warnings |
29 | 28 | from collections import defaultdict |
30 | 29 | from typing import Dict, List, Optional, Tuple, Union |
31 | 30 |
|
@@ -2373,14 +2372,15 @@ def test_enable_lora_hotswap_called_after_adapter_added_warning(self): |
2373 | 2372 |
|
2374 | 2373 | def test_enable_lora_hotswap_called_after_adapter_added_ignore(self): |
2375 | 2374 | # check possibility to ignore the error/warning |
| 2375 | + from diffusers.loaders.peft import logger |
| 2376 | + |
2376 | 2377 | lora_config = self.get_lora_config(8, 8, target_modules=["to_q"]) |
2377 | 2378 | init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() |
2378 | 2379 | model = self.model_class(**init_dict).to(torch_device) |
2379 | 2380 | model.add_adapter(lora_config) |
2380 | | - with warnings.catch_warnings(record=True) as w: |
2381 | | - warnings.simplefilter("always") # Capture all warnings |
2382 | | - model.enable_lora_hotswap(target_rank=32, check_compiled="warn") |
2383 | | - self.assertEqual(len(w), 0, f"Expected no warnings, but got: {[str(warn.message) for warn in w]}") |
| 2381 | + # note: assertNoLogs requires Python 3.10+ |
| 2382 | + with self.assertNoLogs(logger, level="WARNING"): |
| 2383 | + model.enable_lora_hotswap(target_rank=32, check_compiled="ignore") |
2384 | 2384 |
|
2385 | 2385 | def test_enable_lora_hotswap_wrong_check_compiled_argument_raises(self): |
2386 | 2386 | # check that wrong argument value raises an error |
|
0 commit comments