Skip to content

Commit 628306b

Browse files
committed
up
1 parent 4db1bfc commit 628306b

File tree

1 file changed

+0
-67
lines changed

1 file changed

+0
-67
lines changed

tests/lora/utils.py

Lines changed: 0 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
15-
import copy
1615
import inspect
1716
import os
1817
import re
@@ -292,20 +291,6 @@ def _get_modules_to_save(self, pipe, has_denoiser=False):
292291

293292
return modules_to_save
294293

295-
def _get_exclude_modules(self, pipe):
296-
from diffusers.utils.peft_utils import _derive_exclude_modules
297-
298-
modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
299-
denoiser = "unet" if self.unet_kwargs is not None else "transformer"
300-
modules_to_save = {k: v for k, v in modules_to_save.items() if k == denoiser}
301-
denoiser_lora_state_dict = self._get_lora_state_dicts(modules_to_save)[f"{denoiser}_lora_layers"]
302-
pipe.unload_lora_weights()
303-
denoiser_state_dict = pipe.unet.state_dict() if self.unet_kwargs is not None else pipe.transformer.state_dict()
304-
exclude_modules = _derive_exclude_modules(
305-
denoiser_state_dict, denoiser_lora_state_dict, adapter_name="default"
306-
)
307-
return exclude_modules
308-
309294
def add_adapters_to_pipeline(self, pipe, text_lora_config=None, denoiser_lora_config=None, adapter_name="default"):
310295
if text_lora_config is not None:
311296
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
@@ -2342,58 +2327,6 @@ def test_lora_unload_add_adapter(self):
23422327
)
23432328
_ = pipe(**inputs, generator=torch.manual_seed(0))[0]
23442329

2345-
@require_peft_version_greater("0.13.2")
2346-
def test_lora_exclude_modules(self):
2347-
"""
2348-
Test to check if `exclude_modules` works or not. It works in the following way:
2349-
we first create a pipeline and insert LoRA config into it. We then derive a `set`
2350-
of modules to exclude by investigating its denoiser state dict and denoiser LoRA
2351-
state dict.
2352-
2353-
We then create a new LoRA config to include the `exclude_modules` and perform tests.
2354-
"""
2355-
scheduler_cls = self.scheduler_classes[0]
2356-
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
2357-
pipe = self.pipeline_class(**components).to(torch_device)
2358-
_, _, inputs = self.get_dummy_inputs(with_generator=False)
2359-
2360-
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
2361-
self.assertTrue(output_no_lora.shape == self.output_shape)
2362-
2363-
# only supported for `denoiser` now
2364-
pipe_cp = copy.deepcopy(pipe)
2365-
pipe_cp, _ = self.add_adapters_to_pipeline(
2366-
pipe_cp, text_lora_config=text_lora_config, denoiser_lora_config=denoiser_lora_config
2367-
)
2368-
denoiser_exclude_modules = self._get_exclude_modules(pipe_cp)
2369-
pipe_cp.to("cpu")
2370-
del pipe_cp
2371-
2372-
denoiser_lora_config.exclude_modules = denoiser_exclude_modules
2373-
pipe, _ = self.add_adapters_to_pipeline(
2374-
pipe, text_lora_config=text_lora_config, denoiser_lora_config=denoiser_lora_config
2375-
)
2376-
output_lora_exclude_modules = pipe(**inputs, generator=torch.manual_seed(0))[0]
2377-
2378-
with tempfile.TemporaryDirectory() as tmpdir:
2379-
modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
2380-
lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
2381-
lora_metadatas = self._get_lora_adapter_metadata(modules_to_save)
2382-
self.pipeline_class.save_lora_weights(save_directory=tmpdir, **lora_state_dicts, **lora_metadatas)
2383-
pipe.unload_lora_weights()
2384-
pipe.load_lora_weights(tmpdir)
2385-
2386-
output_lora_pretrained = pipe(**inputs, generator=torch.manual_seed(0))[0]
2387-
2388-
self.assertTrue(
2389-
not np.allclose(output_no_lora, output_lora_exclude_modules, atol=1e-3, rtol=1e-3),
2390-
"LoRA should change outputs.",
2391-
)
2392-
self.assertTrue(
2393-
np.allclose(output_lora_exclude_modules, output_lora_pretrained, atol=1e-3, rtol=1e-3),
2394-
"Lora outputs should match.",
2395-
)
2396-
23972330
def test_inference_load_delete_load_adapters(self):
23982331
"Tests if `load_lora_weights()` -> `delete_adapters()` -> `load_lora_weights()` works."
23992332
for scheduler_cls in self.scheduler_classes:

0 commit comments

Comments
 (0)