diff --git a/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py b/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py index 937cae2e47f5..6fadf5c8180e 100644 --- a/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py +++ b/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -import math +import contextvars, math from dataclasses import dataclass from typing import List, Optional, Tuple, Union @@ -82,8 +82,8 @@ def __init__( self.timesteps = sigmas * num_train_timesteps - self._step_index = None - self._begin_index = None + self._step_index = contextvars.ContextVar("step_index") + self._begin_index = contextvars.ContextVar("begin_index") self.sigmas = sigmas.to("cpu") # to avoid too much CPU/GPU communication self.sigma_min = self.sigmas[-1].item() @@ -94,14 +94,22 @@ def step_index(self): """ The index counter for current timestep. It will increase 1 after each scheduler step. """ - return self._step_index + return self._step_index.get() + + @step_index.setter + def step_index(self, step_index): + self._step_index.set(step_index) @property def begin_index(self): """ The index for the first timestep. It should be set from pipeline with `set_begin_index` method. """ - return self._begin_index + return self._begin_index.get() + + @begin_index.setter + def begin_index(self, begin_index): + self._begin_index.set(begin_index) # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index def set_begin_index(self, begin_index: int = 0): @@ -112,7 +120,7 @@ def set_begin_index(self, begin_index: int = 0): begin_index (`int`): The begin index for the scheduler. """ - self._begin_index = begin_index + self.begin_index = begin_index def scale_noise( self, @@ -144,10 +152,10 @@ def scale_noise( schedule_timesteps = self.timesteps.to(sample.device) timestep = timestep.to(sample.device) - # self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index - if self.begin_index is None: + # self._begin_index is an unset contextvar when scheduler is used for training, or pipeline does not implement set_begin_index + if self._begin_index not in contextvars.copy_context(): step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timestep] - elif self.step_index is not None: + elif self._step_index in contextvars.copy_context(): # add_noise is called after first denoising step (for inpainting) step_indices = [self.step_index] * timestep.shape[0] else: @@ -207,8 +215,8 @@ def set_timesteps( self.timesteps = timesteps.to(device=device) self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)]) - self._step_index = None - self._begin_index = None + self._step_index = contextvars.ContextVar("step_index") + self._begin_index = contextvars.ContextVar("begin_index") def index_for_timestep(self, timestep, schedule_timesteps=None): if schedule_timesteps is None: @@ -225,12 +233,12 @@ def index_for_timestep(self, timestep, schedule_timesteps=None): return indices[pos].item() def _init_step_index(self, timestep): - if self.begin_index is None: + if self._begin_index not in contextvars.copy_context(): if isinstance(timestep, torch.Tensor): timestep = timestep.to(self.timesteps.device) - self._step_index = self.index_for_timestep(timestep) + self.step_index = self.index_for_timestep(timestep) else: - self._step_index = self._begin_index + self.step_index = self.begin_index def step( self, @@ -285,7 +293,7 @@ def step( ), ) - if self.step_index is None: + if self._step_index not in contextvars.copy_context(): self._init_step_index(timestep) # Upcast to avoid precision issues when computing prev_sample @@ -300,7 +308,7 @@ def step( prev_sample = prev_sample.to(model_output.dtype) # upon completion increase step index by one - self._step_index += 1 + self.step_index += 1 if not return_dict: return (prev_sample,)