20
20
import piexif
21
21
import piexif .helper
22
22
import gradio as gr
23
+ import safetensors .torch
23
24
24
25
25
26
class LruCache (OrderedDict ):
@@ -249,7 +250,7 @@ def run_pnginfo(image):
249
250
return '' , geninfo , info
250
251
251
252
252
- def run_modelmerger (primary_model_name , secondary_model_name , teritary_model_name , interp_method , multiplier , save_as_half , custom_name ):
253
+ def run_modelmerger (primary_model_name , secondary_model_name , teritary_model_name , interp_method , multiplier , save_as_half , custom_name , checkpoint_format ):
253
254
def weighted_sum (theta0 , theta1 , alpha ):
254
255
return ((1 - alpha ) * theta0 ) + (alpha * theta1 )
255
256
@@ -264,19 +265,15 @@ def add_difference(theta0, theta1_2_diff, alpha):
264
265
teritary_model_info = sd_models .checkpoints_list .get (teritary_model_name , None )
265
266
266
267
print (f"Loading { primary_model_info .filename } ..." )
267
- primary_model = torch .load (primary_model_info .filename , map_location = 'cpu' )
268
- theta_0 = sd_models .get_state_dict_from_checkpoint (primary_model )
268
+ theta_0 = sd_models .read_state_dict (primary_model_info .filename , map_location = 'cpu' )
269
269
270
270
print (f"Loading { secondary_model_info .filename } ..." )
271
- secondary_model = torch .load (secondary_model_info .filename , map_location = 'cpu' )
272
- theta_1 = sd_models .get_state_dict_from_checkpoint (secondary_model )
271
+ theta_1 = sd_models .read_state_dict (secondary_model_info .filename , map_location = 'cpu' )
273
272
274
273
if teritary_model_info is not None :
275
274
print (f"Loading { teritary_model_info .filename } ..." )
276
- teritary_model = torch .load (teritary_model_info .filename , map_location = 'cpu' )
277
- theta_2 = sd_models .get_state_dict_from_checkpoint (teritary_model )
275
+ theta_2 = sd_models .read_state_dict (teritary_model_info .filename , map_location = 'cpu' )
278
276
else :
279
- teritary_model = None
280
277
theta_2 = None
281
278
282
279
theta_funcs = {
@@ -295,7 +292,7 @@ def add_difference(theta0, theta1_2_diff, alpha):
295
292
theta_1 [key ] = theta_func1 (theta_1 [key ], t2 )
296
293
else :
297
294
theta_1 [key ] = torch .zeros_like (theta_1 [key ])
298
- del theta_2 , teritary_model
295
+ del theta_2
299
296
300
297
for key in tqdm .tqdm (theta_0 .keys ()):
301
298
if 'model' in key and key in theta_1 :
@@ -314,12 +311,17 @@ def add_difference(theta0, theta1_2_diff, alpha):
314
311
315
312
ckpt_dir = shared .cmd_opts .ckpt_dir or sd_models .model_path
316
313
317
- filename = primary_model_info .model_name + '_' + str (round (1 - multiplier , 2 )) + '-' + secondary_model_info .model_name + '_' + str (round (multiplier , 2 )) + '-' + interp_method .replace (" " , "_" ) + '-merged.ckpt'
318
- filename = filename if custom_name == '' else (custom_name + '.ckpt' )
314
+ filename = primary_model_info .model_name + '_' + str (round (1 - multiplier , 2 )) + '-' + secondary_model_info .model_name + '_' + str (round (multiplier , 2 )) + '-' + interp_method .replace (" " , "_" ) + '-merged.' + checkpoint_format
315
+ filename = filename if custom_name == '' else (custom_name + '.' + checkpoint_format )
319
316
output_modelname = os .path .join (ckpt_dir , filename )
320
317
321
318
print (f"Saving to { output_modelname } ..." )
322
- torch .save (primary_model , output_modelname )
319
+
320
+ _ , extension = os .path .splitext (output_modelname )
321
+ if extension .lower () == ".safetensors" :
322
+ safetensors .torch .save_file (theta_0 , output_modelname , metadata = {"format" : "pt" })
323
+ else :
324
+ torch .save (theta_0 , output_modelname )
323
325
324
326
sd_models .list_models ()
325
327
0 commit comments