diff --git a/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py b/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py index 6ddd9ac23009..c7474d56c708 100644 --- a/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py +++ b/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py @@ -99,10 +99,19 @@ def __init__( self._step_index = None self._begin_index = None + self._shift = shift + self.sigmas = sigmas.to("cpu") # to avoid too much CPU/GPU communication self.sigma_min = self.sigmas[-1].item() self.sigma_max = self.sigmas[0].item() + @property + def shift(self): + """ + The value used for shifting. + """ + return self._shift + @property def step_index(self): """ @@ -128,6 +137,9 @@ def set_begin_index(self, begin_index: int = 0): """ self._begin_index = begin_index + def set_shift(self, shift: float): + self._shift = shift + def scale_noise( self, sample: torch.FloatTensor, @@ -236,7 +248,7 @@ def set_timesteps( if self.config.use_dynamic_shifting: sigmas = self.time_shift(mu, 1.0, sigmas) else: - sigmas = self.config.shift * sigmas / (1 + (self.config.shift - 1) * sigmas) + sigmas = self.shift * sigmas / (1 + (self.shift - 1) * sigmas) if self.config.shift_terminal: sigmas = self.stretch_shift_to_terminal(sigmas)