You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
The PyTorch Post-hoc EMA library improves neural network performance by applying Exponential Moving Average (EMA) techniques after training. This approach allows for the adjustment of EMA profiles post-training, which is crucial for optimizing model weight stabilization without predefining decay parameters.
4
-
5
-
By implementing the post-hoc synthesized EMA method from Karras et al., the library offers flexibility in exploring EMA profiles' effects on training and sampling. It seamlessly integrates with PyTorch models, making it easy to enhance machine learning projects with post-hoc EMA adjustments.
6
-
7
-
This library was adapted from [ema-pytorch](https://github.com/lucidrains/ema-pytorch) by lucidrains.
3
+
Choose your EMA decay rate after training. No need to decide upfront.
8
4
9
5
The library uses `sigma_rel` (relative standard deviation) to parameterize EMA decay rates, which relates to the classical EMA decay rate `beta` as follows:
with posthoc_ema.state_dict(sigma_rel=0.15) as state_dict:
69
+
model.load_state_dict(state_dict, strict=False)
86
70
```
87
71
88
-
Set parameters to EMA state during training:
72
+
## Advanced Usage
73
+
74
+
### Switch EMA During Training
89
75
90
76
```python
91
-
with posthoc_ema.state_dict(sigma_rel=0.15) as ema_state_dict:
92
-
result = model.load_state_dict(ema_state_dict, strict=False)
93
-
assertlen(result.unexpected_keys) ==0
77
+
with posthoc_ema.state_dict(sigma_rel=0.15) as state_dict:
78
+
model.load_state_dict(state_dict, strict=False)
94
79
```
95
80
96
-
You can visualize how well different EMA decay rates can be reconstructed from the stored checkpoints:
81
+
### Visualize Reconstruction Quality
97
82
98
83
```python
99
84
posthoc_ema.reconstruction_error()
100
85
```
101
86
102
-
## Configuration
103
-
104
-
PostHocEMA provides several configuration options to customize its behavior:
87
+
### Configuration
105
88
106
89
```python
107
90
posthoc_ema = PostHocEMA.from_model(
108
91
model,
109
92
checkpoint_dir="path/to/checkpoints",
110
-
max_checkpoints=20, # Keep last 20 checkpoints per EMA model (default=20)
93
+
max_checkpoints=20, # Keep last 20 checkpoints per EMA model
111
94
sigma_rels=(0.05, 0.28), # Default relative standard deviations from paper
112
-
update_every=10, # Update EMA weights every 10 steps (default)
113
-
checkpoint_every=1000, # Create checkpoints every 1000 steps (default)
114
-
checkpoint_dtype=torch.float16, # Store checkpoints in half precision (default is no change)
95
+
update_every=10, # Update EMA weights every 10 steps
96
+
checkpoint_every=1000, # Create checkpoints every 1000 steps
97
+
checkpoint_dtype=torch.float16, # Store checkpoints in half precision
115
98
)
116
99
```
117
100
118
-
The default values are chosen based on the original paper:
119
-
120
-
-`max_checkpoints=20`: The paper notes that "a few dozen snapshots is more than sufficient for a virtually perfect EMA reconstruction"
121
-
-`sigma_rels=(0.05, 0.28)`: These correspond to γ₁=16.97 and γ₂=6.94 from the paper
122
-
-`checkpoint_every=1000`: While the paper used 4096 steps between checkpoints, we default to more frequent checkpoints for better granularity
123
-
124
-
### Relationship between sigma_rel and beta
125
-
126
-
The paper introduces `sigma_rel` as an alternative parameterization to the classical EMA decay rate `beta`. You can use either parameterization by specifying `betas` or `sigma_rels` when creating the EMA. The `sigma_rel` value represents the relative standard deviation of the EMA weights, while `beta` is the classical decay rate. Lower `sigma_rel` values (or higher `beta` values) result in slower decay and more stable averages.
127
-
128
101
## Citations
129
102
130
103
```bibtex
@@ -133,8 +106,7 @@ The paper introduces `sigma_rel` as an alternative parameterization to the class
133
106
author = {Tero Karras and Miika Aittala and Jaakko Lehtinen and Janne Hellsten and Timo Aila and Samuli Laine},
@@ -155,7 +126,6 @@ The paper introduces `sigma_rel` as an alternative parameterization to the class
155
126
author = {Siyuan Li and Zicheng Liu and Juanxi Tian and Ge Wang and Zedong Wang and Weiyang Jin and Di Wu and Cheng Tan and Tao Lin and Yang Liu and Baigui Sun and Stan Z. Li},
0 commit comments