@@ -1366,3 +1366,59 @@ def sample_gradient_estimation(model, x, sigmas, extra_args=None, callback=None,
13661366 x = x + d_bar * dt
13671367 old_d = d
13681368 return x
1369+
1370+ @torch .no_grad ()
1371+ def sample_er_sde (model , x , sigmas , extra_args = None , callback = None , disable = None , s_noise = 1. , noise_sampler = None , noise_scaler = None , max_stage = 3 ):
1372+ """
1373+ Extended Reverse-Time SDE solver (VE ER-SDE-Solver-3). Arxiv: https://arxiv.org/abs/2309.06169.
1374+ Code reference: https://github.com/QinpengCui/ER-SDE-Solver/blob/main/er_sde_solver.py.
1375+ """
1376+ extra_args = {} if extra_args is None else extra_args
1377+ seed = extra_args .get ("seed" , None )
1378+ noise_sampler = default_noise_sampler (x , seed = seed ) if noise_sampler is None else noise_sampler
1379+ s_in = x .new_ones ([x .shape [0 ]])
1380+
1381+ def default_noise_scaler (sigma ):
1382+ return sigma * ((sigma ** 0.3 ).exp () + 10.0 )
1383+ noise_scaler = default_noise_scaler if noise_scaler is None else noise_scaler
1384+ num_integration_points = 200.0
1385+ point_indice = torch .arange (0 , num_integration_points , dtype = torch .float32 , device = x .device )
1386+
1387+ old_denoised = None
1388+ old_denoised_d = None
1389+
1390+ for i in trange (len (sigmas ) - 1 , disable = disable ):
1391+ denoised = model (x , sigmas [i ] * s_in , ** extra_args )
1392+ if callback is not None :
1393+ callback ({'x' : x , 'i' : i , 'sigma' : sigmas [i ], 'sigma_hat' : sigmas [i ], 'denoised' : denoised })
1394+ stage_used = min (max_stage , i + 1 )
1395+ if sigmas [i + 1 ] == 0 :
1396+ x = denoised
1397+ elif stage_used == 1 :
1398+ r = noise_scaler (sigmas [i + 1 ]) / noise_scaler (sigmas [i ])
1399+ x = r * x + (1 - r ) * denoised
1400+ else :
1401+ r = noise_scaler (sigmas [i + 1 ]) / noise_scaler (sigmas [i ])
1402+ x = r * x + (1 - r ) * denoised
1403+
1404+ dt = sigmas [i + 1 ] - sigmas [i ]
1405+ sigma_step_size = - dt / num_integration_points
1406+ sigma_pos = sigmas [i + 1 ] + point_indice * sigma_step_size
1407+ scaled_pos = noise_scaler (sigma_pos )
1408+
1409+ # Stage 2
1410+ s = torch .sum (1 / scaled_pos ) * sigma_step_size
1411+ denoised_d = (denoised - old_denoised ) / (sigmas [i ] - sigmas [i - 1 ])
1412+ x = x + (dt + s * noise_scaler (sigmas [i + 1 ])) * denoised_d
1413+
1414+ if stage_used >= 3 :
1415+ # Stage 3
1416+ s_u = torch .sum ((sigma_pos - sigmas [i ]) / scaled_pos ) * sigma_step_size
1417+ denoised_u = (denoised_d - old_denoised_d ) / ((sigmas [i ] - sigmas [i - 2 ]) / 2 )
1418+ x = x + ((dt ** 2 ) / 2 + s_u * noise_scaler (sigmas [i + 1 ])) * denoised_u
1419+ old_denoised_d = denoised_d
1420+
1421+ if s_noise != 0 and sigmas [i + 1 ] > 0 :
1422+ x = x + noise_sampler (sigmas [i ], sigmas [i + 1 ]) * s_noise * (sigmas [i + 1 ] ** 2 - sigmas [i ] ** 2 * r ** 2 ).sqrt ()
1423+ old_denoised = denoised
1424+ return x
0 commit comments