1+ import math
12from typing import List , Optional , Union
23
3- import math
44import numpy as np
55import torch
66
7- from ...configuration_utils import ConfigMixin , register_to_config
87from ..sigmas .beta_sigmas import BetaSigmas
98from ..sigmas .exponential_sigmas import ExponentialSigmas
109from ..sigmas .karras_sigmas import KarrasSigmas
1110
12- class FlowMatchSD3 :
13-
14- def _sigma_to_t (self , sigma ):
15- return sigma * self .num_train_timesteps
1611
17- def __call__ (self , num_inference_steps : int , num_train_timesteps : int , shift : float , use_dynamic_shifting : bool = False , ** kwargs ):
18- self .num_train_timesteps = num_train_timesteps
12+ class FlowMatchSD3 :
13+ def __call__ (
14+ self ,
15+ num_inference_steps : int ,
16+ num_train_timesteps : int ,
17+ shift : float ,
18+ use_dynamic_shifting : bool = False ,
19+ ** kwargs ,
20+ ) -> np .ndarray :
21+ """
22+ This is different to others that directly calculate `sigmas`.
23+ It needs `sigma_min` and `sigma_max` after shift
24+ https://github.com/huggingface/diffusers/blob/0ed09a17bbab784a78fb163b557b4827467b0468/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py#L89-L95
25+ Then we calculate `sigmas` from that `sigma_min` and `sigma_max`.
26+ https://github.com/huggingface/diffusers/blob/0ed09a17bbab784a78fb163b557b4827467b0468/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py#L238-L240
27+ Shifting happens again after (outside of this).
28+ https://github.com/huggingface/diffusers/blob/0ed09a17bbab784a78fb163b557b4827467b0468/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py#L248-L251
29+ """
1930 timesteps = np .linspace (1 , num_train_timesteps , num_train_timesteps , dtype = np .float32 )[::- 1 ].copy ()
2031 timesteps = torch .from_numpy (timesteps ).to (dtype = torch .float32 )
2132
@@ -25,18 +36,20 @@ def __call__(self, num_inference_steps: int, num_train_timesteps: int, shift: fl
2536 sigmas = shift * sigmas / (1 + (shift - 1 ) * sigmas )
2637 sigma_min = sigmas [- 1 ].item ()
2738 sigma_max = sigmas [0 ].item ()
28- timesteps = np .linspace (
29- self ._sigma_to_t (sigma_max ), self ._sigma_to_t (sigma_min ), num_inference_steps
30- )
39+ timesteps = np .linspace (sigma_max * num_train_timesteps , sigma_min * num_train_timesteps , num_inference_steps )
3140 sigmas = timesteps / num_train_timesteps
3241 return sigmas
3342
43+
3444class FlowMatchFlux :
35- def __call__ (self , num_inference_steps : int , ** kwargs ):
45+ def __call__ (self , num_inference_steps : int , ** kwargs ) -> np . ndarray :
3646 return np .linspace (1.0 , 1 / num_inference_steps , num_inference_steps )
3747
48+
3849class FlowMatchLinearQuadratic :
39- def __call__ (self , num_inference_steps : int , threshold_noise : float = 0.25 , linear_steps : Optional [int ] = None , ** kwargs ):
50+ def __call__ (
51+ self , num_inference_steps : int , threshold_noise : float = 0.25 , linear_steps : Optional [int ] = None , ** kwargs
52+ ) -> np .ndarray :
4053 if linear_steps is None :
4154 linear_steps = num_inference_steps // 2
4255 linear_sigma_schedule = [i * threshold_noise / linear_steps for i in range (linear_steps )]
@@ -49,22 +62,33 @@ def __call__(self, num_inference_steps: int, threshold_noise: float = 0.25, line
4962 quadratic_coef * (i ** 2 ) + linear_coef * i + const for i in range (linear_steps , num_inference_steps )
5063 ]
5164 sigma_schedule = linear_sigma_schedule + quadratic_sigma_schedule
52- sigma_schedule = [1.0 - x for x in sigma_schedule ]
65+ sigma_schedule = np . array ( [1.0 - x for x in sigma_schedule ]). astype ( np . float32 )
5366 return sigma_schedule
5467
68+
5569class FlowMatchHunyuanVideo :
56- def __call__ (self , num_inference_steps : int , ** kwargs ):
70+ def __call__ (self , num_inference_steps : int , ** kwargs ) -> np . ndarray :
5771 return np .linspace (1.0 , 0.0 , num_inference_steps + 1 )[:- 1 ].copy ()
5872
73+
74+ class FlowMatchSANA :
75+ def __call__ (self , num_inference_steps : int , num_train_timesteps : int , shift : float , ** kwargs ) -> np .ndarray :
76+ alphas = np .linspace (1 , 1 / num_train_timesteps , num_inference_steps + 1 )
77+ sigmas = 1.0 - alphas
78+ sigmas = np .flip (shift * sigmas / (1 + (shift - 1 ) * sigmas ))[:- 1 ].copy ()
79+ return sigmas
80+
81+
5982BASE_SCHEDULE_MAP = {
6083 "FlowMatchHunyuanVideo" : FlowMatchHunyuanVideo ,
6184 "FlowMatchLinearQuadratic" : FlowMatchLinearQuadratic ,
6285 "FlowMatchFlux" : FlowMatchFlux ,
6386 "FlowMatchSD3" : FlowMatchSD3 ,
87+ "FlowMatchSANA" : FlowMatchSANA ,
6488}
6589
66- class FlowMatchSchedule :
6790
91+ class FlowMatchSchedule :
6892 scale_model_input = False
6993
7094 base_schedules = BASE_SCHEDULE_MAP
@@ -145,7 +169,7 @@ def __call__(
145169 ):
146170 shift = shift or self .shift
147171 if self .use_dynamic_shifting and mu is None :
148- raise ValueError (" you have a pass a value for `mu` when `use_dynamic_shifting` is set to be `True`" )
172+ raise ValueError ("You have to pass a value for `mu` when `use_dynamic_shifting` is set to be `True`" )
149173
150174 if sigmas is None :
151175 sigmas = self .base_schedule (
@@ -155,9 +179,8 @@ def __call__(
155179 use_dynamic_shifting = self .use_dynamic_shifting ,
156180 )
157181 else :
182+ # NOTE: current usage is **without** `sigma_last` - different than BetaSchedule
158183 sigmas = np .array (sigmas ).astype (np .float32 )
159- num_inference_steps = len (sigmas )
160- self .num_inference_steps = num_inference_steps
161184
162185 if self .use_dynamic_shifting :
163186 sigmas = self .time_shift (mu , 1.0 , sigmas )
0 commit comments