Skip to content

Commit 3ae62f9

Browse files
Warlord-Khameerabbasi
authored andcommitted
Add Tests
1 parent 0293703 commit 3ae62f9

File tree

1 file changed

+90
-0
lines changed

1 file changed

+90
-0
lines changed

tests/lora/test_lora_layers_af.py

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
# coding=utf-8
2+
# Copyright 2024 HuggingFace Inc.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
import sys
16+
import unittest
17+
18+
from diffusers import (
19+
FlowMatchEulerDiscreteScheduler,
20+
AuraFlowPipeline,
21+
)
22+
from diffusers.utils.testing_utils import is_peft_available, require_peft_backend, require_torch_gpu, torch_device
23+
24+
25+
if is_peft_available():
26+
pass
27+
28+
sys.path.append(".")
29+
30+
from utils import PeftLoraLoaderMixinTests # noqa: E402
31+
32+
33+
@require_peft_backend
34+
class AFLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
35+
pipeline_class = AuraFlowPipeline
36+
scheduler_cls = FlowMatchEulerDiscreteScheduler()
37+
scheduler_kwargs = {}
38+
transformer_kwargs = {
39+
"sample_size": 64,
40+
"patch_size": 2,
41+
"in_channels": 4,
42+
"num_mmdit_layers": 4,
43+
"num_single_dit_layers": 32,
44+
"attention_head_dim": 256,
45+
"num_attention_heads": 12,
46+
"joint_attention_dim": 2048,
47+
"caption_projection_dim": 3072,
48+
"out_channels": 4,
49+
"pos_embed_max_size": 1024,
50+
}
51+
vae_kwargs = {
52+
"sample_size": 1024,
53+
"in_channels": 3,
54+
"out_channels": 3,
55+
"block_out_channels": [
56+
128,
57+
256,
58+
512,
59+
512
60+
],
61+
"layers_per_block": 2,
62+
"latent_channels": 4,
63+
"norm_num_groups": 32,
64+
"use_quant_conv": True,
65+
"use_post_quant_conv": True,
66+
"shift_factor": None,
67+
"scaling_factor": 0.13025,
68+
}
69+
has_three_text_encoders = False
70+
71+
@require_torch_gpu
72+
def test_af_lora(self):
73+
"""
74+
Test loading the loras that are saved with the diffusers and peft formats.
75+
Related PR: https://github.com/huggingface/diffusers/pull/8584
76+
"""
77+
components = self.get_dummy_components()
78+
79+
pipe = self.pipeline_class(**components)
80+
pipe = pipe.to(torch_device)
81+
pipe.set_progress_bar_config(disable=None)
82+
83+
lora_model_id = "Warlord-K/gorkem-auraflow-lora"
84+
85+
lora_filename = "pytorch_lora_weights.safetensors"
86+
pipe.load_lora_weights(lora_model_id, weight_name=lora_filename)
87+
pipe.unload_lora_weights()
88+
89+
lora_filename = "lora_peft_format.safetensors"
90+
pipe.load_lora_weights(lora_model_id, weight_name=lora_filename)

0 commit comments

Comments
 (0)