Skip to content

Commit 90c240b

Browse files
committed
apply review sugestions
1 parent 4c37ef0 commit 90c240b

File tree

1 file changed

+23
-5
lines changed

1 file changed

+23
-5
lines changed

src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -245,7 +245,7 @@ def stretch_shift_to_terminal(self, t: torch.Tensor) -> torch.Tensor:
245245

246246
def set_timesteps(
247247
self,
248-
num_inference_steps: int = None,
248+
num_inference_steps: Optional[int] = None,
249249
device: Union[str, torch.device] = None,
250250
sigmas: Optional[List[float]] = None,
251251
mu: Optional[float] = None,
@@ -255,7 +255,7 @@ def set_timesteps(
255255
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
256256
257257
Args:
258-
num_inference_steps (`int`):
258+
num_inference_steps (`int`, *optional*):
259259
The number of diffusion steps used when generating samples with a pre-trained model.
260260
device (`str` or `torch.device`, *optional*):
261261
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
@@ -270,22 +270,40 @@ def set_timesteps(
270270
automatically.
271271
"""
272272
if self.config.use_dynamic_shifting and mu is None:
273-
raise ValueError(" you have a pass a value for `mu` when `use_dynamic_shifting` is set to be `True`")
273+
raise ValueError("`mu` must be passed when `use_dynamic_shifting` is set to be `True`")
274+
275+
if sigmas is not None and timesteps is not None:
276+
if len(sigmas) != len(timesteps):
277+
raise ValueError("`sigmas` and `timesteps` should have the same length")
278+
279+
if num_inference_steps is not None:
280+
if (sigmas is not None and len(sigmas) != num_inference_steps) or (
281+
timesteps is not None and len(timesteps) != num_inference_steps
282+
):
283+
raise ValueError(
284+
"`sigmas` and `timesteps` should have the same length as num_inference_steps, if `num_inference_steps` is provided"
285+
)
286+
else:
287+
num_inference_steps = len(sigmas) if sigmas is not None else len(timesteps)
274288

275289
self.num_inference_steps = num_inference_steps
276290

277291
# 1. Prepare default sigmas
278292
is_timesteps_provided = timesteps is not None
293+
294+
if is_timesteps_provided:
295+
timesteps = np.array(timesteps).astype(np.float32)
296+
279297
if sigmas is None:
280298
if timesteps is None:
281299
timesteps = np.linspace(
282300
self._sigma_to_t(self.sigma_max), self._sigma_to_t(self.sigma_min), num_inference_steps
283301
)
284-
else:
285-
timesteps = np.array(timesteps).astype(np.float32)
286302
sigmas = timesteps / self.config.num_train_timesteps
303+
num_inference_steps = len(sigmas)
287304
else:
288305
sigmas = np.array(sigmas).astype(np.float32)
306+
num_inference_steps = len(sigmas)
289307

290308
# 2. Perform timestep shifting. Either no shifting is applied, or resolution-dependent shifting of
291309
# "exponential" or "linear" type is applied

0 commit comments

Comments
 (0)