@@ -216,63 +216,59 @@ def _load_on_pipeline(pipeline: diffusers.DiffusionPipeline,
216216                # this is tricky because there is stupidly a positional argument named 'token' 
217217                # as well as an accepted kwargs value with the key 'token' 
218218
219-                 old_token  =  os .environ .get ('HF_TOKEN' , None )
220-                 if  use_auth_token  is  not None :
221-                     os .environ ['HF_TOKEN' ] =  use_auth_token 
222- 
223-                 try :
224-                     is_sdxl  =  pipeline .__class__ .__name__ .startswith ('StableDiffusionXL' )
225-                     is_flux  =  pipeline .__class__ .__name__ .startswith ('Flux' )
226- 
227-                     if  is_sdxl  or  is_flux :
228-                         filename , dicts  =  _load_textual_inversion_state_dict (
229-                             model_path ,
230-                             revision = textual_inversion_uri .revision ,
231-                             subfolder = textual_inversion_uri .subfolder ,
232-                             weight_name = textual_inversion_uri .weight_name ,
233-                             local_files_only = local_files_only 
234-                         )
235- 
236-                         if  is_sdxl :
237-                             if  'clip_l'  not  in dicts  or  'clip_g'  not  in dicts :
238-                                 raise  RuntimeError (
239-                                     'clip_l or clip_g not found in SDXL textual ' 
240-                                     f'inversion model "{ textual_inversion_uri .model }  
241-                                     'unsupported model format.' )
242-                         else :
243-                             if  'clip_l'  not  in dicts :
244-                                 raise  RuntimeError (
245-                                     'clip_l not found in Flux textual ' 
246-                                     f'inversion model "{ textual_inversion_uri .model }  
247-                                     'unsupported model format.' )
248- 
249-                         # token is the file name (no extension) with spaces 
250-                         # replaced by underscores when the user does not provide 
251-                         # a prompt token 
252-                         token  =  os .path .splitext (
253-                             os .path .basename (filename ))[0 ].replace (' ' , '_' ) \
254-                             if  textual_inversion_uri .token  is  None  else  textual_inversion_uri .token 
255- 
256-                         pipeline .load_textual_inversion (dicts ['clip_l' ],
257-                                                         token = token ,
258-                                                         text_encoder = pipeline .text_encoder ,
259-                                                         tokenizer = pipeline .tokenizer )
260- 
261-                         if  is_sdxl :
262-                             pipeline .load_textual_inversion (dicts ['clip_g' ],
263-                                                             token = token ,
264-                                                             text_encoder = pipeline .text_encoder_2 ,
265-                                                             tokenizer = pipeline .tokenizer_2 )
219+                 is_sdxl  =  pipeline .__class__ .__name__ .startswith ('StableDiffusionXL' )
220+                 is_flux  =  pipeline .__class__ .__name__ .startswith ('Flux' )
221+ 
222+                 if  is_sdxl  or  is_flux :
223+                     filename , dicts  =  _load_textual_inversion_state_dict (
224+                         model_path ,
225+                         revision = textual_inversion_uri .revision ,
226+                         subfolder = textual_inversion_uri .subfolder ,
227+                         weight_name = textual_inversion_uri .weight_name ,
228+                         local_files_only = local_files_only ,
229+                         token = use_auth_token 
230+                     )
231+ 
232+                     if  is_sdxl :
233+                         if  'clip_l'  not  in dicts  or  'clip_g'  not  in dicts :
234+                             raise  RuntimeError (
235+                                 'clip_l or clip_g not found in SDXL textual ' 
236+                                 f'inversion model "{ textual_inversion_uri .model }  
237+                                 'unsupported model format.' )
266238                    else :
267-                         pipeline .load_textual_inversion (model_path ,
268-                                                         token = textual_inversion_uri .token ,
269-                                                         revision = textual_inversion_uri .revision ,
270-                                                         subfolder = textual_inversion_uri .subfolder ,
271-                                                         weight_name = textual_inversion_uri .weight_name ,
272-                                                         local_files_only = local_files_only )
273-                 finally :
274-                     if  old_token  is  not None :
275-                         os .environ ['HF_TOKEN' ] =  old_token 
239+                         if  'clip_l'  not  in dicts :
240+                             raise  RuntimeError (
241+                                 'clip_l not found in Flux textual ' 
242+                                 f'inversion model "{ textual_inversion_uri .model }  
243+                                 'unsupported model format.' )
244+ 
245+                     # token is the file name (no extension) with spaces 
246+                     # replaced by underscores when the user does not provide 
247+                     # a prompt token 
248+                     token  =  os .path .splitext (
249+                         os .path .basename (filename ))[0 ].replace (' ' , '_' ) \
250+                         if  textual_inversion_uri .token  is  None  else  textual_inversion_uri .token 
251+ 
252+                     pipeline .load_textual_inversion (dicts ['clip_l' ],
253+                                                     token = token ,
254+                                                     text_encoder = pipeline .text_encoder ,
255+                                                     tokenizer = pipeline .tokenizer ,
256+                                                     hf_token = use_auth_token )
257+ 
258+                     if  is_sdxl :
259+                         pipeline .load_textual_inversion (dicts ['clip_g' ],
260+                                                         token = token ,
261+                                                         text_encoder = pipeline .text_encoder_2 ,
262+                                                         tokenizer = pipeline .tokenizer_2 ,
263+                                                         hf_token = use_auth_token )
264+                 else :
265+                     pipeline .load_textual_inversion (model_path ,
266+                                                     token = textual_inversion_uri .token ,
267+                                                     revision = textual_inversion_uri .revision ,
268+                                                     subfolder = textual_inversion_uri .subfolder ,
269+                                                     weight_name = textual_inversion_uri .weight_name ,
270+                                                     local_files_only = local_files_only ,
271+                                                     hf_token = use_auth_token )
276272
277273                _messages .debug_log (f'Added Textual Inversion: "{ textual_inversion_uri }  
278274                                    f'to pipeline: "{ pipeline .__class__ .__name__ }  )
0 commit comments