Skip to content

Commit c4f8758

Browse files
committed
Update KDPM2AncestralDiscreteSchedulerTest
1 parent bcd1fc1 commit c4f8758

File tree

1 file changed

+8
-8
lines changed

1 file changed

+8
-8
lines changed

tests/schedulers/test_scheduler_kdpm2_ancestral.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -59,8 +59,8 @@ def test_full_loop_no_noise(self):
5959
result_sum = torch.sum(torch.abs(sample))
6060
result_mean = torch.mean(torch.abs(sample))
6161

62-
assert abs(result_sum.item() - 13849.3877) < 1e-2
63-
assert abs(result_mean.item() - 18.0331) < 5e-3
62+
assert abs(result_sum.item() - 13979.9433) < 1e-2
63+
assert abs(result_mean.item() - 18.2030) < 5e-3
6464

6565
def test_prediction_type(self):
6666
for prediction_type in ["epsilon", "v_prediction"]:
@@ -92,8 +92,8 @@ def test_full_loop_with_v_prediction(self):
9292
result_sum = torch.sum(torch.abs(sample))
9393
result_mean = torch.mean(torch.abs(sample))
9494

95-
assert abs(result_sum.item() - 328.9970) < 1e-2
96-
assert abs(result_mean.item() - 0.4284) < 1e-3
95+
assert abs(result_sum.item() - 331.8133) < 1e-2
96+
assert abs(result_mean.item() - 0.4320) < 1e-3
9797

9898
def test_full_loop_device(self):
9999
if torch_device == "mps":
@@ -119,8 +119,8 @@ def test_full_loop_device(self):
119119
result_sum = torch.sum(torch.abs(sample))
120120
result_mean = torch.mean(torch.abs(sample))
121121

122-
assert abs(result_sum.item() - 13849.3818) < 1e-1
123-
assert abs(result_mean.item() - 18.0331) < 1e-3
122+
assert abs(result_sum.item() - 13979.9433) < 1e-1
123+
assert abs(result_mean.item() - 18.2030) < 1e-3
124124

125125
def test_full_loop_with_noise(self):
126126
if torch_device == "mps":
@@ -154,5 +154,5 @@ def test_full_loop_with_noise(self):
154154
result_sum = torch.sum(torch.abs(sample))
155155
result_mean = torch.mean(torch.abs(sample))
156156

157-
assert abs(result_sum.item() - 93087.0312) < 1e-2, f" expected result sum 93087.0312, but get {result_sum}"
158-
assert abs(result_mean.item() - 121.2071) < 5e-3, f" expected result mean 121.2071, but get {result_mean}"
157+
assert abs(result_sum.item() - 93087.3437) < 1e-2, f" expected result sum 93087.3437, but get {result_sum}"
158+
assert abs(result_mean.item() - 121.2074) < 5e-3, f" expected result mean 121.2074, but get {result_mean}"

0 commit comments

Comments
 (0)