Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 24 additions & 16 deletions src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import math
import contextvars, math
from dataclasses import dataclass
from typing import List, Optional, Tuple, Union

Expand Down Expand Up @@ -82,8 +82,8 @@ def __init__(

self.timesteps = sigmas * num_train_timesteps

self._step_index = None
self._begin_index = None
self._step_index = contextvars.ContextVar("step_index")
self._begin_index = contextvars.ContextVar("begin_index")

self.sigmas = sigmas.to("cpu") # to avoid too much CPU/GPU communication
self.sigma_min = self.sigmas[-1].item()
Expand All @@ -94,14 +94,22 @@ def step_index(self):
"""
The index counter for current timestep. It will increase 1 after each scheduler step.
"""
return self._step_index
return self._step_index.get()

@step_index.setter
def step_index(self, step_index):
self._step_index.set(step_index)

@property
def begin_index(self):
"""
The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
"""
return self._begin_index
return self._begin_index.get()

@begin_index.setter
def begin_index(self, begin_index):
self._begin_index.set(begin_index)

# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
def set_begin_index(self, begin_index: int = 0):
Expand All @@ -112,7 +120,7 @@ def set_begin_index(self, begin_index: int = 0):
begin_index (`int`):
The begin index for the scheduler.
"""
self._begin_index = begin_index
self.begin_index = begin_index

def scale_noise(
self,
Expand Down Expand Up @@ -144,10 +152,10 @@ def scale_noise(
schedule_timesteps = self.timesteps.to(sample.device)
timestep = timestep.to(sample.device)

# self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index
if self.begin_index is None:
# self._begin_index is an unset contextvar when scheduler is used for training, or pipeline does not implement set_begin_index
if self._begin_index not in contextvars.copy_context():
step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timestep]
elif self.step_index is not None:
elif self._step_index in contextvars.copy_context():
# add_noise is called after first denoising step (for inpainting)
step_indices = [self.step_index] * timestep.shape[0]
else:
Expand Down Expand Up @@ -207,8 +215,8 @@ def set_timesteps(
self.timesteps = timesteps.to(device=device)
self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)])

self._step_index = None
self._begin_index = None
self._step_index = contextvars.ContextVar("step_index")
self._begin_index = contextvars.ContextVar("begin_index")

def index_for_timestep(self, timestep, schedule_timesteps=None):
if schedule_timesteps is None:
Expand All @@ -225,12 +233,12 @@ def index_for_timestep(self, timestep, schedule_timesteps=None):
return indices[pos].item()

def _init_step_index(self, timestep):
if self.begin_index is None:
if self._begin_index not in contextvars.copy_context():
if isinstance(timestep, torch.Tensor):
timestep = timestep.to(self.timesteps.device)
self._step_index = self.index_for_timestep(timestep)
self.step_index = self.index_for_timestep(timestep)
else:
self._step_index = self._begin_index
self.step_index = self.begin_index

def step(
self,
Expand Down Expand Up @@ -285,7 +293,7 @@ def step(
),
)

if self.step_index is None:
if self._step_index not in contextvars.copy_context():
self._init_step_index(timestep)

# Upcast to avoid precision issues when computing prev_sample
Expand All @@ -300,7 +308,7 @@ def step(
prev_sample = prev_sample.to(model_output.dtype)

# upon completion increase step index by one
self._step_index += 1
self.step_index += 1

if not return_dict:
return (prev_sample,)
Expand Down