@@ -129,151 +129,6 @@ class PeftLoraLoaderMixinTests:
129129 text_encoder_target_modules = ["q_proj" , "k_proj" , "v_proj" , "out_proj" ]
130130 denoiser_target_modules = ["to_q" , "to_k" , "to_v" , "to_out.0" ]
131131
132- def get_dummy_components (self , scheduler_cls = None , use_dora = False , lora_alpha = None ):
133- if self .unet_kwargs and self .transformer_kwargs :
134- raise ValueError ("Both `unet_kwargs` and `transformer_kwargs` cannot be specified." )
135- if self .has_two_text_encoders and self .has_three_text_encoders :
136- raise ValueError ("Both `has_two_text_encoders` and `has_three_text_encoders` cannot be True." )
137-
138- scheduler_cls = self .scheduler_cls if scheduler_cls is None else scheduler_cls
139- rank = 4
140- lora_alpha = rank if lora_alpha is None else lora_alpha
141-
142- torch .manual_seed (0 )
143- if self .unet_kwargs is not None :
144- unet = UNet2DConditionModel (** self .unet_kwargs )
145- else :
146- transformer = self .transformer_cls (** self .transformer_kwargs )
147-
148- scheduler = scheduler_cls (** self .scheduler_kwargs )
149-
150- torch .manual_seed (0 )
151- vae = self .vae_cls (** self .vae_kwargs )
152-
153- text_encoder = self .text_encoder_cls .from_pretrained (
154- self .text_encoder_id , subfolder = self .text_encoder_subfolder
155- )
156- tokenizer = self .tokenizer_cls .from_pretrained (self .tokenizer_id , subfolder = self .tokenizer_subfolder )
157-
158- if self .text_encoder_2_cls is not None :
159- text_encoder_2 = self .text_encoder_2_cls .from_pretrained (
160- self .text_encoder_2_id , subfolder = self .text_encoder_2_subfolder
161- )
162- tokenizer_2 = self .tokenizer_2_cls .from_pretrained (
163- self .tokenizer_2_id , subfolder = self .tokenizer_2_subfolder
164- )
165-
166- if self .text_encoder_3_cls is not None :
167- text_encoder_3 = self .text_encoder_3_cls .from_pretrained (
168- self .text_encoder_3_id , subfolder = self .text_encoder_3_subfolder
169- )
170- tokenizer_3 = self .tokenizer_3_cls .from_pretrained (
171- self .tokenizer_3_id , subfolder = self .tokenizer_3_subfolder
172- )
173-
174- text_lora_config = LoraConfig (
175- r = rank ,
176- lora_alpha = lora_alpha ,
177- target_modules = self .text_encoder_target_modules ,
178- init_lora_weights = False ,
179- use_dora = use_dora ,
180- )
181-
182- denoiser_lora_config = LoraConfig (
183- r = rank ,
184- lora_alpha = lora_alpha ,
185- target_modules = self .denoiser_target_modules ,
186- init_lora_weights = False ,
187- use_dora = use_dora ,
188- )
189-
190- pipeline_components = {
191- "scheduler" : scheduler ,
192- "vae" : vae ,
193- "text_encoder" : text_encoder ,
194- "tokenizer" : tokenizer ,
195- }
196- # Denoiser
197- if self .unet_kwargs is not None :
198- pipeline_components .update ({"unet" : unet })
199- elif self .transformer_kwargs is not None :
200- pipeline_components .update ({"transformer" : transformer })
201-
202- # Remaining text encoders.
203- if self .text_encoder_2_cls is not None :
204- pipeline_components .update ({"tokenizer_2" : tokenizer_2 , "text_encoder_2" : text_encoder_2 })
205- if self .text_encoder_3_cls is not None :
206- pipeline_components .update ({"tokenizer_3" : tokenizer_3 , "text_encoder_3" : text_encoder_3 })
207-
208- # Remaining stuff
209- init_params = inspect .signature (self .pipeline_class .__init__ ).parameters
210- if "safety_checker" in init_params :
211- pipeline_components .update ({"safety_checker" : None })
212- if "feature_extractor" in init_params :
213- pipeline_components .update ({"feature_extractor" : None })
214- if "image_encoder" in init_params :
215- pipeline_components .update ({"image_encoder" : None })
216-
217- return pipeline_components , text_lora_config , denoiser_lora_config
218-
219- @property
220- def output_shape (self ):
221- raise NotImplementedError
222-
223- def get_dummy_inputs (self , with_generator = True ):
224- batch_size = 1
225- sequence_length = 10
226- num_channels = 4
227- sizes = (32 , 32 )
228-
229- generator = torch .manual_seed (0 )
230- noise = floats_tensor ((batch_size , num_channels ) + sizes )
231- input_ids = torch .randint (1 , sequence_length , size = (batch_size , sequence_length ), generator = generator )
232-
233- pipeline_inputs = {
234- "prompt" : "A painting of a squirrel eating a burger" ,
235- "num_inference_steps" : 5 ,
236- "guidance_scale" : 6.0 ,
237- "output_type" : "np" ,
238- }
239- if with_generator :
240- pipeline_inputs .update ({"generator" : generator })
241-
242- return noise , input_ids , pipeline_inputs
243-
244- # Copied from: https://colab.research.google.com/gist/sayakpaul/df2ef6e1ae6d8c10a49d859883b10860/scratchpad.ipynb
245- def get_dummy_tokens (self ):
246- max_seq_length = 77
247-
248- inputs = torch .randint (2 , 56 , size = (1 , max_seq_length ), generator = torch .manual_seed (0 ))
249-
250- prepared_inputs = {}
251- prepared_inputs ["input_ids" ] = inputs
252- return prepared_inputs
253-
254- def add_adapters_to_pipeline (self , pipe , text_lora_config = None , denoiser_lora_config = None , adapter_name = "default" ):
255- if text_lora_config is not None :
256- if "text_encoder" in self .pipeline_class ._lora_loadable_modules :
257- pipe .text_encoder .add_adapter (text_lora_config , adapter_name = adapter_name )
258- self .assertTrue (
259- check_if_lora_correctly_set (pipe .text_encoder ), "Lora not correctly set in text encoder"
260- )
261-
262- if denoiser_lora_config is not None :
263- denoiser = pipe .transformer if self .unet_kwargs is None else pipe .unet
264- denoiser .add_adapter (denoiser_lora_config , adapter_name = adapter_name )
265- self .assertTrue (check_if_lora_correctly_set (denoiser ), "Lora not correctly set in denoiser." )
266- else :
267- denoiser = None
268-
269- if text_lora_config is not None and self .has_two_text_encoders or self .has_three_text_encoders :
270- if "text_encoder_2" in self .pipeline_class ._lora_loadable_modules :
271- pipe .text_encoder_2 .add_adapter (text_lora_config , adapter_name = adapter_name )
272- self .assertTrue (
273- check_if_lora_correctly_set (pipe .text_encoder_2 ), "Lora not correctly set in text encoder 2"
274- )
275- return pipe , denoiser
276-
277132 @require_peft_version_greater ("0.13.1" )
278133 def test_low_cpu_mem_usage_with_injection (self ):
279134 """Tests if we can inject LoRA state dict with low_cpu_mem_usage."""
@@ -2354,6 +2209,161 @@ def test_group_offloading_inference_denoiser(self, offload_type, use_stream):
23542209 return
23552210 self ._test_group_offloading_inference_denoiser (offload_type , use_stream )
23562211
2212+ def get_dummy_components (self , scheduler_cls = None , use_dora = False , lora_alpha = None ):
2213+ if self .unet_kwargs and self .transformer_kwargs :
2214+ raise ValueError ("Both `unet_kwargs` and `transformer_kwargs` cannot be specified." )
2215+ if self .has_two_text_encoders and self .has_three_text_encoders :
2216+ raise ValueError ("Both `has_two_text_encoders` and `has_three_text_encoders` cannot be True." )
2217+
2218+ scheduler_cls = self .scheduler_cls if scheduler_cls is None else scheduler_cls
2219+ rank = 4
2220+ lora_alpha = rank if lora_alpha is None else lora_alpha
2221+
2222+ torch .manual_seed (0 )
2223+ if self .unet_kwargs is not None :
2224+ unet = UNet2DConditionModel (** self .unet_kwargs )
2225+ else :
2226+ transformer = self .transformer_cls (** self .transformer_kwargs )
2227+
2228+ scheduler = scheduler_cls (** self .scheduler_kwargs )
2229+
2230+ torch .manual_seed (0 )
2231+ vae = self .vae_cls (** self .vae_kwargs )
2232+
2233+ text_encoder = self .text_encoder_cls .from_pretrained (
2234+ self .text_encoder_id , subfolder = self .text_encoder_subfolder
2235+ )
2236+ tokenizer = self .tokenizer_cls .from_pretrained (self .tokenizer_id , subfolder = self .tokenizer_subfolder )
2237+
2238+ if self .text_encoder_2_cls is not None :
2239+ text_encoder_2 = self .text_encoder_2_cls .from_pretrained (
2240+ self .text_encoder_2_id , subfolder = self .text_encoder_2_subfolder
2241+ )
2242+ tokenizer_2 = self .tokenizer_2_cls .from_pretrained (
2243+ self .tokenizer_2_id , subfolder = self .tokenizer_2_subfolder
2244+ )
2245+
2246+ if self .text_encoder_3_cls is not None :
2247+ text_encoder_3 = self .text_encoder_3_cls .from_pretrained (
2248+ self .text_encoder_3_id , subfolder = self .text_encoder_3_subfolder
2249+ )
2250+ tokenizer_3 = self .tokenizer_3_cls .from_pretrained (
2251+ self .tokenizer_3_id , subfolder = self .tokenizer_3_subfolder
2252+ )
2253+
2254+ text_lora_config = LoraConfig (
2255+ r = rank ,
2256+ lora_alpha = lora_alpha ,
2257+ target_modules = self .text_encoder_target_modules ,
2258+ init_lora_weights = False ,
2259+ use_dora = use_dora ,
2260+ )
2261+
2262+ denoiser_lora_config = LoraConfig (
2263+ r = rank ,
2264+ lora_alpha = lora_alpha ,
2265+ target_modules = self .denoiser_target_modules ,
2266+ init_lora_weights = False ,
2267+ use_dora = use_dora ,
2268+ )
2269+
2270+ pipeline_components = {
2271+ "scheduler" : scheduler ,
2272+ "vae" : vae ,
2273+ "text_encoder" : text_encoder ,
2274+ "tokenizer" : tokenizer ,
2275+ }
2276+ # Denoiser
2277+ if self .unet_kwargs is not None :
2278+ pipeline_components .update ({"unet" : unet })
2279+ elif self .transformer_kwargs is not None :
2280+ pipeline_components .update ({"transformer" : transformer })
2281+
2282+ # Remaining text encoders.
2283+ if self .text_encoder_2_cls is not None :
2284+ pipeline_components .update ({"tokenizer_2" : tokenizer_2 , "text_encoder_2" : text_encoder_2 })
2285+ if self .text_encoder_3_cls is not None :
2286+ pipeline_components .update ({"tokenizer_3" : tokenizer_3 , "text_encoder_3" : text_encoder_3 })
2287+
2288+ # Remaining stuff
2289+ init_params = inspect .signature (self .pipeline_class .__init__ ).parameters
2290+ if "safety_checker" in init_params :
2291+ pipeline_components .update ({"safety_checker" : None })
2292+ if "feature_extractor" in init_params :
2293+ pipeline_components .update ({"feature_extractor" : None })
2294+ if "image_encoder" in init_params :
2295+ pipeline_components .update ({"image_encoder" : None })
2296+
2297+ return pipeline_components , text_lora_config , denoiser_lora_config
2298+
2299+ @property
2300+ def output_shape (self ):
2301+ raise NotImplementedError
2302+
2303+ def get_dummy_inputs (self , with_generator = True ):
2304+ batch_size = 1
2305+ sequence_length = 10
2306+ num_channels = 4
2307+ sizes = (32 , 32 )
2308+
2309+ generator = torch .manual_seed (0 )
2310+ noise = floats_tensor ((batch_size , num_channels ) + sizes )
2311+ input_ids = torch .randint (1 , sequence_length , size = (batch_size , sequence_length ), generator = generator )
2312+
2313+ pipeline_inputs = {
2314+ "prompt" : "A painting of a squirrel eating a burger" ,
2315+ "num_inference_steps" : 5 ,
2316+ "guidance_scale" : 6.0 ,
2317+ "output_type" : "np" ,
2318+ }
2319+ if with_generator :
2320+ pipeline_inputs .update ({"generator" : generator })
2321+
2322+ return noise , input_ids , pipeline_inputs
2323+
2324+ # Copied from: https://colab.research.google.com/gist/sayakpaul/df2ef6e1ae6d8c10a49d859883b10860/scratchpad.ipynb
2325+ def get_dummy_tokens (self ):
2326+ max_seq_length = 77
2327+
2328+ inputs = torch .randint (2 , 56 , size = (1 , max_seq_length ), generator = torch .manual_seed (0 ))
2329+
2330+ prepared_inputs = {}
2331+ prepared_inputs ["input_ids" ] = inputs
2332+ return prepared_inputs
2333+
2334+ def add_adapters_to_pipeline (self , pipe , text_lora_config = None , denoiser_lora_config = None , adapter_name = "default" ):
2335+ if text_lora_config is not None :
2336+ if "text_encoder" in self .pipeline_class ._lora_loadable_modules :
2337+ pipe .text_encoder .add_adapter (text_lora_config , adapter_name = adapter_name )
2338+ self .assertTrue (
2339+ check_if_lora_correctly_set (pipe .text_encoder ), "Lora not correctly set in text encoder"
2340+ )
2341+
2342+ if denoiser_lora_config is not None :
2343+ denoiser = pipe .transformer if self .unet_kwargs is None else pipe .unet
2344+ denoiser .add_adapter (denoiser_lora_config , adapter_name = adapter_name )
2345+ self .assertTrue (check_if_lora_correctly_set (denoiser ), "Lora not correctly set in denoiser." )
2346+ else :
2347+ denoiser = None
2348+
2349+ if text_lora_config is not None and self .has_two_text_encoders or self .has_three_text_encoders :
2350+ if "text_encoder_2" in self .pipeline_class ._lora_loadable_modules :
2351+ pipe .text_encoder_2 .add_adapter (text_lora_config , adapter_name = adapter_name )
2352+ self .assertTrue (
2353+ check_if_lora_correctly_set (pipe .text_encoder_2 ), "Lora not correctly set in text encoder 2"
2354+ )
2355+ return pipe , denoiser
2356+
2357+ def _setup_pipeline_and_get_base_output (self , scheduler_cls ):
2358+ components , text_lora_config , denoiser_lora_config = self .get_dummy_components (scheduler_cls )
2359+ pipe = self .pipeline_class (** components )
2360+ pipe = pipe .to (torch_device )
2361+ pipe .set_progress_bar_config (disable = None )
2362+ _ , _ , inputs = self .get_dummy_inputs (with_generator = False )
2363+
2364+ output_no_lora = pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
2365+ return pipe , inputs , output_no_lora , text_lora_config , denoiser_lora_config
2366+
23572367 def _get_lora_state_dicts (self , modules_to_save ):
23582368 state_dicts = {}
23592369 for module_name , module in modules_to_save .items ():
0 commit comments