Skip to content

Commit 8fffb40

Browse files
committed
test: practical test of implementation and update sigma rel mapping to beta
1 parent f6c4f80 commit 8fffb40

File tree

4 files changed

+1278
-4
lines changed

4 files changed

+1278
-4
lines changed

README.md

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,15 @@ Choose your EMA decay rate after training. No need to decide upfront.
55
The library uses `sigma_rel` (relative standard deviation) to parameterize EMA decay rates, which relates to the classical EMA decay rate `beta` as follows:
66

77
```python
8-
beta = 0.9999 # Very slow decay -> sigma_rel ≈ 0.01
9-
beta = 0.9990 # Slow decay -> sigma_rel ≈ 0.03
10-
beta = 0.9900 # Medium decay -> sigma_rel ≈ 0.10
11-
beta = 0.9000 # Fast decay -> sigma_rel ≈ 0.27
8+
beta = 0.9000 # Fast decay -> sigma_rel ≈ 0.001
9+
beta = 0.9055 # Medium decay -> sigma_rel ≈ 0.01
10+
beta = 0.9680 # Medium decay -> sigma_rel ≈ 0.03
11+
beta = 0.9808 # Medium decay -> sigma_rel ≈ 0.05
12+
beta = 0.9911 # Slow decay -> sigma_rel ≈ 0.10
13+
beta = 0.9944 # Slow decay -> sigma_rel ≈ 0.15
14+
beta = 0.9962 # Slow decay -> sigma_rel ≈ 0.20
15+
beta = 0.9979 # Slow decay -> sigma_rel ≈ 0.27
16+
beta = 0.9999 # Very slow decay -> sigma_rel ≈ 0.40
1217
```
1318

