Skip to content

Commit c5e62ae

Browse files
authored
Merge branch 'main' into quantization-config
2 parents 81bb48a + 9d06161 commit c5e62ae

File tree

5 files changed

+23
-26
lines changed

5 files changed

+23
-26
lines changed

src/diffusers/schedulers/scheduling_edm_euler.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -333,14 +333,13 @@ def step(
333333

334334
gamma = min(s_churn / (len(self.sigmas) - 1), 2**0.5 - 1) if s_tmin <= sigma <= s_tmax else 0.0
335335

336-
noise = randn_tensor(
337-
model_output.shape, dtype=model_output.dtype, device=model_output.device, generator=generator
338-
)
339-
340-
eps = noise * s_noise
341336
sigma_hat = sigma * (gamma + 1)
342337

343338
if gamma > 0:
339+
noise = randn_tensor(
340+
model_output.shape, dtype=model_output.dtype, device=model_output.device, generator=generator
341+
)
342+
eps = noise * s_noise
344343
sample = sample + eps * (sigma_hat**2 - sigma**2) ** 0.5
345344

346345
# 1. compute predicted original sample (x_0) from sigma-scaled predicted noise

src/diffusers/schedulers/scheduling_euler_discrete.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -638,14 +638,13 @@ def step(
638638

639639
gamma = min(s_churn / (len(self.sigmas) - 1), 2**0.5 - 1) if s_tmin <= sigma <= s_tmax else 0.0
640640

641-
noise = randn_tensor(
642-
model_output.shape, dtype=model_output.dtype, device=model_output.device, generator=generator
643-
)
644-
645-
eps = noise * s_noise
646641
sigma_hat = sigma * (gamma + 1)
647642

648643
if gamma > 0:
644+
noise = randn_tensor(
645+
model_output.shape, dtype=model_output.dtype, device=model_output.device, generator=generator
646+
)
647+
eps = noise * s_noise
649648
sample = sample + eps * (sigma_hat**2 - sigma**2) ** 0.5
650649

651650
# 1. compute predicted original sample (x_0) from sigma-scaled predicted noise

src/diffusers/schedulers/scheduling_flow_match_heun_discrete.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -266,14 +266,13 @@ def step(
266266

267267
gamma = min(s_churn / (len(self.sigmas) - 1), 2**0.5 - 1) if s_tmin <= sigma <= s_tmax else 0.0
268268

269-
noise = randn_tensor(
270-
model_output.shape, dtype=model_output.dtype, device=model_output.device, generator=generator
271-
)
272-
273-
eps = noise * s_noise
274269
sigma_hat = sigma * (gamma + 1)
275270

276271
if gamma > 0:
272+
noise = randn_tensor(
273+
model_output.shape, dtype=model_output.dtype, device=model_output.device, generator=generator
274+
)
275+
eps = noise * s_noise
277276
sample = sample + eps * (sigma_hat**2 - sigma**2) ** 0.5
278277

279278
if self.state_in_first_order:

src/diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -524,9 +524,6 @@ def step(
524524
gamma = 0
525525
sigma_hat = sigma * (gamma + 1) # Note: sigma_hat == sigma for now
526526

527-
device = model_output.device
528-
noise = randn_tensor(model_output.shape, dtype=model_output.dtype, device=device, generator=generator)
529-
530527
# 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
531528
if self.config.prediction_type == "epsilon":
532529
sigma_input = sigma_hat if self.state_in_first_order else sigma_interpol
@@ -564,6 +561,9 @@ def step(
564561
self.sample = None
565562

566563
prev_sample = sample + derivative * dt
564+
noise = randn_tensor(
565+
model_output.shape, dtype=model_output.dtype, device=model_output.device, generator=generator
566+
)
567567
prev_sample = prev_sample + noise * sigma_up
568568

569569
# upon completion increase step index by one

tests/schedulers/test_scheduler_kdpm2_ancestral.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -59,8 +59,8 @@ def test_full_loop_no_noise(self):
5959
result_sum = torch.sum(torch.abs(sample))
6060
result_mean = torch.mean(torch.abs(sample))
6161

62-
assert abs(result_sum.item() - 13849.3877) < 1e-2
63-
assert abs(result_mean.item() - 18.0331) < 5e-3
62+
assert abs(result_sum.item() - 13979.9433) < 1e-2
63+
assert abs(result_mean.item() - 18.2030) < 5e-3
6464

6565
def test_prediction_type(self):
6666
for prediction_type in ["epsilon", "v_prediction"]:
@@ -92,8 +92,8 @@ def test_full_loop_with_v_prediction(self):
9292
result_sum = torch.sum(torch.abs(sample))
9393
result_mean = torch.mean(torch.abs(sample))
9494

95-
assert abs(result_sum.item() - 328.9970) < 1e-2
96-
assert abs(result_mean.item() - 0.4284) < 1e-3
95+
assert abs(result_sum.item() - 331.8133) < 1e-2
96+
assert abs(result_mean.item() - 0.4320) < 1e-3
9797

9898
def test_full_loop_device(self):
9999
if torch_device == "mps":
@@ -119,8 +119,8 @@ def test_full_loop_device(self):
119119
result_sum = torch.sum(torch.abs(sample))
120120
result_mean = torch.mean(torch.abs(sample))
121121

122-
assert abs(result_sum.item() - 13849.3818) < 1e-1
123-
assert abs(result_mean.item() - 18.0331) < 1e-3
122+
assert abs(result_sum.item() - 13979.9433) < 1e-1
123+
assert abs(result_mean.item() - 18.2030) < 1e-3
124124

125125
def test_full_loop_with_noise(self):
126126
if torch_device == "mps":
@@ -154,5 +154,5 @@ def test_full_loop_with_noise(self):
154154
result_sum = torch.sum(torch.abs(sample))
155155
result_mean = torch.mean(torch.abs(sample))
156156

157-
assert abs(result_sum.item() - 93087.0312) < 1e-2, f" expected result sum 93087.0312, but get {result_sum}"
158-
assert abs(result_mean.item() - 121.2071) < 5e-3, f" expected result mean 121.2071, but get {result_mean}"
157+
assert abs(result_sum.item() - 93087.3437) < 1e-2, f" expected result sum 93087.3437, but get {result_sum}"
158+
assert abs(result_mean.item() - 121.2074) < 5e-3, f" expected result mean 121.2074, but get {result_mean}"

0 commit comments

Comments
 (0)