@@ -249,7 +249,7 @@ def run_pnginfo(image):
249
249
return '' , geninfo , info
250
250
251
251
252
- def run_modelmerger (primary_model_name , secondary_model_name , teritary_model_name , interp_method , multiplier , save_as_half , custom_name ):
252
+ def run_modelmerger (primary_model_name , secondary_model_name , teritary_model_name , interp_method , multiplier , save_as_half , save_as_safetensors , custom_name ):
253
253
def weighted_sum (theta0 , theta1 , alpha ):
254
254
return ((1 - alpha ) * theta0 ) + (alpha * theta1 )
255
255
@@ -264,16 +264,16 @@ def add_difference(theta0, theta1_2_diff, alpha):
264
264
teritary_model_info = sd_models .checkpoints_list .get (teritary_model_name , None )
265
265
266
266
print (f"Loading { primary_model_info .filename } ..." )
267
- primary_model = torch . load (primary_model_info .filename , map_location = 'cpu' )
267
+ primary_model = sd_models . torch_load (primary_model_info .filename , primary_model_info , map_override = 'cpu' )
268
268
theta_0 = sd_models .get_state_dict_from_checkpoint (primary_model )
269
269
270
270
print (f"Loading { secondary_model_info .filename } ..." )
271
- secondary_model = torch . load (secondary_model_info .filename , map_location = 'cpu' )
271
+ secondary_model = sd_models . torch_load (secondary_model_info .filename , primary_model_info , map_override = 'cpu' )
272
272
theta_1 = sd_models .get_state_dict_from_checkpoint (secondary_model )
273
273
274
274
if teritary_model_info is not None :
275
275
print (f"Loading { teritary_model_info .filename } ..." )
276
- teritary_model = torch . load (teritary_model_info .filename , map_location = 'cpu' )
276
+ teritary_model = sd_models . torch_load (teritary_model_info .filename , teritary_model_info , map_override = 'cpu' )
277
277
theta_2 = sd_models .get_state_dict_from_checkpoint (teritary_model )
278
278
else :
279
279
teritary_model = None
@@ -314,12 +314,13 @@ def add_difference(theta0, theta1_2_diff, alpha):
314
314
315
315
ckpt_dir = shared .cmd_opts .ckpt_dir or sd_models .model_path
316
316
317
- filename = primary_model_info .model_name + '_' + str (round (1 - multiplier , 2 )) + '-' + secondary_model_info .model_name + '_' + str (round (multiplier , 2 )) + '-' + interp_method .replace (" " , "_" ) + '-merged.ckpt'
318
- filename = filename if custom_name == '' else (custom_name + '.ckpt' )
317
+ output_exttype = '.safetensors' if save_as_safetensors else '.ckpt'
318
+ filename = primary_model_info .model_name + '_' + str (round (1 - multiplier , 2 )) + '-' + secondary_model_info .model_name + '_' + str (round (multiplier , 2 )) + '-' + interp_method .replace (" " , "_" ) + '-merged' + output_exttype
319
+ filename = filename if custom_name == '' else (custom_name + output_exttype )
319
320
output_modelname = os .path .join (ckpt_dir , filename )
320
321
321
322
print (f"Saving to { output_modelname } ..." )
322
- torch . save (primary_model , output_modelname )
323
+ sd_models . torch_save (primary_model , output_modelname )
323
324
324
325
sd_models .list_models ()
325
326
0 commit comments