@@ -1154,3 +1154,36 @@ def post_cfg_function(args):
11541154 if sigmas [i + 1 ] > 0 :
11551155 x = x + noise_sampler (sigmas [i ], sigmas [i + 1 ]) * s_noise * sigma_up
11561156 return x
1157+
1158+ @torch .no_grad ()
1159+ def sample_dpmpp_2m_cfg_pp (model , x , sigmas , extra_args = None , callback = None , disable = None ):
1160+ """DPM-Solver++(2M)."""
1161+ extra_args = {} if extra_args is None else extra_args
1162+ s_in = x .new_ones ([x .shape [0 ]])
1163+ t_fn = lambda sigma : sigma .log ().neg ()
1164+
1165+ old_uncond_denoised = None
1166+ uncond_denoised = None
1167+ def post_cfg_function (args ):
1168+ nonlocal uncond_denoised
1169+ uncond_denoised = args ["uncond_denoised" ]
1170+ return args ["denoised" ]
1171+
1172+ model_options = extra_args .get ("model_options" , {}).copy ()
1173+ extra_args ["model_options" ] = comfy .model_patcher .set_model_options_post_cfg_function (model_options , post_cfg_function , disable_cfg1_optimization = True )
1174+
1175+ for i in trange (len (sigmas ) - 1 , disable = disable ):
1176+ denoised = model (x , sigmas [i ] * s_in , ** extra_args )
1177+ if callback is not None :
1178+ callback ({'x' : x , 'i' : i , 'sigma' : sigmas [i ], 'sigma_hat' : sigmas [i ], 'denoised' : denoised })
1179+ t , t_next = t_fn (sigmas [i ]), t_fn (sigmas [i + 1 ])
1180+ h = t_next - t
1181+ if old_uncond_denoised is None or sigmas [i + 1 ] == 0 :
1182+ denoised_mix = - torch .exp (- h ) * uncond_denoised
1183+ else :
1184+ h_last = t - t_fn (sigmas [i - 1 ])
1185+ r = h_last / h
1186+ denoised_mix = - torch .exp (- h ) * uncond_denoised - torch .expm1 (- h ) * (1 / (2 * r )) * (denoised - old_uncond_denoised )
1187+ x = denoised + denoised_mix + torch .exp (- h ) * x
1188+ old_uncond_denoised = uncond_denoised
1189+ return x
0 commit comments