Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 24 additions & 11 deletions src/diffusers/pipelines/hidream_image/pipeline_hidream_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def retrieve_timesteps(
sigmas: Optional[List[float]] = None,
**kwargs,
):
r"""
"""
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.

Expand All @@ -118,29 +118,40 @@ def retrieve_timesteps(
if timesteps is not None and sigmas is not None:
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
if timesteps is not None:
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accepts_timesteps:
# Use cached reflection of set_timesteps signature for speed.
params = _get_set_timesteps_params(scheduler)
if "timesteps" not in params:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" timestep schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
timesteps_out = scheduler.timesteps
num_inference_steps_out = len(timesteps_out)
elif sigmas is not None:
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accept_sigmas:
# Use cached reflection of set_timesteps signature for speed.
params = _get_set_timesteps_params(scheduler)
if "sigmas" not in params:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" sigmas schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
timesteps_out = scheduler.timesteps
num_inference_steps_out = len(timesteps_out)
else:
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
timesteps = scheduler.timesteps
return timesteps, num_inference_steps
timesteps_out = scheduler.timesteps
num_inference_steps_out = num_inference_steps
return timesteps_out, num_inference_steps_out

def _get_set_timesteps_params(scheduler):
cls = type(scheduler)
if cls not in _scheduler_set_timesteps_param_cache:
# Only do the expensive inspect once per scheduler class
params = set(inspect.signature(cls.set_timesteps).parameters)
_scheduler_set_timesteps_param_cache[cls] = params
return _scheduler_set_timesteps_param_cache[cls]


class HiDreamImagePipeline(DiffusionPipeline, HiDreamImageLoraLoaderMixin):
Expand Down Expand Up @@ -1010,3 +1021,5 @@ def __call__(
return (image,)

return HiDreamImagePipelineOutput(images=image)

_scheduler_set_timesteps_param_cache = {}