Skip to content

Commit 67f80fd

Browse files
committed
textual inversion HF_TOKEN hack no longer needed
huggingface/diffusers#10546
1 parent dc06111 commit 67f80fd

File tree

1 file changed

+52
-56
lines changed

1 file changed

+52
-56
lines changed

dgenerate/pipelinewrapper/uris/textualinversionuri.py

Lines changed: 52 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -216,63 +216,59 @@ def _load_on_pipeline(pipeline: diffusers.DiffusionPipeline,
216216
# this is tricky because there is stupidly a positional argument named 'token'
217217
# as well as an accepted kwargs value with the key 'token'
218218

219-
old_token = os.environ.get('HF_TOKEN', None)
220-
if use_auth_token is not None:
221-
os.environ['HF_TOKEN'] = use_auth_token
222-
223-
try:
224-
is_sdxl = pipeline.__class__.__name__.startswith('StableDiffusionXL')
225-
is_flux = pipeline.__class__.__name__.startswith('Flux')
226-
227-
if is_sdxl or is_flux:
228-
filename, dicts = _load_textual_inversion_state_dict(
229-
model_path,
230-
revision=textual_inversion_uri.revision,
231-
subfolder=textual_inversion_uri.subfolder,
232-
weight_name=textual_inversion_uri.weight_name,
233-
local_files_only=local_files_only
234-
)
235-
236-
if is_sdxl:
237-
if 'clip_l' not in dicts or 'clip_g' not in dicts:
238-
raise RuntimeError(
239-
'clip_l or clip_g not found in SDXL textual '
240-
f'inversion model "{textual_inversion_uri.model}" state dict, '
241-
'unsupported model format.')
242-
else:
243-
if 'clip_l' not in dicts:
244-
raise RuntimeError(
245-
'clip_l not found in Flux textual '
246-
f'inversion model "{textual_inversion_uri.model}" state dict, '
247-
'unsupported model format.')
248-
249-
# token is the file name (no extension) with spaces
250-
# replaced by underscores when the user does not provide
251-
# a prompt token
252-
token = os.path.splitext(
253-
os.path.basename(filename))[0].replace(' ', '_') \
254-
if textual_inversion_uri.token is None else textual_inversion_uri.token
255-
256-
pipeline.load_textual_inversion(dicts['clip_l'],
257-
token=token,
258-
text_encoder=pipeline.text_encoder,
259-
tokenizer=pipeline.tokenizer)
260-
261-
if is_sdxl:
262-
pipeline.load_textual_inversion(dicts['clip_g'],
263-
token=token,
264-
text_encoder=pipeline.text_encoder_2,
265-
tokenizer=pipeline.tokenizer_2)
219+
is_sdxl = pipeline.__class__.__name__.startswith('StableDiffusionXL')
220+
is_flux = pipeline.__class__.__name__.startswith('Flux')
221+
222+
if is_sdxl or is_flux:
223+
filename, dicts = _load_textual_inversion_state_dict(
224+
model_path,
225+
revision=textual_inversion_uri.revision,
226+
subfolder=textual_inversion_uri.subfolder,
227+
weight_name=textual_inversion_uri.weight_name,
228+
local_files_only=local_files_only,
229+
token=use_auth_token
230+
)
231+
232+
if is_sdxl:
233+
if 'clip_l' not in dicts or 'clip_g' not in dicts:
234+
raise RuntimeError(
235+
'clip_l or clip_g not found in SDXL textual '
236+
f'inversion model "{textual_inversion_uri.model}" state dict, '
237+
'unsupported model format.')
266238
else:
267-
pipeline.load_textual_inversion(model_path,
268-
token=textual_inversion_uri.token,
269-
revision=textual_inversion_uri.revision,
270-
subfolder=textual_inversion_uri.subfolder,
271-
weight_name=textual_inversion_uri.weight_name,
272-
local_files_only=local_files_only)
273-
finally:
274-
if old_token is not None:
275-
os.environ['HF_TOKEN'] = old_token
239+
if 'clip_l' not in dicts:
240+
raise RuntimeError(
241+
'clip_l not found in Flux textual '
242+
f'inversion model "{textual_inversion_uri.model}" state dict, '
243+
'unsupported model format.')
244+
245+
# token is the file name (no extension) with spaces
246+
# replaced by underscores when the user does not provide
247+
# a prompt token
248+
token = os.path.splitext(
249+
os.path.basename(filename))[0].replace(' ', '_') \
250+
if textual_inversion_uri.token is None else textual_inversion_uri.token
251+
252+
pipeline.load_textual_inversion(dicts['clip_l'],
253+
token=token,
254+
text_encoder=pipeline.text_encoder,
255+
tokenizer=pipeline.tokenizer,
256+
hf_token=use_auth_token)
257+
258+
if is_sdxl:
259+
pipeline.load_textual_inversion(dicts['clip_g'],
260+
token=token,
261+
text_encoder=pipeline.text_encoder_2,
262+
tokenizer=pipeline.tokenizer_2,
263+
hf_token=use_auth_token)
264+
else:
265+
pipeline.load_textual_inversion(model_path,
266+
token=textual_inversion_uri.token,
267+
revision=textual_inversion_uri.revision,
268+
subfolder=textual_inversion_uri.subfolder,
269+
weight_name=textual_inversion_uri.weight_name,
270+
local_files_only=local_files_only,
271+
hf_token=use_auth_token)
276272

277273
_messages.debug_log(f'Added Textual Inversion: "{textual_inversion_uri}" '
278274
f'to pipeline: "{pipeline.__class__.__name__}"')

0 commit comments

Comments
 (0)