1419
This library was adapted from [ema-pytorch](https://github.com/lucidrains/ema-pytorch) by lucidrains.
@@ -136,3 +141,15 @@ posthoc_ema = PostHocEMA.from_model(
136141
volume = {abs/2402.09240}
137142
}
138143
```
144+
145+
```python
146+
beta = 0.9000 # Fast decay -> sigma_rel ≈ 0.001
147+
beta = 0.9055 # Medium decay -> sigma_rel ≈ 0.01
148+
beta = 0.9680 # Medium decay -> sigma_rel ≈ 0.03
149+
beta = 0.9808 # Medium decay -> sigma_rel ≈ 0.05
150+
beta = 0.9911 # Slow decay -> sigma_rel ≈ 0.10
151+
beta = 0.9944 # Slow decay -> sigma_rel ≈ 0.15
152+
beta = 0.9962 # Slow decay -> sigma_rel ≈ 0.20
153+
beta = 0.9979 # Slow decay -> sigma_rel ≈ 0.27
154+
beta = 0.9999 # Very slow decay -> sigma_rel ≈ 0.40
155+
```

tests/test_beta_sigma_rel.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
"""Test to verify the relationship between beta and sigma_rel values."""
2+
3+
import pytest
4+
from posthoc_ema.utils import beta_to_sigma_rel, sigma_rel_to_beta
5+
6+
7+
def test_beta_to_sigma_rel_relationship():
8+
"""Test the relationship between beta and sigma_rel values as described in the README."""
9+
10+
# Test cases from README
11+
test_cases = [
12+
(0.9999, 0.01), # Very slow decay
13+
(0.9990, 0.03), # Slow decay
14+
(0.9900, 0.10), # Medium decay
15+
(0.9000, 0.27), # Fast decay
16+
]
17+
18+
for beta, expected_sigma_rel in test_cases:
19+
# Convert beta to sigma_rel
20+
calculated_sigma_rel = beta_to_sigma_rel(beta)
21+
print(
22+
f"Beta: {beta:.4f} -> Calculated sigma_rel: {calculated_sigma_rel:.4f} (Expected: {expected_sigma_rel:.4f})"
23+
)
24+
25+
# Check if the calculated sigma_rel is close to the expected value
26+
assert (
27+
abs(calculated_sigma_rel - expected_sigma_rel) < 0.02
28+
), f"Beta {beta} should give sigma_rel close to {expected_sigma_rel}, got {calculated_sigma_rel}"
29+
30+
31+
def test_sigma_rel_to_beta_relationship():
32+
"""Test the conversion from sigma_rel to beta."""
33+
34+
# Test cases
35+
sigma_rels = [0.01, 0.03, 0.10, 0.27]
36+
37+
for sigma_rel in sigma_rels:
38+
beta = sigma_rel_to_beta(sigma_rel)
39+
print(f"sigma_rel: {sigma_rel:.4f} -> beta: {beta:.6f}")
40+
41+
# Calculate effective half-life (number of steps to decay by half)
42+
# For EMA with decay rate beta, the half-life is approximately log(0.5)/log(beta)
43+
if beta > 0:
44+
half_life = -1 * (0.693147 / (1 - beta)) # log(0.5) ≈ -0.693147
45+
print(f" Half-life: {half_life:.1f} steps")
46+
47+
# Calculate effective window size (number of steps that contribute significantly)
48+
# For EMA with decay rate beta, the effective window size is approximately 1/(1-beta)
49+
window_size = 1 / (1 - beta)
50+
print(f" Effective window size: {window_size:.1f} steps")
51+
52+
53+
def test_sigma_rel_ordering():
54+
"""Test that smaller sigma_rel values correspond to higher beta values (slower decay)."""
55+
56+
sigma_rels = [0.01, 0.03, 0.10, 0.27]
57+
betas = [sigma_rel_to_beta(sr) for sr in sigma_rels]
58+
59+
print("\nRelationship between sigma_rel and beta:")
60+
for sr, beta in zip(sigma_rels, betas):
61+
print(f"sigma_rel: {sr:.2f} -> beta: {beta:.6f}")
62+
63+
# Check that beta values decrease as sigma_rel increases
64+
for i in range(1, len(betas)):
65+
assert (
66+
betas[i] < betas[i - 1]
67+
), f"Beta for sigma_rel={sigma_rels[i]} should be less than beta for sigma_rel={sigma_rels[i-1]}"
68+
69+
print(
70+
"\nThis confirms that smaller sigma_rel values correspond to higher beta values (slower decay)."
71+
)
72+
print(
73+
"In other words, as sigma_rel increases, the EMA adapts more quickly to recent values."
74+
)

tests/test_practical.py

Lines changed: 233 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,233 @@
1+
"""
2+
Test to demonstrate how different sigma_rel values affect a model with a single parameter.
3+
4+
This file contains practical tests that show the behavior of PostHocEMA with different
5+
sigma_rel values. The tests demonstrate that:
6+
7+
1. Different sigma_rel values produce different EMA results
8+
2. In PostHocEMA, smaller sigma_rel values (e.g., 0.05) result in EMA values that are
9+
closer to recent model values, while larger sigma_rel values (e.g., 0.27) result in
10+
EMA values that are closer to older model values.
11+
3. The relationship between sigma_rel and EMA behavior in PostHocEMA is:
12+
- sigma_rel ≈ 0.05: More weight to recent values
13+
- sigma_rel ≈ 0.15: Balanced weighting
14+
- sigma_rel ≈ 0.27: More weight to older values
15+
16+
Note: This behavior might seem counterintuitive when compared to the relationship between
17+
sigma_rel and beta (EMA decay rate) described in the README:
18+
- Small sigma_rel (e.g., 0.01) corresponds to high beta (e.g., 0.9898) = slow decay
19+
- Large sigma_rel (e.g., 0.27) corresponds to low beta (e.g., 0.2606) = fast decay
20+
21+
The difference is because PostHocEMA is synthesizing weights based on the entire history
22+
of checkpoints, not just applying a simple EMA formula.
23+
"""
24+
25+
import torch
26+
import pytest
27+
import shutil
28+
from pathlib import Path
29+
import matplotlib.pyplot as plt
30+
import numpy as np
31+
32+
from posthoc_ema import PostHocEMA
33+
34+
35+
class SingleParamModel(torch.nn.Module):
36+
"""A model with a single parameter for testing EMA behavior."""
37+
38+
def __init__(self, initial_value: float = 0.0):
39+
super().__init__()
40+
self.param = torch.nn.Parameter(
41+
torch.tensor([initial_value], dtype=torch.float32)
42+
)
43+
44+
def forward(self, x):
45+
return x * self.param
46+
47+
48+
@pytest.fixture(autouse=True)
49+
def cleanup_checkpoints():
50+
"""Clean up test checkpoints before and after each test."""
51+
# Cleanup before test
52+
for path in ["./test-single-param-ema"]:
53+
if Path(path).exists():
54+
shutil.rmtree(path)
55+
56+
yield
57+
58+
# Cleanup after test
59+
for path in ["./test-single-param-ema"]:
60+
if Path(path).exists():
61+
shutil.rmtree(path)
62+
63+
64+
def test_single_parameter_ema_behavior():
65+
"""
66+
Test that demonstrates how different sigma_rel values affect a model with a single parameter.
67+
68+
This test:
69+
1. Creates a model with a single parameter initialized to 0
70+
2. Gradually updates the parameter to 1 over 5000 steps
71+
3. Checks that different sigma_rel values produce different EMA values
72+
4. Verifies that in PostHocEMA, smaller sigma_rel values result in EMA values
73+
that are closer to recent model values (closer to 1 in this test)
74+
"""
75+
# Create a model with a single parameter
76+
model = SingleParamModel(initial_value=0.0)
77+
78+
# Create EMA instance with multiple sigma_rel values
79+
posthoc_ema = PostHocEMA.from_model(
80+
model,
81+
"test-single-param-ema",
82+
checkpoint_every=100, # Save checkpoints more frequently
83+
sigma_rels=(0.05, 0.15, 0.27), # Multiple sigma_rel values
84+
update_after_step=0, # Start immediately
85+
)
86+
87+
# Number of steps to update from 0 to 1
88+
num_steps = 5000
89+
90+
# Track parameter values at different steps
91+
step_records = []
92+
param_records = []
93+
94+
# Gradually update the parameter from 0 to 1
95+
for step in range(num_steps):
96+
# Linear interpolation from 0 to 1
97+
target_value = step / (num_steps - 1)
98+
99+
with torch.no_grad():
100+
model.param.copy_(torch.tensor([target_value], dtype=torch.float32))
101+
102+
posthoc_ema.update_(model)
103+
104+
# Record values at specific steps
105+
if step % 500 == 0 or step == num_steps - 1:
106+
step_records.append(step)
107+
param_records.append(model.param.item())
108+
109+
# Print the final model parameter value (should be close to 1)
110+
print(f"\nFinal model parameter value: {model.param.item()}")
111+
112+
# Test different sigma_rel values
113+
sigma_rels = [0.05, 0.15, 0.27]
114+
ema_values = {}
115+
116+
for sigma_rel in sigma_rels:
117+
with posthoc_ema.state_dict(sigma_rel=sigma_rel) as state_dict:
118+
ema_values[sigma_rel] = state_dict["param"].item()
119+
print(f"EMA value with sigma_rel={sigma_rel}: {ema_values[sigma_rel]}")
120+
121+
# Verify that different sigma_rel values produce different results
122+
assert (
123+
ema_values[0.05] != ema_values[0.15]
124+
), "Different sigma_rel values should produce different results"
125+
assert (
126+
ema_values[0.15] != ema_values[0.27]
127+
), "Different sigma_rel values should produce different results"
128+
129+
# Verify that smaller sigma_rel values result in EMA values closer to recent values
130+
# In PostHocEMA, smaller sigma_rel values give more weight to recent checkpoints
131+
# Since our parameter is increasing from 0 to 1, recent values are closer to 1
132+
assert (
133+
ema_values[0.05] > ema_values[0.15] > ema_values[0.27]
134+
), "In PostHocEMA, smaller sigma_rel values should result in EMA values closer to recent values"
135+
136+
# Verify that all EMA values are less than the final model parameter (which is 1)
137+
# This is expected because EMA is a weighted average of past values, which were all < 1
138+
for sigma_rel, value in ema_values.items():
139+
assert (
140+
value < 1.0
141+
), f"EMA value with sigma_rel={sigma_rel} should be less than 1.0"
142+
143+
# Print the differences between EMA values and the final model parameter
144+
# These differences show how much each sigma_rel setting weights older vs. newer values
145+
for sigma_rel, value in ema_values.items():
146+
print(f"Difference from final value (sigma_rel={sigma_rel}): {1.0 - value}")
147+
148+
149+
def test_single_parameter_ema_visualization():
150+
"""
151+
Test that visualizes how different sigma_rel values affect a model with a single parameter.
152+
153+
This test:
154+
1. Creates a model with a single parameter initialized to 0
155+
2. Gradually updates the parameter to 1 over 5000 steps
156+
3. Records EMA values at regular intervals for different sigma_rel values
157+
4. Plots the results to visualize the effect of different sigma_rel values
158+
(in PostHocEMA, smaller sigma_rel values give more weight to recent values)
159+
"""
160+
# Create a model with a single parameter
161+
model = SingleParamModel(initial_value=0.0)
162+
163+
# Create EMA instance with multiple sigma_rel values
164+
posthoc_ema = PostHocEMA.from_model(
165+
model,
166+
"test-single-param-ema",
167+
checkpoint_every=100, # Save checkpoints more frequently
168+
sigma_rels=(0.05, 0.15, 0.27), # Multiple sigma_rel values
169+
update_after_step=0, # Start immediately
170+
)
171+
172+
# Number of steps to update from 0 to 1
173+
num_steps = 5000
174+
175+
# Record points for visualization
176+
# Only record every 500 steps to reduce test time
177+
record_steps = list(range(0, num_steps, 500)) + [num_steps - 1]
178+
179+
# Track parameter values at different steps
180+
step_records = []
181+
param_records = []
182+
ema_records = {0.05: [], 0.15: [], 0.27: []}
183+
184+
# Gradually update the parameter from 0 to 1
185+
for step in range(num_steps):
186+
# Linear interpolation from 0 to 1
187+
target_value = step / (num_steps - 1)
188+
189+
with torch.no_grad():
190+
model.param.copy_(torch.tensor([target_value], dtype=torch.float32))
191+
192+
posthoc_ema.update_(model)
193+
194+
# Only try to record EMA values after we've created some checkpoints
195+
# (after at least checkpoint_every steps)
196+
if step in record_steps and step >= 100: # First checkpoint created at step 100
197+
step_records.append(step)
198+
param_records.append(model.param.item())
199+
200+
# Record EMA values for different sigma_rel values
201+
for sigma_rel in ema_records.keys():
202+
with posthoc_ema.state_dict(sigma_rel=sigma_rel) as state_dict:
203+
ema_records[sigma_rel].append(state_dict["param"].item())
204+
205+
# Print final values
206+
print(f"\nFinal model parameter value: {param_records[-1]}")
207+
for sigma_rel, values in ema_records.items():
208+
print(f"Final EMA value with sigma_rel={sigma_rel}: {values[-1]}")
209+
210+
# Verify that smaller sigma_rel values result in faster adaptation
211+
assert (
212+
ema_records[0.05][-1] > ema_records[0.15][-1] > ema_records[0.27][-1]
213+
), "Smaller sigma_rel values should result in faster adaptation (values closer to 1)"
214+
215+
# Skip the actual plotting in automated tests
216+
# We'll just skip plotting in automated tests to avoid dependencies
217+
# Uncomment this section to generate plots when running manually
218+
"""
219+
# Plot the results
220+
plt.figure(figsize=(10, 6))
221+
plt.plot(step_records, param_records, "k-", label="Model Parameter")
222+
223+
for sigma_rel, values in ema_records.items():
224+
plt.plot(step_records, values, "--", label=f"EMA (sigma_rel={sigma_rel})")
225+
226+
plt.xlabel("Step")
227+
plt.ylabel("Parameter Value")
228+
plt.title("Effect of Different sigma_rel Values on EMA")
229+
plt.legend()
230+
plt.grid(True)
231+
plt.savefig("ema_parameter_comparison.png")
232+
plt.close()
233+
"""

0 commit comments

Comments
 (0)