Skip to content

Commit fb18269

Browse files
committed
more reorg.
1 parent fc88ac4 commit fb18269

File tree

1 file changed

+155
-145
lines changed

1 file changed

+155
-145
lines changed

tests/lora/utils.py

Lines changed: 155 additions & 145 deletions
Original file line numberDiff line numberDiff line change
@@ -129,151 +129,6 @@ class PeftLoraLoaderMixinTests:
129129
text_encoder_target_modules = ["q_proj", "k_proj", "v_proj", "out_proj"]
130130
denoiser_target_modules = ["to_q", "to_k", "to_v", "to_out.0"]
131131

132-
def get_dummy_components(self, scheduler_cls=None, use_dora=False, lora_alpha=None):
133-
if self.unet_kwargs and self.transformer_kwargs:
134-
raise ValueError("Both `unet_kwargs` and `transformer_kwargs` cannot be specified.")
135-
if self.has_two_text_encoders and self.has_three_text_encoders:
136-
raise ValueError("Both `has_two_text_encoders` and `has_three_text_encoders` cannot be True.")
137-
138-
scheduler_cls = self.scheduler_cls if scheduler_cls is None else scheduler_cls
139-
rank = 4
140-
lora_alpha = rank if lora_alpha is None else lora_alpha
141-
142-
torch.manual_seed(0)
143-
if self.unet_kwargs is not None:
144-
unet = UNet2DConditionModel(**self.unet_kwargs)
145-
else:
146-
transformer = self.transformer_cls(**self.transformer_kwargs)
147-
148-
scheduler = scheduler_cls(**self.scheduler_kwargs)
149-
150-
torch.manual_seed(0)
151-
vae = self.vae_cls(**self.vae_kwargs)
152-
153-
text_encoder = self.text_encoder_cls.from_pretrained(
154-
self.text_encoder_id, subfolder=self.text_encoder_subfolder
155-
)
156-
tokenizer = self.tokenizer_cls.from_pretrained(self.tokenizer_id, subfolder=self.tokenizer_subfolder)
157-
158-
if self.text_encoder_2_cls is not None:
159-
text_encoder_2 = self.text_encoder_2_cls.from_pretrained(
160-
self.text_encoder_2_id, subfolder=self.text_encoder_2_subfolder
161-
)
162-
tokenizer_2 = self.tokenizer_2_cls.from_pretrained(
163-
self.tokenizer_2_id, subfolder=self.tokenizer_2_subfolder
164-
)
165-
166-
if self.text_encoder_3_cls is not None:
167-
text_encoder_3 = self.text_encoder_3_cls.from_pretrained(
168-
self.text_encoder_3_id, subfolder=self.text_encoder_3_subfolder
169-
)
170-
tokenizer_3 = self.tokenizer_3_cls.from_pretrained(
171-
self.tokenizer_3_id, subfolder=self.tokenizer_3_subfolder
172-
)
173-
174-
text_lora_config = LoraConfig(
175-
r=rank,
176-
lora_alpha=lora_alpha,
177-
target_modules=self.text_encoder_target_modules,
178-
init_lora_weights=False,
179-
use_dora=use_dora,
180-
)
181-
182-
denoiser_lora_config = LoraConfig(
183-
r=rank,
184-
lora_alpha=lora_alpha,
185-
target_modules=self.denoiser_target_modules,
186-
init_lora_weights=False,
187-
use_dora=use_dora,
188-
)
189-
190-
pipeline_components = {
191-
"scheduler": scheduler,
192-
"vae": vae,
193-
"text_encoder": text_encoder,
194-
"tokenizer": tokenizer,
195-
}
196-
# Denoiser
197-
if self.unet_kwargs is not None:
198-
pipeline_components.update({"unet": unet})
199-
elif self.transformer_kwargs is not None:
200-
pipeline_components.update({"transformer": transformer})
201-
202-
# Remaining text encoders.
203-
if self.text_encoder_2_cls is not None:
204-
pipeline_components.update({"tokenizer_2": tokenizer_2, "text_encoder_2": text_encoder_2})
205-
if self.text_encoder_3_cls is not None:
206-
pipeline_components.update({"tokenizer_3": tokenizer_3, "text_encoder_3": text_encoder_3})
207-
208-
# Remaining stuff
209-
init_params = inspect.signature(self.pipeline_class.__init__).parameters
210-
if "safety_checker" in init_params:
211-
pipeline_components.update({"safety_checker": None})
212-
if "feature_extractor" in init_params:
213-
pipeline_components.update({"feature_extractor": None})
214-
if "image_encoder" in init_params:
215-
pipeline_components.update({"image_encoder": None})
216-
217-
return pipeline_components, text_lora_config, denoiser_lora_config
218-
219-
@property
220-
def output_shape(self):
221-
raise NotImplementedError
222-
223-
def get_dummy_inputs(self, with_generator=True):
224-
batch_size = 1
225-
sequence_length = 10
226-
num_channels = 4
227-
sizes = (32, 32)
228-
229-
generator = torch.manual_seed(0)
230-
noise = floats_tensor((batch_size, num_channels) + sizes)
231-
input_ids = torch.randint(1, sequence_length, size=(batch_size, sequence_length), generator=generator)
232-
233-
pipeline_inputs = {
234-
"prompt": "A painting of a squirrel eating a burger",
235-
"num_inference_steps": 5,
236-
"guidance_scale": 6.0,
237-
"output_type": "np",
238-
}
239-
if with_generator:
240-
pipeline_inputs.update({"generator": generator})
241-
242-
return noise, input_ids, pipeline_inputs
243-
244-
# Copied from: https://colab.research.google.com/gist/sayakpaul/df2ef6e1ae6d8c10a49d859883b10860/scratchpad.ipynb
245-
def get_dummy_tokens(self):
246-
max_seq_length = 77
247-
248-
inputs = torch.randint(2, 56, size=(1, max_seq_length), generator=torch.manual_seed(0))
249-
250-
prepared_inputs = {}
251-
prepared_inputs["input_ids"] = inputs
252-
return prepared_inputs
253-
254-
def add_adapters_to_pipeline(self, pipe, text_lora_config=None, denoiser_lora_config=None, adapter_name="default"):
255-
if text_lora_config is not None:
256-
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
257-
pipe.text_encoder.add_adapter(text_lora_config, adapter_name=adapter_name)
258-
self.assertTrue(
259-
check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
260-
)
261-
262-
if denoiser_lora_config is not None:
263-
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
264-
denoiser.add_adapter(denoiser_lora_config, adapter_name=adapter_name)
265-
self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
266-
else:
267-
denoiser = None
268-
269-
if text_lora_config is not None and self.has_two_text_encoders or self.has_three_text_encoders:
270-
if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
271-
pipe.text_encoder_2.add_adapter(text_lora_config, adapter_name=adapter_name)
272-
self.assertTrue(
273-
check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
274-
)
275-
return pipe, denoiser
276-
277132
@require_peft_version_greater("0.13.1")
278133
def test_low_cpu_mem_usage_with_injection(self):
279134
"""Tests if we can inject LoRA state dict with low_cpu_mem_usage."""
@@ -2354,6 +2209,161 @@ def test_group_offloading_inference_denoiser(self, offload_type, use_stream):
23542209
return
23552210
self._test_group_offloading_inference_denoiser(offload_type, use_stream)
23562211

