Skip to content

Commit 76e2727

Browse files
sayakpaula-r-r-o-w
andauthored
[SANA LoRA] sana lora training tests and misc. (#10296)
* sana lora training tests and misc. * remove push to hub * Update examples/dreambooth/train_dreambooth_lora_sana.py Co-authored-by: Aryan <[email protected]> --------- Co-authored-by: Aryan <[email protected]>
1 parent 02c777c commit 76e2727

File tree

4 files changed

+231
-24
lines changed

4 files changed

+231
-24
lines changed
Lines changed: 206 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,206 @@
1+
# coding=utf-8
2+
# Copyright 2024 HuggingFace Inc.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
import logging
17+
import os
18+
import sys
19+
import tempfile
20+
21+
import safetensors
22+
23+
24+
sys.path.append("..")
25+
from test_examples_utils import ExamplesTestsAccelerate, run_command # noqa: E402
26+
27+
28+
logging.basicConfig(level=logging.DEBUG)
29+
30+
logger = logging.getLogger()
31+
stream_handler = logging.StreamHandler(sys.stdout)
32+
logger.addHandler(stream_handler)
33+
34+
35+
class DreamBoothLoRASANA(ExamplesTestsAccelerate):
36+
instance_data_dir = "docs/source/en/imgs"
37+
pretrained_model_name_or_path = "hf-internal-testing/tiny-sana-pipe"
38+
script_path = "examples/dreambooth/train_dreambooth_lora_sana.py"
39+
transformer_layer_type = "transformer_blocks.0.attn1.to_k"
40+
41+
def test_dreambooth_lora_sana(self):
42+
with tempfile.TemporaryDirectory() as tmpdir:
43+
test_args = f"""
44+
{self.script_path}
45+
--pretrained_model_name_or_path {self.pretrained_model_name_or_path}
46+
--instance_data_dir {self.instance_data_dir}
47+
--resolution 32
48+
--train_batch_size 1
49+
--gradient_accumulation_steps 1
50+
--max_train_steps 2
51+
--learning_rate 5.0e-04
52+
--scale_lr
53+
--lr_scheduler constant
54+
--lr_warmup_steps 0
55+
--output_dir {tmpdir}
56+
--max_sequence_length 16
57+
""".split()
58+
59+
test_args.extend(["--instance_prompt", ""])
60+
run_command(self._launch_args + test_args)
61+
# save_pretrained smoke test
62+
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
63+
64+
# make sure the state_dict has the correct naming in the parameters.
65+
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
66+
is_lora = all("lora" in k for k in lora_state_dict.keys())
67+
self.assertTrue(is_lora)
68+
69+
# when not training the text encoder, all the parameters in the state dict should start
70+
# with `"transformer"` in their names.
71+
starts_with_transformer = all(key.startswith("transformer") for key in lora_state_dict.keys())
72+
self.assertTrue(starts_with_transformer)
73+
74+
def test_dreambooth_lora_latent_caching(self):
75+
with tempfile.TemporaryDirectory() as tmpdir:
76+
test_args = f"""
77+
{self.script_path}
78+
--pretrained_model_name_or_path {self.pretrained_model_name_or_path}
79+
--instance_data_dir {self.instance_data_dir}
80+
--resolution 32
81+
--train_batch_size 1
82+
--gradient_accumulation_steps 1
83+
--max_train_steps 2
84+
--cache_latents
85+
--learning_rate 5.0e-04
86+
--scale_lr
87+
--lr_scheduler constant
88+
--lr_warmup_steps 0
89+
--output_dir {tmpdir}
90+
--max_sequence_length 16
91+
""".split()
92+
93+
test_args.extend(["--instance_prompt", ""])
94+
run_command(self._launch_args + test_args)
95+
# save_pretrained smoke test
96+
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
97+
98+
# make sure the state_dict has the correct naming in the parameters.
99+
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
100+
is_lora = all("lora" in k for k in lora_state_dict.keys())
101+
self.assertTrue(is_lora)
102+
103+
# when not training the text encoder, all the parameters in the state dict should start
104+
# with `"transformer"` in their names.
105+
starts_with_transformer = all(key.startswith("transformer") for key in lora_state_dict.keys())
106+
self.assertTrue(starts_with_transformer)
107+
108+
def test_dreambooth_lora_layers(self):
109+
with tempfile.TemporaryDirectory() as tmpdir:
110+
test_args = f"""
111+
{self.script_path}
112+
--pretrained_model_name_or_path {self.pretrained_model_name_or_path}
113+
--instance_data_dir {self.instance_data_dir}
114+
--resolution 32
115+
--train_batch_size 1
116+
--gradient_accumulation_steps 1
117+
--max_train_steps 2
118+
--cache_latents
119+
--learning_rate 5.0e-04
120+
--scale_lr
121+
--lora_layers {self.transformer_layer_type}
122+
--lr_scheduler constant
123+
--lr_warmup_steps 0
124+
--output_dir {tmpdir}
125+
--max_sequence_length 16
126+
""".split()
127+
128+
test_args.extend(["--instance_prompt", ""])
129+
run_command(self._launch_args + test_args)
130+
# save_pretrained smoke test
131+
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
132+
133+
# make sure the state_dict has the correct naming in the parameters.
134+
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
135+
is_lora = all("lora" in k for k in lora_state_dict.keys())
136+
self.assertTrue(is_lora)
137+
138+
# when not training the text encoder, all the parameters in the state dict should start
139+
# with `"transformer"` in their names. In this test, we only params of
140+
# `self.transformer_layer_type` should be in the state dict.
141+
starts_with_transformer = all(self.transformer_layer_type in key for key in lora_state_dict)
142+
self.assertTrue(starts_with_transformer)
143+
144+
def test_dreambooth_lora_sana_checkpointing_checkpoints_total_limit(self):
145+
with tempfile.TemporaryDirectory() as tmpdir:
146+
test_args = f"""
147+
{self.script_path}
148+
--pretrained_model_name_or_path={self.pretrained_model_name_or_path}
149+
--instance_data_dir={self.instance_data_dir}
150+
--output_dir={tmpdir}
151+
--resolution=32
152+
--train_batch_size=1
153+
--gradient_accumulation_steps=1
154+
--max_train_steps=6
155+
--checkpoints_total_limit=2
156+
--checkpointing_steps=2
157+
--max_sequence_length 16
158+
""".split()
159+
160+
test_args.extend(["--instance_prompt", ""])
161+
run_command(self._launch_args + test_args)
162+
163+
self.assertEqual(
164+
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
165+
{"checkpoint-4", "checkpoint-6"},
166+
)
167+
168+
def test_dreambooth_lora_sana_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self):
169+
with tempfile.TemporaryDirectory() as tmpdir:
170+
test_args = f"""
171+
{self.script_path}
172+
--pretrained_model_name_or_path={self.pretrained_model_name_or_path}
173+
--instance_data_dir={self.instance_data_dir}
174+
--output_dir={tmpdir}
175+
--resolution=32
176+
--train_batch_size=1
177+
--gradient_accumulation_steps=1
178+
--max_train_steps=4
179+
--checkpointing_steps=2
180+
--max_sequence_length 166
181+
""".split()
182+
183+
test_args.extend(["--instance_prompt", ""])
184+
run_command(self._launch_args + test_args)
185+
186+
self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-2", "checkpoint-4"})
187+
188+
resume_run_args = f"""
189+
{self.script_path}
190+
--pretrained_model_name_or_path={self.pretrained_model_name_or_path}
191+
--instance_data_dir={self.instance_data_dir}
192+
--output_dir={tmpdir}
193+
--resolution=32
194+
--train_batch_size=1
195+
--gradient_accumulation_steps=1
196+
--max_train_steps=8
197+
--checkpointing_steps=2
198+
--resume_from_checkpoint=checkpoint-4
199+
--checkpoints_total_limit=2
200+
--max_sequence_length 16
201+
""".split()
202+
203+
resume_run_args.extend(["--instance_prompt", ""])
204+
run_command(self._launch_args + resume_run_args)
205+
206+
self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-6", "checkpoint-8"})

