-
Notifications
You must be signed in to change notification settings - Fork 6.8k
Expand file tree
/
Copy pathtest_scheduler_flow_match_euler_discrete.py
More file actions
46 lines (34 loc) · 1.82 KB
/
test_scheduler_flow_match_euler_discrete.py
File metadata and controls
46 lines (34 loc) · 1.82 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
import torch
from diffusers import FlowMatchEulerDiscreteScheduler
class TestFlowMatchEulerDiscreteSchedulerSigmaConsistency:
"""Regression tests for https://github.com/huggingface/diffusers/issues/13243"""
def test_set_timesteps_no_double_shift(self):
"""Calling set_timesteps(num_train_timesteps) should reproduce the same sigmas as __init__.
set_timesteps appends a terminal zero, so we compare only the first N values.
"""
scheduler = FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=3.0)
sigmas_init = scheduler.sigmas.clone()
scheduler.set_timesteps(1000)
sigmas_after = scheduler.sigmas[:-1] # drop appended terminal zero
torch.testing.assert_close(sigmas_init, sigmas_after, atol=1e-6, rtol=1e-5)
def test_set_timesteps_no_double_shift_various_shifts(self):
"""The fix holds for different shift values."""
for shift in [1.0, 2.0, 3.0, 5.0]:
scheduler = FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=shift)
sigmas_init = scheduler.sigmas.clone()
scheduler.set_timesteps(1000)
sigmas_after = scheduler.sigmas[:-1]
torch.testing.assert_close(
sigmas_init,
sigmas_after,
atol=1e-6,
rtol=1e-5,
msg=f"Sigma mismatch after set_timesteps with shift={shift}",
)
def test_set_timesteps_fewer_steps(self):
"""set_timesteps with fewer steps should produce sigmas within the original range."""
scheduler = FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=3.0)
scheduler.set_timesteps(50)
# All sigmas should fall within [0, 1]
assert scheduler.sigmas.min() >= 0.0
assert scheduler.sigmas.max() <= 1.0 + 1e-6