2212+
def get_dummy_components(self, scheduler_cls=None, use_dora=False, lora_alpha=None):
2213+
if self.unet_kwargs and self.transformer_kwargs:
2214+
raise ValueError("Both `unet_kwargs` and `transformer_kwargs` cannot be specified.")
2215+
if self.has_two_text_encoders and self.has_three_text_encoders:
2216+
raise ValueError("Both `has_two_text_encoders` and `has_three_text_encoders` cannot be True.")
2217+
2218+
scheduler_cls = self.scheduler_cls if scheduler_cls is None else scheduler_cls
2219+
rank = 4
2220+
lora_alpha = rank if lora_alpha is None else lora_alpha
2221+
2222+
torch.manual_seed(0)
2223+
if self.unet_kwargs is not None:
2224+
unet = UNet2DConditionModel(**self.unet_kwargs)
2225+
else:
2226+
transformer = self.transformer_cls(**self.transformer_kwargs)
2227+
2228+
scheduler = scheduler_cls(**self.scheduler_kwargs)
2229+
2230+
torch.manual_seed(0)
2231+
vae = self.vae_cls(**self.vae_kwargs)
2232+
2233+
text_encoder = self.text_encoder_cls.from_pretrained(
2234+
self.text_encoder_id, subfolder=self.text_encoder_subfolder
2235+
)
2236+
tokenizer = self.tokenizer_cls.from_pretrained(self.tokenizer_id, subfolder=self.tokenizer_subfolder)
2237+
2238+
if self.text_encoder_2_cls is not None:
2239+
text_encoder_2 = self.text_encoder_2_cls.from_pretrained(
2240+
self.text_encoder_2_id, subfolder=self.text_encoder_2_subfolder
2241+
)
2242+
tokenizer_2 = self.tokenizer_2_cls.from_pretrained(
2243+
self.tokenizer_2_id, subfolder=self.tokenizer_2_subfolder
2244+
)
2245+
2246+
if self.text_encoder_3_cls is not None:
2247+
text_encoder_3 = self.text_encoder_3_cls.from_pretrained(
2248+
self.text_encoder_3_id, subfolder=self.text_encoder_3_subfolder
2249+
)
2250+
tokenizer_3 = self.tokenizer_3_cls.from_pretrained(
2251+
self.tokenizer_3_id, subfolder=self.tokenizer_3_subfolder
2252+
)
2253+
2254+
text_lora_config = LoraConfig(
2255+
r=rank,
2256+
lora_alpha=lora_alpha,
2257+
target_modules=self.text_encoder_target_modules,
2258+
init_lora_weights=False,
2259+
use_dora=use_dora,
2260+
)
2261+
2262+
denoiser_lora_config = LoraConfig(
2263+
r=rank,
2264+
lora_alpha=lora_alpha,
2265+
target_modules=self.denoiser_target_modules,
2266+
init_lora_weights=False,
2267+
use_dora=use_dora,
2268+
)
2269+
2270+
pipeline_components = {
2271+
"scheduler": scheduler,
2272+
"vae": vae,
2273+
"text_encoder": text_encoder,
2274+
"tokenizer": tokenizer,
2275+
}
2276+
# Denoiser
2277+
if self.unet_kwargs is not None:
2278+
pipeline_components.update({"unet": unet})
2279+
elif self.transformer_kwargs is not None:
2280+
pipeline_components.update({"transformer": transformer})
2281+
2282+
# Remaining text encoders.
2283+
if self.text_encoder_2_cls is not None:
2284+
pipeline_components.update({"tokenizer_2": tokenizer_2, "text_encoder_2": text_encoder_2})
2285+
if self.text_encoder_3_cls is not None:
2286+
pipeline_components.update({"tokenizer_3": tokenizer_3, "text_encoder_3": text_encoder_3})
2287+
2288+
# Remaining stuff
2289+
init_params = inspect.signature(self.pipeline_class.__init__).parameters
2290+
if "safety_checker" in init_params:
2291+
pipeline_components.update({"safety_checker": None})
2292+
if "feature_extractor" in init_params:
2293+
pipeline_components.update({"feature_extractor": None})
2294+
if "image_encoder" in init_params:
2295+
pipeline_components.update({"image_encoder": None})
2296+
2297+
return pipeline_components, text_lora_config, denoiser_lora_config
2298+
2299+
@property
2300+
def output_shape(self):
2301+
raise NotImplementedError
2302+
2303+
def get_dummy_inputs(self, with_generator=True):
2304+
batch_size = 1
2305+
sequence_length = 10
2306+
num_channels = 4
2307+
sizes = (32, 32)
2308+
2309+
generator = torch.manual_seed(0)
2310+
noise = floats_tensor((batch_size, num_channels) + sizes)
2311+
input_ids = torch.randint(1, sequence_length, size=(batch_size, sequence_length), generator=generator)
2312+
2313+
pipeline_inputs = {
2314+
"prompt": "A painting of a squirrel eating a burger",
2315+
"num_inference_steps": 5,
2316+
"guidance_scale": 6.0,
2317+
"output_type": "np",
2318+
}
2319+
if with_generator:
2320+
pipeline_inputs.update({"generator": generator})
2321+
2322+
return noise, input_ids, pipeline_inputs
2323+
2324+
# Copied from: https://colab.research.google.com/gist/sayakpaul/df2ef6e1ae6d8c10a49d859883b10860/scratchpad.ipynb
2325+
def get_dummy_tokens(self):
2326+
max_seq_length = 77
2327+
2328+
inputs = torch.randint(2, 56, size=(1, max_seq_length), generator=torch.manual_seed(0))
2329+
2330+
prepared_inputs = {}
2331+
prepared_inputs["input_ids"] = inputs
2332+
return prepared_inputs
2333+
2334+
def add_adapters_to_pipeline(self, pipe, text_lora_config=None, denoiser_lora_config=None, adapter_name="default"):
2335+
if text_lora_config is not None:
2336+
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
2337+
pipe.text_encoder.add_adapter(text_lora_config, adapter_name=adapter_name)
2338+
self.assertTrue(
2339+
check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
2340+
)
2341+
2342+
if denoiser_lora_config is not None:
2343+
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
2344+
denoiser.add_adapter(denoiser_lora_config, adapter_name=adapter_name)
2345+
self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
2346+
else:
2347+
denoiser = None
2348+
2349+
if text_lora_config is not None and self.has_two_text_encoders or self.has_three_text_encoders:
2350+
if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
2351+
pipe.text_encoder_2.add_adapter(text_lora_config, adapter_name=adapter_name)
2352+
self.assertTrue(
2353+
check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
2354+
)
2355+
return pipe, denoiser
2356+
2357+
def _setup_pipeline_and_get_base_output(self, scheduler_cls):
2358+
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
2359+
pipe = self.pipeline_class(**components)
2360+
pipe = pipe.to(torch_device)
2361+
pipe.set_progress_bar_config(disable=None)
2362+
_, _, inputs = self.get_dummy_inputs(with_generator=False)
2363+
2364+
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
2365+
return pipe, inputs, output_no_lora, text_lora_config, denoiser_lora_config
2366+
23572367
def _get_lora_state_dicts(self, modules_to_save):
23582368
state_dicts = {}
23592369
for module_name, module in modules_to_save.items():

0 commit comments

Comments
 (0)