Skip to content

Commit 6378156

Browse files
committed
Generalize SD torch load/save to implement safetensor merging compat
1 parent ac7ecd2 commit 6378156

File tree

3 files changed

+1840
-1826
lines changed

3 files changed

+1840
-1826
lines changed

modules/extras.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -249,7 +249,7 @@ def run_pnginfo(image):
249249
return '', geninfo, info
250250

251251

252-
def run_modelmerger(primary_model_name, secondary_model_name, teritary_model_name, interp_method, multiplier, save_as_half, custom_name):
252+
def run_modelmerger(primary_model_name, secondary_model_name, teritary_model_name, interp_method, multiplier, save_as_half, save_as_safetensors, custom_name):
253253
def weighted_sum(theta0, theta1, alpha):
254254
return ((1 - alpha) * theta0) + (alpha * theta1)
255255

@@ -264,16 +264,16 @@ def add_difference(theta0, theta1_2_diff, alpha):
264264
teritary_model_info = sd_models.checkpoints_list.get(teritary_model_name, None)
265265

266266
print(f"Loading {primary_model_info.filename}...")
267-
primary_model = torch.load(primary_model_info.filename, map_location='cpu')
267+
primary_model = sd_models.torch_load(primary_model_info.filename, primary_model_info, map_override='cpu')
268268
theta_0 = sd_models.get_state_dict_from_checkpoint(primary_model)
269269

270270
print(f"Loading {secondary_model_info.filename}...")
271-
secondary_model = torch.load(secondary_model_info.filename, map_location='cpu')
271+
secondary_model = sd_models.torch_load(secondary_model_info.filename, primary_model_info, map_override='cpu')
272272
theta_1 = sd_models.get_state_dict_from_checkpoint(secondary_model)
273273

274274
if teritary_model_info is not None:
275275
print(f"Loading {teritary_model_info.filename}...")
276-
teritary_model = torch.load(teritary_model_info.filename, map_location='cpu')
276+
teritary_model = sd_models.torch_load(teritary_model_info.filename, teritary_model_info, map_override='cpu')
277277
theta_2 = sd_models.get_state_dict_from_checkpoint(teritary_model)
278278
else:
279279
teritary_model = None
@@ -314,12 +314,13 @@ def add_difference(theta0, theta1_2_diff, alpha):
314314

315315
ckpt_dir = shared.cmd_opts.ckpt_dir or sd_models.model_path
316316

317-
filename = primary_model_info.model_name + '_' + str(round(1-multiplier, 2)) + '-' + secondary_model_info.model_name + '_' + str(round(multiplier, 2)) + '-' + interp_method.replace(" ", "_") + '-merged.ckpt'
318-
filename = filename if custom_name == '' else (custom_name + '.ckpt')
317+
output_exttype = '.safetensors' if save_as_safetensors else '.ckpt'
318+
filename = primary_model_info.model_name + '_' + str(round(1-multiplier, 2)) + '-' + secondary_model_info.model_name + '_' + str(round(multiplier, 2)) + '-' + interp_method.replace(" ", "_") + '-merged' + output_exttype
319+
filename = filename if custom_name == '' else (custom_name + output_exttype)
319320
output_modelname = os.path.join(ckpt_dir, filename)
320321

321322
print(f"Saving to {output_modelname}...")
322-
torch.save(primary_model, output_modelname)
323+
sd_models.torch_save(primary_model, output_modelname)
323324

324325
sd_models.list_models()
325326

modules/sd_models.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import gc
55
from collections import namedtuple
66
import torch
7-
from safetensors.torch import load_file
7+
from safetensors.torch import load_file, save_file
88
import re
99
from omegaconf import OmegaConf
1010

@@ -143,6 +143,22 @@ def transform_checkpoint_dict_key(k):
143143

144144
return k
145145

146+
def torch_load(model_filename, model_info, map_override=None):
147+
map_override=shared.weight_load_location if not map_override else map_override
148+
if(checkpoint_types[model_info.exttype] == 'safetensors'):
149+
# safely load weights
150+
# TODO: safetensors supports zero copy fast load to gpu, see issue #684
151+
return load_file(model_filename, device=map_override)
152+
else:
153+
return torch.load(model_filename, map_location=map_override)
154+
155+
def torch_save(model, output_filename):
156+
basename, exttype = os.path.splitext(output_filename)
157+
if(checkpoint_types[exttype] == 'safetensors'):
158+
# [===== >] Reticulating brines...
159+
save_file(model, output_filename, metadata={"format": "pt"})
160+
else:
161+
torch.save(model, output_filename)
146162

147163
def get_state_dict_from_checkpoint(pl_sd):
148164
if "state_dict" in pl_sd:
@@ -175,12 +191,7 @@ def load_model_weights(model, checkpoint_info, vae_file="auto"):
175191
# load from file
176192
print(f"Loading weights [{sd_model_hash}] from {checkpoint_file}")
177193

178-
if(checkpoint_types[checkpoint_info.exttype] == 'safetensors'):
179-
# safely load weights
180-
# TODO: safetensors supports zero copy fast load to gpu, see issue #684
181-
pl_sd = load_file(checkpoint_file, device=shared.weight_load_location)
182-
else:
183-
pl_sd = torch.load(checkpoint_file, map_location=shared.weight_load_location)
194+
pl_sd = torch_load(checkpoint_file, checkpoint_info)
184195

185196
if "global_step" in pl_sd:
186197
print(f"Global Step: {pl_sd['global_step']}")

0 commit comments

Comments
 (0)