Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ def retrieve_timesteps(
sigmas: Optional[List[float]] = None,
**kwargs,
):
r"""
"""
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.

Expand All @@ -190,30 +190,36 @@ def retrieve_timesteps(
"""
if timesteps is not None and sigmas is not None:
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
st = scheduler.set_timesteps
if timesteps is not None:
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accepts_timesteps:
if not _accepts_kw(st, "timesteps"):
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" timestep schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
st(timesteps=timesteps, device=device, **kwargs)
return scheduler.timesteps, len(scheduler.timesteps)
elif sigmas is not None:
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accept_sigmas:
if not _accepts_kw(st, "sigmas"):
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" sigmas schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
st(sigmas=sigmas, device=device, **kwargs)
return scheduler.timesteps, len(scheduler.timesteps)
else:
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
timesteps = scheduler.timesteps
return timesteps, num_inference_steps
st(num_inference_steps, device=device, **kwargs)
return scheduler.timesteps, num_inference_steps


def _accepts_kw(func, name):
"""Fast check whether func's arguments accept parameter name."""
# most methods have .__code__.co_varnames, it's fastest
try:
return name in func.__code__.co_varnames
except AttributeError:
# fallback for edge cases
return name in inspect.signature(func).parameters


class StableDiffusionXLInpaintPipeline(
Expand Down