Skip to content

Commit 4e658bb

Browse files
authored
feat: optimize performance lora filtering in metadata (#3048)
* feat: add remove_performance_lora method * feat: use class PerformanceLoRA instead of strings in config * refactor: cleanup flags, use __member__ to check if enums contains key * feat: only filter lora of selected performance instead of all performance LoRAs * fix: disable intermediate results for all restricted performances too fast for Gradio, which becomes a bottleneck * refactor: rename parse_json to to_json, rename parse_string to to_string * feat: use speed steps as default instead of hardcoded 30 * feat: add method to_steps to Performance * refactor: remove method ordinal_suffix, not needed anymore * feat: only filter lora of selected performance instead of all performance LoRAs both metadata and history log * feat: do not filter LoRAs in metadata parser but rather in metadata load action
1 parent 3ef663c commit 4e658bb

File tree

8 files changed

+144
-53
lines changed

8 files changed

+144
-53
lines changed

modules/async_worker.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -462,8 +462,10 @@ def handler(async_task):
462462

463463
progressbar(async_task, 2, 'Loading models ...')
464464

465-
loras, prompt = parse_lora_references_from_prompt(prompt, loras, modules.config.default_max_lora_number)
465+
lora_filenames = modules.util.remove_performance_lora(modules.config.lora_filenames, performance_selection)
466+
loras, prompt = parse_lora_references_from_prompt(prompt, loras, modules.config.default_max_lora_number, lora_filenames=lora_filenames)
466467
loras += performance_loras
468+
467469
pipeline.refresh_everything(refiner_model_name=refiner_model_name, base_model_name=base_model_name,
468470
loras=loras, base_model_additional_loras=base_model_additional_loras,
469471
use_synthetic_refiner=use_synthetic_refiner, vae_name=vae_name)

modules/config.py

Lines changed: 8 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -548,25 +548,9 @@ def add_ratio(x):
548548

549549
model_filenames = []
550550
lora_filenames = []
551-
lora_filenames_no_special = []
552551
vae_filenames = []
553552
wildcard_filenames = []
554553

555-
sdxl_lcm_lora = 'sdxl_lcm_lora.safetensors'
556-
sdxl_lightning_lora = 'sdxl_lightning_4step_lora.safetensors'
557-
sdxl_hyper_sd_lora = 'sdxl_hyper_sd_4step_lora.safetensors'
558-
loras_metadata_remove = [sdxl_lcm_lora, sdxl_lightning_lora, sdxl_hyper_sd_lora]
559-
560-
561-
def remove_special_loras(lora_filenames):
562-
global loras_metadata_remove
563-
564-
loras_no_special = lora_filenames.copy()
565-
for lora_to_remove in loras_metadata_remove:
566-
if lora_to_remove in loras_no_special:
567-
loras_no_special.remove(lora_to_remove)
568-
return loras_no_special
569-
570554

571555
def get_model_filenames(folder_paths, extensions=None, name_filter=None):
572556
if extensions is None:
@@ -582,10 +566,9 @@ def get_model_filenames(folder_paths, extensions=None, name_filter=None):
582566

583567

584568
def update_files():
585-
global model_filenames, lora_filenames, lora_filenames_no_special, vae_filenames, wildcard_filenames, available_presets
569+
global model_filenames, lora_filenames, vae_filenames, wildcard_filenames, available_presets
586570
model_filenames = get_model_filenames(paths_checkpoints)
587571
lora_filenames = get_model_filenames(paths_loras)
588-
lora_filenames_no_special = remove_special_loras(lora_filenames)
589572
vae_filenames = get_model_filenames(path_vae)
590573
wildcard_filenames = get_files_from_folder(path_wildcards, ['.txt'])
591574
available_presets = get_presets()
@@ -634,26 +617,27 @@ def downloading_sdxl_lcm_lora():
634617
load_file_from_url(
635618
url='https://huggingface.co/lllyasviel/misc/resolve/main/sdxl_lcm_lora.safetensors',
636619
model_dir=paths_loras[0],
637-
file_name=sdxl_lcm_lora
620+
file_name=modules.flags.PerformanceLoRA.EXTREME_SPEED.value
638621
)
639-
return sdxl_lcm_lora
622+
return modules.flags.PerformanceLoRA.EXTREME_SPEED.value
623+
640624

641625
def downloading_sdxl_lightning_lora():
642626
load_file_from_url(
643627
url='https://huggingface.co/mashb1t/misc/resolve/main/sdxl_lightning_4step_lora.safetensors',
644628
model_dir=paths_loras[0],
645-
file_name=sdxl_lightning_lora
629+
file_name=modules.flags.PerformanceLoRA.LIGHTNING.value
646630
)
647-
return sdxl_lightning_lora
631+
return modules.flags.PerformanceLoRA.LIGHTNING.value
648632

649633

650634
def downloading_sdxl_hyper_sd_lora():
651635
load_file_from_url(
652636
url='https://huggingface.co/mashb1t/misc/resolve/main/sdxl_hyper_sd_4step_lora.safetensors',
653637
model_dir=paths_loras[0],
654-
file_name=sdxl_hyper_sd_lora
638+
file_name=modules.flags.PerformanceLoRA.HYPER_SD.value
655639
)
656-
return sdxl_hyper_sd_lora
640+
return modules.flags.PerformanceLoRA.HYPER_SD.value
657641

658642

659643
def downloading_controlnet_canny():

modules/flags.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,8 @@
4848

4949
KSAMPLER_NAMES = list(KSAMPLER.keys())
5050

51-
SCHEDULER_NAMES = ["normal", "karras", "exponential", "sgm_uniform", "simple", "ddim_uniform", "lcm", "turbo", "align_your_steps", "tcd"]
51+
SCHEDULER_NAMES = ["normal", "karras", "exponential", "sgm_uniform", "simple", "ddim_uniform", "lcm", "turbo",
52+
"align_your_steps", "tcd"]
5253
SAMPLER_NAMES = KSAMPLER_NAMES + list(SAMPLER_EXTRA.keys())
5354

5455
sampler_list = SAMPLER_NAMES
@@ -91,6 +92,7 @@
9192
'1664*576', '1728*576'
9293
]
9394