examples/dreambooth/train_dreambooth_lora_sana.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -943,7 +943,7 @@ def main(args):
943943

944944
# Load scheduler and models
945945
noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
946-
args.pretrained_model_name_or_path, subfolder="scheduler"
946+
args.pretrained_model_name_or_path, subfolder="scheduler", revision=args.revision
947947
)
948948
noise_scheduler_copy = copy.deepcopy(noise_scheduler)
949949
text_encoder = Gemma2Model.from_pretrained(
@@ -964,15 +964,6 @@ def main(args):
964964
vae.requires_grad_(False)
965965
text_encoder.requires_grad_(False)
966966

967-
# Initialize a text encoding pipeline and keep it to CPU for now.
968-
text_encoding_pipeline = SanaPipeline.from_pretrained(
969-
args.pretrained_model_name_or_path,
970-
vae=None,
971-
transformer=None,
972-
text_encoder=text_encoder,
973-
tokenizer=tokenizer,
974-
)
975-
976967
# For mixed precision training we cast all non-trainable weights (vae, text_encoder and transformer) to half-precision
977968
# as these weights are only used for inference, keeping weights in full precision is not required.
978969
weight_dtype = torch.float32
@@ -993,6 +984,15 @@ def main(args):
993984
# because Gemma2 is particularly suited for bfloat16.
994985
text_encoder.to(dtype=torch.bfloat16)
995986

987+
# Initialize a text encoding pipeline and keep it to CPU for now.
988+
text_encoding_pipeline = SanaPipeline.from_pretrained(
989+
args.pretrained_model_name_or_path,
990+
vae=None,
991+
transformer=None,
992+
text_encoder=text_encoder,
993+
tokenizer=tokenizer,
994+
)
995+
996996
if args.gradient_checkpointing:
997997
transformer.enable_gradient_checkpointing()
998998

@@ -1182,6 +1182,7 @@ def compute_text_embeddings(prompt, text_encoding_pipeline):
11821182
)
11831183
if args.offload:
11841184
text_encoding_pipeline = text_encoding_pipeline.to("cpu")
1185+
prompt_embeds = prompt_embeds.to(transformer.dtype)
11851186
return prompt_embeds, prompt_attention_mask
11861187

