Skip to content

Commit 523ef5c

Browse files
authored
fix: add Civitai compatibility for LoRAs in a1111 metadata scheme by switching schema (#2615)
* feat: update sha256 generation functions https://github.com/lllyasviel/stable-diffusion-webui-forge/blob/29be1da7cf2b5dccfc70fbdd33eb35c56a31ffb7/modules/hashes.py * feat: add compatibility for LoRAs in a1111 metadata scheme * feat: add backwards compatibility * refactor: extract remove_special_loras * fix: correctly apply LoRA weight for legacy schema
1 parent 9aaa400 commit 523ef5c

File tree

3 files changed

+66
-20
lines changed

3 files changed

+66
-20
lines changed

modules/config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -539,6 +539,7 @@ def add_ratio(x):
539539

540540
sdxl_lcm_lora = 'sdxl_lcm_lora.safetensors'
541541
sdxl_lightning_lora = 'sdxl_lightning_4step_lora.safetensors'
542+
loras_metadata_remove = [sdxl_lcm_lora, sdxl_lightning_lora]
542543

543544

544545
def get_model_filenames(folder_paths, extensions=None, name_filter=None):

modules/meta_parser.py

Lines changed: 32 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import json
2-
import os
32
import re
43
from abc import ABC, abstractmethod
54
from pathlib import Path
@@ -12,7 +11,7 @@
1211
import modules.sdxl_styles
1312
from modules.flags import MetadataScheme, Performance, Steps
1413
from modules.flags import SAMPLERS, CIVITAI_NO_KARRAS
15-
from modules.util import quote, unquote, extract_styles_from_prompt, is_json, get_file_from_folder_list, calculate_sha256
14+
from modules.util import quote, unquote, extract_styles_from_prompt, is_json, get_file_from_folder_list, sha256
1615

1716
re_param_code = r'\s*(\w[\w \-/]+):\s*("(?:\\.|[^\\"])+"|[^,]*)(?:,|$)'
1817
re_param = re.compile(re_param_code)
@@ -110,7 +109,8 @@ def get_steps(key: str, fallback: str | None, source_dict: dict, results: list,
110109
assert h is not None
111110
h = int(h)
112111
# if not in steps or in steps and performance is not the same
113-
if h not in iter(Steps) or Steps(h).name.casefold() != source_dict.get('performance', '').replace(' ', '_').casefold():
112+
if h not in iter(Steps) or Steps(h).name.casefold() != source_dict.get('performance', '').replace(' ',
113+
'_').casefold():
114114
results.append(h)
115115
return
116116
results.append(-1)
@@ -204,7 +204,8 @@ def get_lora(key: str, fallback: str | None, source_dict: dict, results: list):
204204
def get_sha256(filepath):
205205
global hash_cache
206206
if filepath not in hash_cache:
207-
hash_cache[filepath] = calculate_sha256(filepath)
207+
# is_safetensors = os.path.splitext(filepath)[1].lower() == '.safetensors'
208+
hash_cache[filepath] = sha256(filepath)
208209

209210
return hash_cache[filepath]
210211

@@ -231,8 +232,9 @@ def parse_meta_from_preset(preset_content):
231232
height = height[:height.index(" ")]
232233
preset_prepared[meta_key] = (width, height)
233234
else:
234-
preset_prepared[meta_key] = items[settings_key] if settings_key in items and items[settings_key] is not None else getattr(modules.config, settings_key)
235-
235+
preset_prepared[meta_key] = items[settings_key] if settings_key in items and items[
236+
settings_key] is not None else getattr(modules.config, settings_key)
237+
236238
if settings_key == "default_styles" or settings_key == "default_aspect_ratio":
237239
preset_prepared[meta_key] = str(preset_prepared[meta_key])
238240

@@ -288,6 +290,12 @@ def set_data(self, raw_prompt, full_prompt, raw_negative_prompt, full_negative_p
288290
lora_hash = get_sha256(lora_path)
289291
self.loras.append((Path(lora_name).stem, lora_weight, lora_hash))
290292

293+
@staticmethod
294+
def remove_special_loras(lora_filenames):
295+
for lora_to_remove in modules.config.loras_metadata_remove:
296+
if lora_to_remove in lora_filenames:
297+
lora_filenames.remove(lora_to_remove)
298+
291299

292300
class A1111MetadataParser(MetadataParser):
293301
def get_scheme(self) -> MetadataScheme:
@@ -397,12 +405,19 @@ def parse_json(self, metadata: str) -> dict:
397405
data[key] = filename
398406
break
399407

400-
if 'lora_hashes' in data and data['lora_hashes'] != '':
408+
lora_data = ''
409+
if 'lora_weights' in data and data['lora_weights'] != '':
410+
lora_data = data['lora_weights']
411+
elif 'lora_hashes' in data and data['lora_hashes'] != '' and data['lora_hashes'].split(', ')[0].count(':') == 2:
412+
lora_data = data['lora_hashes']
413+
414+
if lora_data != '':
401415
lora_filenames = modules.config.lora_filenames.copy()
402-
if modules.config.sdxl_lcm_lora in lora_filenames:
403-
lora_filenames.remove(modules.config.sdxl_lcm_lora)
404-
for li, lora in enumerate(data['lora_hashes'].split(', ')):
405-
lora_name, lora_hash, lora_weight = lora.split(': ')
416+
self.remove_special_loras(lora_filenames)
417+
for li, lora in enumerate(lora_data.split(', ')):
418+
lora_split = lora.split(': ')
419+
lora_name = lora_split[0]
420+
lora_weight = lora_split[2] if len(lora_split) == 3 else lora_split[1]
406421
for filename in lora_filenames:
407422
path = Path(filename)
408423
if lora_name == path.stem:
@@ -453,11 +468,15 @@ def parse_string(self, metadata: dict) -> str:
453468

454469
if len(self.loras) > 0:
455470
lora_hashes = []
471+
lora_weights = []
456472
for index, (lora_name, lora_weight, lora_hash) in enumerate(self.loras):
457473
# workaround for Fooocus not knowing LoRA name in LoRA metadata
458-
lora_hashes.append(f'{lora_name}: {lora_hash}: {lora_weight}')
474+
lora_hashes.append(f'{lora_name}: {lora_hash}')
475+
lora_weights.append(f'{lora_name}: {lora_weight}')
459476
lora_hashes_string = ', '.join(lora_hashes)
477+
lora_weights_string = ', '.join(lora_weights)
460478
generation_params[self.fooocus_to_a1111['lora_hashes']] = lora_hashes_string
479+
generation_params[self.fooocus_to_a1111['lora_weights']] = lora_weights_string
461480

462481
generation_params[self.fooocus_to_a1111['version']] = data['version']
463482

@@ -480,9 +499,7 @@ def get_scheme(self) -> MetadataScheme:
480499
def parse_json(self, metadata: dict) -> dict:
481500
model_filenames = modules.config.model_filenames.copy()
482501
lora_filenames = modules.config.lora_filenames.copy()
483-
if modules.config.sdxl_lcm_lora in lora_filenames:
484-
lora_filenames.remove(modules.config.sdxl_lcm_lora)
485-
502+
self.remove_special_loras(lora_filenames)
486503
for key, value in metadata.items():
487504
if value in ['', 'None']:
488505
continue

modules/util.py

Lines changed: 33 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,9 @@
77
import os
88
import cv2
99
import json
10+
import hashlib
1011

1112
from PIL import Image
12-
from hashlib import sha256
1313

1414
import modules.sdxl_styles
1515

@@ -182,16 +182,44 @@ def get_files_from_folder(folder_path, extensions=None, name_filter=None):
182182
return filenames
183183

184184

185-
def calculate_sha256(filename, length=HASH_SHA256_LENGTH) -> str:
186-
hash_sha256 = sha256()
185+
def sha256(filename, use_addnet_hash=False, length=HASH_SHA256_LENGTH):
186+
print(f"Calculating sha256 for {filename}: ", end='')
187+
if use_addnet_hash:
188+
with open(filename, "rb") as file:
189+
sha256_value = addnet_hash_safetensors(file)
190+
else:
191+
sha256_value = calculate_sha256(filename)
192+
print(f"{sha256_value}")
193+
194+
return sha256_value[:length] if length is not None else sha256_value
195+
196+
197+
def addnet_hash_safetensors(b):
198+
"""kohya-ss hash for safetensors from https://github.com/kohya-ss/sd-scripts/blob/main/library/train_util.py"""
199+
hash_sha256 = hashlib.sha256()
200+
blksize = 1024 * 1024
201+
202+
b.seek(0)
203+
header = b.read(8)
204+
n = int.from_bytes(header, "little")
205+
206+
offset = n + 8
207+
b.seek(offset)
208+
for chunk in iter(lambda: b.read(blksize), b""):
209+
hash_sha256.update(chunk)
210+
211+
return hash_sha256.hexdigest()
212+
213+
214+
def calculate_sha256(filename) -> str:
215+
hash_sha256 = hashlib.sha256()
187216
blksize = 1024 * 1024
188217

189218
with open(filename, "rb") as f:
190219
for chunk in iter(lambda: f.read(blksize), b""):
191220
hash_sha256.update(chunk)
192221

193-
res = hash_sha256.hexdigest()
194-
return res[:length] if length else res
222+
return hash_sha256.hexdigest()
195223

196224

197225
def quote(text):

0 commit comments

Comments
 (0)