95+
9496
class MetadataScheme(Enum):
9597
FOOOCUS = 'fooocus'
9698
A1111 = 'a1111'
@@ -115,6 +117,14 @@ def list(cls) -> list:
115117
return list(map(lambda c: c.value, cls))
116118

117119

120+
class PerformanceLoRA(Enum):
121+
QUALITY = None
122+
SPEED = None
123+
EXTREME_SPEED = 'sdxl_lcm_lora.safetensors'
124+
LIGHTNING = 'sdxl_lightning_4step_lora.safetensors'
125+
HYPER_SD = 'sdxl_hyper_sd_4step_lora.safetensors'
126+
127+
118128
class Steps(IntEnum):
119129
QUALITY = 60
120130
SPEED = 30
@@ -142,14 +152,21 @@ class Performance(Enum):
142152
def list(cls) -> list:
143153
return list(map(lambda c: c.value, cls))
144154

155+
@classmethod
156+
def by_steps(cls, steps: int | str):
157+
return cls[Steps(int(steps)).name]
158+
145159
@classmethod
146160
def has_restricted_features(cls, x) -> bool:
147161
if isinstance(x, Performance):
148162
x = x.value
149163
return x in [cls.EXTREME_SPEED.value, cls.LIGHTNING.value, cls.HYPER_SD.value]
150164

151165
def steps(self) -> int | None:
152-
return Steps[self.name].value if Steps[self.name] else None
166+
return Steps[self.name].value if self.name in Steps.__members__ else None
153167

154168
def steps_uov(self) -> int | None:
155-
return StepsUOV[self.name].value if Steps[self.name] else None
169+
return StepsUOV[self.name].value if self.name in StepsUOV.__members__ else None
170+
171+
def lora_filename(self) -> str | None:
172+
return PerformanceLoRA[self.name].value if self.name in PerformanceLoRA.__members__ else None

modules/meta_parser.py

Lines changed: 28 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ def load_parameter_button_click(raw_metadata: dict | str, is_generating: bool):
3232
get_str('prompt', 'Prompt', loaded_parameter_dict, results)
3333
get_str('negative_prompt', 'Negative Prompt', loaded_parameter_dict, results)
3434
get_list('styles', 'Styles', loaded_parameter_dict, results)
35-
get_str('performance', 'Performance', loaded_parameter_dict, results)
35+
performance = get_str('performance', 'Performance', loaded_parameter_dict, results)
3636
get_steps('steps', 'Steps', loaded_parameter_dict, results)
3737
get_number('overwrite_switch', 'Overwrite Switch', loaded_parameter_dict, results)
3838
get_resolution('resolution', 'Resolution', loaded_parameter_dict, results)
@@ -59,19 +59,27 @@ def load_parameter_button_click(raw_metadata: dict | str, is_generating: bool):
5959

6060
get_freeu('freeu', 'FreeU', loaded_parameter_dict, results)
6161

62+
# prevent performance LoRAs to be added twice, by performance and by lora
63+
performance_filename = None
64+
if performance is not None and performance in Performance.list():
65+
performance = Performance(performance)
66+
performance_filename = performance.lora_filename()
67+
6268
for i in range(modules.config.default_max_lora_number):
63-
get_lora(f'lora_combined_{i + 1}', f'LoRA {i + 1}', loaded_parameter_dict, results)
69+
get_lora(f'lora_combined_{i + 1}', f'LoRA {i + 1}', loaded_parameter_dict, results, performance_filename)
6470

