Skip to content

Commit dac9b6f

Browse files
committed
add safetensors support for model merging #4869
1 parent 6074175 commit dac9b6f

File tree

3 files changed

+35
-24
lines changed

3 files changed

+35
-24
lines changed

modules/extras.py

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import piexif
2121
import piexif.helper
2222
import gradio as gr
23+
import safetensors.torch
2324

2425

2526
class LruCache(OrderedDict):
@@ -249,7 +250,7 @@ def run_pnginfo(image):
249250
return '', geninfo, info
250251

251252

252-
def run_modelmerger(primary_model_name, secondary_model_name, teritary_model_name, interp_method, multiplier, save_as_half, custom_name):
253+
def run_modelmerger(primary_model_name, secondary_model_name, teritary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format):
253254
def weighted_sum(theta0, theta1, alpha):
254255
return ((1 - alpha) * theta0) + (alpha * theta1)
255256

@@ -264,19 +265,15 @@ def add_difference(theta0, theta1_2_diff, alpha):
264265
teritary_model_info = sd_models.checkpoints_list.get(teritary_model_name, None)
265266

266267
print(f"Loading {primary_model_info.filename}...")
267-
primary_model = torch.load(primary_model_info.filename, map_location='cpu')
268-
theta_0 = sd_models.get_state_dict_from_checkpoint(primary_model)
268+
theta_0 = sd_models.read_state_dict(primary_model_info.filename, map_location='cpu')
269269

270270
print(f"Loading {secondary_model_info.filename}...")
271-
secondary_model = torch.load(secondary_model_info.filename, map_location='cpu')
272-
theta_1 = sd_models.get_state_dict_from_checkpoint(secondary_model)
271+
theta_1 = sd_models.read_state_dict(secondary_model_info.filename, map_location='cpu')
273272

274273
if teritary_model_info is not None:
275274
print(f"Loading {teritary_model_info.filename}...")
276-
teritary_model = torch.load(teritary_model_info.filename, map_location='cpu')
277-
theta_2 = sd_models.get_state_dict_from_checkpoint(teritary_model)
275+
theta_2 = sd_models.read_state_dict(teritary_model_info.filename, map_location='cpu')
278276
else:
279-
teritary_model = None
280277
theta_2 = None
281278

282279
theta_funcs = {
@@ -295,7 +292,7 @@ def add_difference(theta0, theta1_2_diff, alpha):
295292
theta_1[key] = theta_func1(theta_1[key], t2)
296293
else:
297294
theta_1[key] = torch.zeros_like(theta_1[key])
298-
del theta_2, teritary_model
295+
del theta_2
299296

300297
for key in tqdm.tqdm(theta_0.keys()):
301298
if 'model' in key and key in theta_1:
@@ -314,12 +311,17 @@ def add_difference(theta0, theta1_2_diff, alpha):
314311

315312
ckpt_dir = shared.cmd_opts.ckpt_dir or sd_models.model_path
316313

317-
filename = primary_model_info.model_name + '_' + str(round(1-multiplier, 2)) + '-' + secondary_model_info.model_name + '_' + str(round(multiplier, 2)) + '-' + interp_method.replace(" ", "_") + '-merged.ckpt'
318-
filename = filename if custom_name == '' else (custom_name + '.ckpt')
314+
filename = primary_model_info.model_name + '_' + str(round(1-multiplier, 2)) + '-' + secondary_model_info.model_name + '_' + str(round(multiplier, 2)) + '-' + interp_method.replace(" ", "_") + '-merged.' + checkpoint_format
315+
filename = filename if custom_name == '' else (custom_name + '.' + checkpoint_format)
319316
output_modelname = os.path.join(ckpt_dir, filename)
320317

321318
print(f"Saving to {output_modelname}...")
322-
torch.save(primary_model, output_modelname)
319+
320+
_, extension = os.path.splitext(output_modelname)
321+
if extension.lower() == ".safetensors":
322+
safetensors.torch.save_file(theta_0, output_modelname, metadata={"format": "pt"})
323+
else:
324+
torch.save(theta_0, output_modelname)
323325

324326
sd_models.list_models()
325327

modules/sd_models.py

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,20 @@ def get_state_dict_from_checkpoint(pl_sd):
160160
return pl_sd
161161

162162

163+
def read_state_dict(checkpoint_file, print_global_state=False, map_location=None):
164+
_, extension = os.path.splitext(checkpoint_file)
165+
if extension.lower() == ".safetensors":
166+
pl_sd = safetensors.torch.load_file(checkpoint_file, device=map_location or shared.weight_load_location)
167+
else:
168+
pl_sd = torch.load(checkpoint_file, map_location=map_location or shared.weight_load_location)
169+
170+
if print_global_state and "global_step" in pl_sd:
171+
print(f"Global Step: {pl_sd['global_step']}")
172+
173+
sd = get_state_dict_from_checkpoint(pl_sd)
174+
return sd
175+
176+
163177
def load_model_weights(model, checkpoint_info, vae_file="auto"):
164178
checkpoint_file = checkpoint_info.filename
165179
sd_model_hash = checkpoint_info.hash
@@ -174,17 +188,7 @@ def load_model_weights(model, checkpoint_info, vae_file="auto"):
174188
# load from file
175189
print(f"Loading weights [{sd_model_hash}] from {checkpoint_file}")
176190

177-
_, extension = os.path.splitext(checkpoint_file)
178-
if extension.lower() == ".safetensors":
179-
pl_sd = safetensors.torch.load_file(checkpoint_file, device=shared.weight_load_location)
180-
else:
181-
pl_sd = torch.load(checkpoint_file, map_location=shared.weight_load_location)
182-
183-
if "global_step" in pl_sd:
184-
print(f"Global Step: {pl_sd['global_step']}")
185-
186-
sd = get_state_dict_from_checkpoint(pl_sd)
187-
del pl_sd
191+
sd = read_state_dict(checkpoint_file)
188192
model.load_state_dict(sd, strict=False)
189193
del sd
190194

modules/ui.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1164,7 +1164,11 @@ def create_ui(wrap_gradio_gpu_call):
11641164
custom_name = gr.Textbox(label="Custom Name (Optional)")
11651165
interp_amount = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, label='Multiplier (M) - set to 0 to get model A', value=0.3)
11661166
interp_method = gr.Radio(choices=["Weighted sum", "Add difference"], value="Weighted sum", label="Interpolation Method")
1167-
save_as_half = gr.Checkbox(value=False, label="Save as float16")
1167+
1168+
with gr.Row():
1169+
checkpoint_format = gr.Radio(choices=["ckpt", "safetensors"], value="ckpt", label="Checkpoint format")
1170+
save_as_half = gr.Checkbox(value=False, label="Save as float16")
1171+
11681172
modelmerger_merge = gr.Button(elem_id="modelmerger_merge", label="Merge", variant='primary')
11691173

11701174
with gr.Column(variant='panel'):
@@ -1692,6 +1696,7 @@ def modelmerger(*args):
16921696
interp_amount,
16931697
save_as_half,
16941698
custom_name,
1699+
checkpoint_format,
16951700
],
16961701
outputs=[
16971702
submit_result,

0 commit comments

Comments
 (0)