Skip to content

Commit a4453ed

Browse files
committed
deis
1 parent 647658b commit a4453ed

File tree

6 files changed

+160
-366
lines changed

6 files changed

+160
-366
lines changed

src/diffusers/configuration_utils.py

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -252,7 +252,9 @@ def from_config(cls, config: Union[FrozenDict, Dict[str, Any]] = None, return_un
252252
use_karras_sigmas = config.pop("use_karras_sigmas", None)
253253
use_exponential_sigmas = config.pop("use_exponential_sigmas", None)
254254
use_beta_sigmas = config.pop("use_beta_sigmas", None)
255+
use_flow_sigmas = config.pop("use_flow_sigmas", None)
255256
prediction_type = config.pop("prediction_type", None)
257+
schedule_config = {}
256258
if use_karras_sigmas:
257259
sigma_schedule_config = {"class_name": "KarrasSigmas"}
258260
elif use_exponential_sigmas:
@@ -263,16 +265,32 @@ def from_config(cls, config: Union[FrozenDict, Dict[str, Any]] = None, return_un
263265
sigma_schedule_config = {}
264266
if "beta_schedule" in config:
265267
config.update({"class_name": "BetaSchedule"})
266-
elif "shift" in config:
268+
from .schedulers.schedules.beta_schedule import BetaSchedule
269+
270+
expected_kwargs = list(inspect.signature(BetaSchedule.__init__).parameters)[1:-1]
271+
for expected_kwarg in expected_kwargs:
272+
if expected_kwarg in config:
273+
schedule_config[expected_kwarg] = config.pop(expected_kwarg)
274+
elif "shift" in config or use_flow_sigmas:
267275
config.update({"class_name": "FlowMatchSchedule"})
276+
from .schedulers.schedules.flow_schedule import FlowMatchSchedule
277+
278+
expected_kwargs = list(inspect.signature(FlowMatchSchedule.__init__).parameters)[1:-1]
279+
for expected_kwarg in expected_kwargs:
280+
if expected_kwarg in config:
281+
schedule_config[expected_kwarg] = config.pop(expected_kwarg)
282+
if prediction_type == "flow_prediction":
283+
prediction_type = "epsilon"
268284
if prediction_type:
269285
config.update({"prediction_type": prediction_type})
270-
config = {
271-
"_class_name": _class_name,
272-
"_diffusers_version": _diffusers_version,
273-
"schedule_config": config,
274-
"sigma_schedule_config": sigma_schedule_config,
275-
}
286+
config.update(
287+
{
288+
"_class_name": _class_name,
289+
"_diffusers_version": _diffusers_version,
290+
"schedule_config": schedule_config,
291+
"sigma_schedule_config": sigma_schedule_config,
292+
}
293+
)
276294

277295
init_dict, unused_kwargs, hidden_dict = cls.extract_init_dict(config, **kwargs)
278296

src/diffusers/schedulers/schedules/beta_schedule.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,10 @@ def __init__(
133133
# FP16 smallest positive subnormal works well here
134134
self.alphas_cumprod[-1] = 2**-24
135135

136+
self.alpha_t = torch.sqrt(self.alphas_cumprod)
137+
self.sigma_t = torch.sqrt(1 - self.alphas_cumprod)
138+
self.lambda_t = torch.log(self.alpha_t) - torch.log(self.sigma_t)
139+
136140
self.num_train_timesteps = num_train_timesteps
137141
self.beta_start = beta_start
138142
self.beta_end = beta_end

0 commit comments

Comments
 (0)