@@ -853,6 +853,11 @@ def sample_dpmpp_2m_sde(model, x, sigmas, extra_args=None, callback=None, disabl
853853 return x
854854
855855
856+ @torch .no_grad ()
857+ def sample_dpmpp_2m_sde_heun (model , x , sigmas , extra_args = None , callback = None , disable = None , eta = 1. , s_noise = 1. , noise_sampler = None , solver_type = 'heun' ):
858+ return sample_dpmpp_2m_sde (model , x , sigmas , extra_args = extra_args , callback = callback , disable = disable , eta = eta , s_noise = s_noise , noise_sampler = noise_sampler , solver_type = solver_type )
859+
860+
856861@torch .no_grad ()
857862def sample_dpmpp_3m_sde (model , x , sigmas , extra_args = None , callback = None , disable = None , eta = 1. , s_noise = 1. , noise_sampler = None ):
858863 """DPM-Solver++(3M) SDE."""
@@ -925,6 +930,16 @@ def sample_dpmpp_3m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, di
925930 return sample_dpmpp_3m_sde (model , x , sigmas , extra_args = extra_args , callback = callback , disable = disable , eta = eta , s_noise = s_noise , noise_sampler = noise_sampler )
926931
927932
933+ @torch .no_grad ()
934+ def sample_dpmpp_2m_sde_heun_gpu (model , x , sigmas , extra_args = None , callback = None , disable = None , eta = 1. , s_noise = 1. , noise_sampler = None , solver_type = 'heun' ):
935+ if len (sigmas ) <= 1 :
936+ return x
937+ extra_args = {} if extra_args is None else extra_args
938+ sigma_min , sigma_max = sigmas [sigmas > 0 ].min (), sigmas .max ()
939+ noise_sampler = BrownianTreeNoiseSampler (x , sigma_min , sigma_max , seed = extra_args .get ("seed" , None ), cpu = False ) if noise_sampler is None else noise_sampler
940+ return sample_dpmpp_2m_sde_heun (model , x , sigmas , extra_args = extra_args , callback = callback , disable = disable , eta = eta , s_noise = s_noise , noise_sampler = noise_sampler , solver_type = solver_type )
941+
942+
928943@torch .no_grad ()
929944def sample_dpmpp_2m_sde_gpu (model , x , sigmas , extra_args = None , callback = None , disable = None , eta = 1. , s_noise = 1. , noise_sampler = None , solver_type = 'midpoint' ):
930945 if len (sigmas ) <= 1 :
0 commit comments