Skip to content

Commit 0dc1cdd

Browse files
committed
update the test case file.
1 parent d6e4a1a commit 0dc1cdd

File tree

1 file changed

+46
-0
lines changed

1 file changed

+46
-0
lines changed

examples/dreambooth/test_dreambooth_lora_hidream.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,16 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16+
import json
1617
import logging
1718
import os
1819
import sys
1920
import tempfile
2021

2122
import safetensors
2223

24+
from diffusers.loaders.lora_base import LORA_ADAPTER_METADATA_KEY
25+
2326

2427
sys.path.append("..")
2528
from test_examples_utils import ExamplesTestsAccelerate, run_command # noqa: E402
@@ -175,6 +178,49 @@ def test_dreambooth_lora_hidream_checkpointing_checkpoints_total_limit(self):
175178
{"checkpoint-4", "checkpoint-6"},
176179
)
177180

181+
def test_dreambooth_lora_with_metadata(self):
182+
# Use a `lora_alpha` that is different from `rank`.
183+
lora_alpha = 8
184+
rank = 4
185+
with tempfile.TemporaryDirectory() as tmpdir:
186+
test_args = f"""
187+
{self.script_path}
188+
--pretrained_model_name_or_path {self.pretrained_model_name_or_path}
189+
--instance_data_dir {self.instance_data_dir}
190+
--instance_prompt {self.instance_prompt}
191+
--resolution 64
192+
--train_batch_size 1
193+
--gradient_accumulation_steps 1
194+
--max_train_steps 2
195+
--lora_alpha={lora_alpha}
196+
--rank={rank}
197+
--learning_rate 5.0e-04
198+
--scale_lr
199+
--lr_scheduler constant
200+
--lr_warmup_steps 0
201+
--output_dir {tmpdir}
202+
""".split()
203+
204+
run_command(self._launch_args + test_args)
205+
# save_pretrained smoke test
206+
state_dict_file = os.path.join(tmpdir, "pytorch_lora_weights.safetensors")
207+
self.assertTrue(os.path.isfile(state_dict_file))
208+
209+
# Check if the metadata was properly serialized.
210+
with safetensors.torch.safe_open(state_dict_file, framework="pt", device="cpu") as f:
211+
metadata = f.metadata() or {}
212+
213+
metadata.pop("format", None)
214+
raw = metadata.get(LORA_ADAPTER_METADATA_KEY)
215+
if raw:
216+
raw = json.loads(raw)
217+
218+
loaded_lora_alpha = raw["transformer.lora_alpha"]
219+
self.assertTrue(loaded_lora_alpha == lora_alpha)
220+
loaded_lora_rank = raw["transformer.r"]
221+
self.assertTrue(loaded_lora_rank == rank)
222+
223+
178224
def test_dreambooth_lora_hidream_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self):
179225
with tempfile.TemporaryDirectory() as tmpdir:
180226
test_args = f"""

0 commit comments

Comments
 (0)