11871188
# If no type of tuning is done on the text_encoder and custom instance prompts are NOT
@@ -1216,7 +1217,7 @@ def compute_text_embeddings(prompt, text_encoding_pipeline):
12161217
vae_config_scaling_factor = vae.config.scaling_factor
12171218
if args.cache_latents:
12181219
latents_cache = []
1219-
vae = vae.to("cuda")
1220+
vae = vae.to(accelerator.device)
12201221
for batch in tqdm(train_dataloader, desc="Caching latents"):
12211222
with torch.no_grad():
12221223
batch["pixel_values"] = batch["pixel_values"].to(

tests/lora/test_lora_layers_sana.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
import unittest
1717

1818
import torch
19-
from transformers import Gemma2ForCausalLM, GemmaTokenizer
19+
from transformers import Gemma2Model, GemmaTokenizer
2020

2121
from diffusers import AutoencoderDC, FlowMatchEulerDiscreteScheduler, SanaPipeline, SanaTransformer2DModel
2222
from diffusers.utils.testing_utils import floats_tensor, require_peft_backend
@@ -73,7 +73,7 @@ class SanaLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
7373
}
7474
vae_cls = AutoencoderDC
7575
tokenizer_cls, tokenizer_id = GemmaTokenizer, "hf-internal-testing/dummy-gemma"
76-
text_encoder_cls, text_encoder_id = Gemma2ForCausalLM, "hf-internal-testing/dummy-gemma-for-diffusers"
76+
text_encoder_cls, text_encoder_id = Gemma2Model, "hf-internal-testing/dummy-gemma-for-diffusers"
7777

7878
@property
7979
def output_shape(self):
@@ -105,34 +105,34 @@ def get_dummy_inputs(self, with_generator=True):
105105

106106
return noise, input_ids, pipeline_inputs
107107

108-
@unittest.skip("Not supported in Sana.")
108+
@unittest.skip("Not supported in SANA.")
109109
def test_modify_padding_mode(self):
110110
pass
111111

112-
@unittest.skip("Not supported in Mochi.")
112+
@unittest.skip("Not supported in SANA.")
113113
def test_simple_inference_with_text_denoiser_block_scale(self):
114114
pass
115115

116-
@unittest.skip("Not supported in Mochi.")
116+
@unittest.skip("Not supported in SANA.")
117117
def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self):
118118
pass
119119

120-
@unittest.skip("Text encoder LoRA is not supported in Mochi.")
120+
@unittest.skip("Text encoder LoRA is not supported in SANA.")
121121
def test_simple_inference_with_partial_text_lora(self):
122122
pass
123123

124-
@unittest.skip("Text encoder LoRA is not supported in Mochi.")
124+
@unittest.skip("Text encoder LoRA is not supported in SANA.")
125125
def test_simple_inference_with_text_lora(self):
126126
pass
127127

128-
@unittest.skip("Text encoder LoRA is not supported in Mochi.")
128+
@unittest.skip("Text encoder LoRA is not supported in SANA.")
129129
def test_simple_inference_with_text_lora_and_scale(self):
130130
pass
131131

132-
@unittest.skip("Text encoder LoRA is not supported in Mochi.")
132+
@unittest.skip("Text encoder LoRA is not supported in SANA.")
133133
def test_simple_inference_with_text_lora_fused(self):
134134
pass
135135

136-
@unittest.skip("Text encoder LoRA is not supported in Mochi.")
136+
@unittest.skip("Text encoder LoRA is not supported in SANA.")
137137
def test_simple_inference_with_text_lora_save_load(self):
138138
pass

tests/pipelines/sana/test_sana.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818

1919
import numpy as np
2020
import torch
21-
from transformers import Gemma2Config, Gemma2ForCausalLM, GemmaTokenizer
21+
from transformers import Gemma2Config, Gemma2Model, GemmaTokenizer
2222

2323
from diffusers import AutoencoderDC, FlowMatchEulerDiscreteScheduler, SanaPipeline, SanaTransformer2DModel
2424
from diffusers.utils.testing_utils import (
@@ -101,7 +101,7 @@ def get_dummy_components(self):
101101
torch.manual_seed(0)
102102
text_encoder_config = Gemma2Config(
103103
head_dim=16,
104-
hidden_size=32,
104+
hidden_size=8,
105105
initializer_range=0.02,
106106
intermediate_size=64,
107107
max_position_embeddings=8192,
@@ -112,7 +112,7 @@ def get_dummy_components(self):
112112
vocab_size=8,
113113
attn_implementation="eager",
114114
)
115-
text_encoder = Gemma2ForCausalLM(text_encoder_config)
115+
text_encoder = Gemma2Model(text_encoder_config)
116116
tokenizer = GemmaTokenizer.from_pretrained("hf-internal-testing/dummy-gemma")
117117

118118
components = {

0 commit comments

Comments
 (0)