Skip to content

Commit 2a65aab

Browse files
committed
doc: update beta for extreme sigma rel
1 parent 8fffb40 commit 2a65aab

File tree

2 files changed

+89
-13
lines changed

2 files changed

+89
-13
lines changed

README.md

Lines changed: 1 addition & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ Choose your EMA decay rate after training. No need to decide upfront.
55
The library uses `sigma_rel` (relative standard deviation) to parameterize EMA decay rates, which relates to the classical EMA decay rate `beta` as follows:
66

77
```python
8-
beta = 0.9000 # Fast decay -> sigma_rel ≈ 0.001
8+
beta = 0.3700 # Fast decay -> sigma_rel ≈ 0.001
99
beta = 0.9055 # Medium decay -> sigma_rel ≈ 0.01
1010
beta = 0.9680 # Medium decay -> sigma_rel ≈ 0.03
1111
beta = 0.9808 # Medium decay -> sigma_rel ≈ 0.05
@@ -141,15 +141,3 @@ posthoc_ema = PostHocEMA.from_model(
141141
volume = {abs/2402.09240}
142142
}
143143
```
144-
145-
```python
146-
beta = 0.9000 # Fast decay -> sigma_rel ≈ 0.001
147-
beta = 0.9055 # Medium decay -> sigma_rel ≈ 0.01
148-
beta = 0.9680 # Medium decay -> sigma_rel ≈ 0.03
149-
beta = 0.9808 # Medium decay -> sigma_rel ≈ 0.05
150-
beta = 0.9911 # Slow decay -> sigma_rel ≈ 0.10
151-
beta = 0.9944 # Slow decay -> sigma_rel ≈ 0.15
152-
beta = 0.9962 # Slow decay -> sigma_rel ≈ 0.20
153-
beta = 0.9979 # Slow decay -> sigma_rel ≈ 0.27
154-
beta = 0.9999 # Very slow decay -> sigma_rel ≈ 0.40
155-
```

tests/test_reference_practical.py

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -948,3 +948,91 @@ def test_compare_large_sigma_rel_with_traditional_ema():
948948
"ema_values": ema_values,
949949
"best_beta": best_beta,
950950
}
951+
952+
953+
def test_compare_very_small_sigma_rel_with_precise_beta():
954+
"""
955+
Test to compare sigma_rel=0.001 with beta values around 0.4 for more precision.
956+
957+
This test:
958+
1. Runs PostHocEMA with sigma_rel=0.001
959+
2. Runs traditional EMA with a range of beta values around 0.4
960+
3. Prints the comparison to help identify the most precise match
961+
"""
962+
# Common parameters
963+
num_steps = 1000
964+
965+
# Create models
966+
model_posthoc = SingleParamModel(initial_value=0.0)
967+
968+
# Import our implementation
969+
from posthoc_ema import PostHocEMA
970+
971+
# Create our PostHocEMA instance
972+
our_ema = PostHocEMA.from_model(
973+
model_posthoc,
974+
"./test-our-ema",
975+
sigma_rels=(0.001,),
976+
update_every=1, # Update every step for more precision
977+
checkpoint_every=50,
978+
update_after_step=0, # Start immediately
979+
)
980+
981+
# Traditional EMA beta values to test, focusing around 0.4
982+
beta_values = [0.35, 0.36, 0.37, 0.38, 0.39, 0.40, 0.41, 0.42, 0.43, 0.44, 0.45]
983+
984+
# Initialize EMA values for each beta
985+
ema_values = {beta: 0.0 for beta in beta_values}
986+
987+
# Gradually update the model from 0 to 1
988+
for step in range(num_steps):
989+
# Linear interpolation from 0 to 1
990+
target_value = step / (num_steps - 1)
991+
992+
with torch.no_grad():
993+
model_posthoc.param.copy_(torch.tensor([target_value], dtype=torch.float32))
994+
995+
# Update traditional EMA values
996+
for beta in beta_values:
997+
ema_values[beta] = beta * ema_values[beta] + (1 - beta) * target_value
998+
999+
our_ema.update_(model_posthoc)
1000+
1001+
# Get PostHocEMA value
1002+
with our_ema.state_dict(sigma_rel=0.001) as state_dict:
1003+
posthoc_value = state_dict["param"].item()
1004+
1005+
print(f"\nPostHocEMA value for sigma_rel=0.001: {posthoc_value:.6f}")
1006+
1007+
# Print comparison with traditional EMA values
1008+
print("\nComparison with beta values around 0.4:")
1009+
for beta in sorted(beta_values):
1010+
diff = abs(ema_values[beta] - posthoc_value)
1011+
print(
1012+
f" Beta={beta:.2f}: {ema_values[beta]:.6f} (diff: {diff:.6f}, {diff/posthoc_value*100:.4f}%)"
1013+
)
1014+
1015+
# Find the closest match
1016+
best_beta = None
1017+
best_diff = float("inf")
1018+
for beta in beta_values:
1019+
diff = abs(ema_values[beta] - posthoc_value)
1020+
if diff < best_diff:
1021+
best_diff = diff
1022+
best_beta = beta
1023+
1024+
print(
1025+
f"\nClosest match: beta={best_beta:.2f} with difference {best_diff:.6f} ({best_diff/posthoc_value*100:.4f}%)"
1026+
)
1027+
1028+
# Determine decay speed category
1029+
decay_speed = "Very fast decay"
1030+
1031+
print(f"\nREADME mapping entry:")
1032+
print(f"beta = {best_beta:.2f} # {decay_speed} -> sigma_rel ≈ 0.001")
1033+
1034+
return {
1035+
"posthoc_value": posthoc_value,
1036+
"ema_values": ema_values,
1037+
"best_beta": best_beta,
1038+
}

0 commit comments

Comments
 (0)