6571
return results
6672

6773

68-
def get_str(key: str, fallback: str | None, source_dict: dict, results: list, default=None):
74+
def get_str(key: str, fallback: str | None, source_dict: dict, results: list, default=None) -> str | None:
6975
try:
7076
h = source_dict.get(key, source_dict.get(fallback, default))
7177
assert isinstance(h, str)
7278
results.append(h)
79+
return h
7380
except:
7481
results.append(gr.update())
82+
return None
7583

7684

7785
def get_list(key: str, fallback: str | None, source_dict: dict, results: list, default=None):
@@ -181,7 +189,7 @@ def get_freeu(key: str, fallback: str | None, source_dict: dict, results: list,
181189
results.append(gr.update())
182190

183191

184-
def get_lora(key: str, fallback: str | None, source_dict: dict, results: list):
192+
def get_lora(key: str, fallback: str | None, source_dict: dict, results: list, performance_filename: str | None):
185193
try:
186194
split_data = source_dict.get(key, source_dict.get(fallback)).split(' : ')
187195
enabled = True
@@ -193,6 +201,9 @@ def get_lora(key: str, fallback: str | None, source_dict: dict, results: list):
193201
name = split_data[1]
194202
weight = split_data[2]
195203

204+
if name == performance_filename:
205+
raise Exception
206+
196207
weight = float(weight)
197208
results.append(enabled)
198209
results.append(name)
@@ -248,7 +259,7 @@ def __init__(self):
248259
self.full_prompt: str = ''
249260
self.raw_negative_prompt: str = ''
250261
self.full_negative_prompt: str = ''
251-
self.steps: int = 30
262+
self.steps: int = Steps.SPEED.value
252263
self.base_model_name: str = ''
253264
self.base_model_hash: str = ''
254265
self.refiner_model_name: str = ''
@@ -261,11 +272,11 @@ def get_scheme(self) -> MetadataScheme:
261272
raise NotImplementedError
262273

263274
@abstractmethod
264-
def parse_json(self, metadata: dict | str) -> dict:
275+
def to_json(self, metadata: dict | str) -> dict:
265276
raise NotImplementedError
266277

267278
@abstractmethod
268-
def parse_string(self, metadata: dict) -> str:
279+
def to_string(self, metadata: dict) -> str:
269280
raise NotImplementedError
270281

271282
def set_data(self, raw_prompt, full_prompt, raw_negative_prompt, full_negative_prompt, steps, base_model_name,
@@ -328,7 +339,7 @@ def get_scheme(self) -> MetadataScheme:
328339
'version': 'Version'
329340
}
330341

331-
def parse_json(self, metadata: str) -> dict:
342+
def to_json(self, metadata: str) -> dict:
332343
metadata_prompt = ''
333344
metadata_negative_prompt = ''
334345

@@ -382,9 +393,9 @@ def parse_json(self, metadata: str) -> dict:
382393
data['styles'] = str(found_styles)
383394

384395
# try to load performance based on steps, fallback for direct A1111 imports
385-
if 'steps' in data and 'performance' not in data:
396+
if 'steps' in data and 'performance' in data is None:
386397
try:
387-
data['performance'] = Performance[Steps(int(data['steps'])).name].value
398+
data['performance'] = Performance.by_steps(data['steps']).value
388399
except ValueError | KeyError:
389400
pass
390401

@@ -414,15 +425,15 @@ def parse_json(self, metadata: str) -> dict:
414425
lora_split = lora.split(': ')
415426
lora_name = lora_split[0]
416427
lora_weight = lora_split[2] if len(lora_split) == 3 else lora_split[1]
417-
for filename in modules.config.lora_filenames_no_special:
428+
for filename in modules.config.lora_filenames:
418429
path = Path(filename)
419430
if lora_name == path.stem:
420431
data[f'lora_combined_{li + 1}'] = f'{filename} : {lora_weight}'
421432
break
422433

423434
return data
424435

425-
def parse_string(self, metadata: dict) -> str:
436+
def to_string(self, metadata: dict) -> str:
426437
data = {k: v for _, k, v in metadata}
427438

428439
width, height = eval(data['resolution'])
@@ -502,22 +513,22 @@ class FooocusMetadataParser(MetadataParser):
502513
def get_scheme(self) -> MetadataScheme:
503514
return MetadataScheme.FOOOCUS
504515

505-
def parse_json(self, metadata: dict) -> dict:
516+
def to_json(self, metadata: dict) -> dict:
506517
for key, value in metadata.items():
507518
if value in ['', 'None']:
508519
continue
509520
if key in ['base_model', 'refiner_model']:
510521
metadata[key] = self.replace_value_with_filename(key, value, modules.config.model_filenames)
511522
elif key.startswith('lora_combined_'):
512-
metadata[key] = self.replace_value_with_filename(key, value, modules.config.lora_filenames_no_special)
523+
metadata[key] = self.replace_value_with_filename(key, value, modules.config.lora_filenames)
513524
elif key == 'vae':
514525
metadata[key] = self.replace_value_with_filename(key, value, modules.config.vae_filenames)
515526
else:
516527
continue
517528

518529
return metadata
519530

520-
def parse_string(self, metadata: list) -> str:
531+
def to_string(self, metadata: list) -> str:
521532
for li, (label, key, value) in enumerate(metadata):
522533
# remove model folder paths from metadata
523534
if key.startswith('lora_combined_'):
@@ -557,6 +568,8 @@ def replace_value_with_filename(key, value, filenames):
557568
elif value == path.stem:
558569
return filename
559570

571+
return None
572+
560573

561574
def get_metadata_parser(metadata_scheme: MetadataScheme) -> MetadataParser:
562575
match metadata_scheme:

modules/private_logger.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ def log(img, metadata, metadata_parser: MetadataParser | None = None, output_for
2727
date_string, local_temp_filename, only_name = generate_temp_filename(folder=path_outputs, extension=output_format)
2828
os.makedirs(os.path.dirname(local_temp_filename), exist_ok=True)
2929

30-
parsed_parameters = metadata_parser.parse_string(metadata.copy()) if metadata_parser is not None else ''
30+
parsed_parameters = metadata_parser.to_string(metadata.copy()) if metadata_parser is not None else ''
3131
image = Image.fromarray(img)
3232

3333
if output_format == OutputFormat.PNG.value:

modules/util.py

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
import modules.config
1818
import modules.sdxl_styles
19+
from modules.flags import Performance
1920

2021
LANCZOS = (Image.Resampling.LANCZOS if hasattr(Image, 'Resampling') else Image.LANCZOS)
2122

@@ -381,9 +382,6 @@ def get_file_from_folder_list(name, folders):
381382

382383
return os.path.abspath(os.path.realpath(os.path.join(folders[0], name)))
383384

384-
def ordinal_suffix(number: int) -> str:
385-
return 'th' if 10 <= number % 100 <= 20 else {1: 'st', 2: 'nd', 3: 'rd'}.get(number % 10, 'th')
386-
387385

388386
def makedirs_with_log(path):
389387
try:
@@ -397,10 +395,15 @@ def get_enabled_loras(loras: list, remove_none=True) -> list:
397395

398396

399397
def parse_lora_references_from_prompt(prompt: str, loras: List[Tuple[AnyStr, float]], loras_limit: int = 5,
400-
skip_file_check=False, prompt_cleanup=True, deduplicate_loras=True) -> tuple[List[Tuple[AnyStr, float]], str]:
398+
skip_file_check=False, prompt_cleanup=True, deduplicate_loras=True,
399+
lora_filenames=None) -> tuple[List[Tuple[AnyStr, float]], str]:
400+
if lora_filenames is None:
401+
lora_filenames = []
402+
401403
found_loras = []
402404
prompt_without_loras = ''
403405
cleaned_prompt = ''
406+
404407
for token in prompt.split(','):
405408
matches = LORAS_PROMPT_PATTERN.findall(token)
406409

@@ -410,7 +413,7 @@ def parse_lora_references_from_prompt(prompt: str, loras: List[Tuple[AnyStr, flo
410413
for match in matches:
411414
lora_name = match[1] + '.safetensors'
412415
if not skip_file_check:
413-
lora_name = get_filname_by_stem(match[1], modules.config.lora_filenames_no_special)
416+
lora_name = get_filname_by_stem(match[1], lora_filenames)
414417
if lora_name is not None:
415418
found_loras.append((lora_name, float(match[2])))
416419
token = token.replace(match[0], '')
@@ -440,6 +443,22 @@ def parse_lora_references_from_prompt(prompt: str, loras: List[Tuple[AnyStr, flo
440443
return updated_loras[:loras_limit], cleaned_prompt
441444

442445

446+
def remove_performance_lora(filenames: list, performance: Performance | None):
447+
loras_without_performance = filenames.copy()
448+
449+
if performance is None:
450+
return loras_without_performance
451+
452+
performance_lora = performance.lora_filename()
453+
454+
for filename in filenames:
455+
path = Path(filename)
456+
if performance_lora == path.name:
457+
loras_without_performance.remove(filename)
458+
459+
return loras_without_performance
460+
461+
443462
def cleanup_prompt(prompt):
444463
prompt = re.sub(' +', ' ', prompt)
445464
prompt = re.sub(',+', ',', prompt)

0 commit comments

Comments
 (0)