Skip to content

Commit 6128fe2

Browse files
committed
doc: easier to follow intro
1 parent dff4a55 commit 6128fe2

File tree

2 files changed

+28
-58
lines changed

2 files changed

+28
-58
lines changed

README.md

Lines changed: 27 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,6 @@
1-
# pytorch post-hoc ema
1+
# pytorch-posthoc-ema
22

3-
The PyTorch Post-hoc EMA library improves neural network performance by applying Exponential Moving Average (EMA) techniques after training. This approach allows for the adjustment of EMA profiles post-training, which is crucial for optimizing model weight stabilization without predefining decay parameters.
4-
5-
By implementing the post-hoc synthesized EMA method from Karras et al., the library offers flexibility in exploring EMA profiles' effects on training and sampling. It seamlessly integrates with PyTorch models, making it easy to enhance machine learning projects with post-hoc EMA adjustments.
6-
7-
This library was adapted from [ema-pytorch](https://github.com/lucidrains/ema-pytorch) by lucidrains.
3+
Choose your EMA decay rate after training. No need to decide upfront.
84

95
The library uses `sigma_rel` (relative standard deviation) to parameterize EMA decay rates, which relates to the classical EMA decay rate `beta` as follows:
106

@@ -14,15 +10,17 @@ beta = 0.999 # Medium decay -> sigma_rel ≈ 0.15
1410
beta = 0.99 # Fast decay -> sigma_rel ≈ 0.28
1511
```
1612

13+
This library was adapted from [ema-pytorch](https://github.com/lucidrains/ema-pytorch) by lucidrains.
14+
1715
New features and changes:
1816

17+
- No extra VRAM usage by keeping EMA on cpu
18+
- No extra VRAM usage for EMA synthesis during evaluation
19+
- Low RAM usage for EMA synthesis
1920
- Simplified or more explicit usage
2021
- Opinionated defaults
2122
- Select number of checkpoints to keep
2223
- Allow "Switch EMA" with PostHocEMA
23-
- No extra VRAM usage by keeping EMA on cpu
24-
- No extra VRAM usage for synthesization during evaluation
25-
- Low RAM usage for synthesis
2624
- Visualization of EMA reconstruction error before training
2725

2826
## Install
@@ -31,7 +29,7 @@ New features and changes:
3129
poetry add pytorch-posthoc-ema
3230
```
3331

34-
## Usage
32+
## Basic Usage
3533

3634
```python
3735
import torch
@@ -42,7 +40,6 @@ model = torch.nn.Linear(512, 512)
4240
posthoc_ema = PostHocEMA.from_model(model, "posthoc-ema")
4341

4442
for _ in range(1000):
45-
4643
# mutate your network, normally with an optimizer
4744
with torch.no_grad():
4845
model.weight.copy_(torch.randn_like(model.weight))
@@ -56,75 +53,51 @@ predictions = model(data)
5653
# use the helper
5754
with posthoc_ema.model(model, sigma_rel=0.15) as ema_model:
5855
ema_predictions = ema_model(data)
59-
60-
# or without magic
61-
model.cpu()
62-
63-
with posthoc_ema.state_dict(sigma_rel=0.15) as ema_state_dict:
64-
ema_model = deepcopy(model)
65-
ema_model.load_state_dict(ema_state_dict)
66-
ema_predictions = ema_model(data)
67-
del ema_model
6856
```
6957

70-
Synthesize after training:
58+
### Load After Training
7159

7260
```python
61+
# With model
7362
posthoc_ema = PostHocEMA.from_path("posthoc-ema", model)
74-
7563
with posthoc_ema.model(model, sigma_rel=0.15) as ema_model:
7664
ema_predictions = ema_model(data)
77-
```
78-
79-
Or without model:
8065

81-
```python
66+
# Without model
8267
posthoc_ema = PostHocEMA.from_path("posthoc-ema")
83-
84-
with posthoc_ema.state_dict(sigma_rel=0.15) as ema_state_dict:
85-
model.load_state_dict(ema_state_dict, strict=False)
68+
with posthoc_ema.state_dict(sigma_rel=0.15) as state_dict:
69+
model.load_state_dict(state_dict, strict=False)
8670
```
8771

88-
Set parameters to EMA state during training:
72+
## Advanced Usage
73+
74+
### Switch EMA During Training
8975

9076
```python
91-
with posthoc_ema.state_dict(sigma_rel=0.15) as ema_state_dict:
92-
result = model.load_state_dict(ema_state_dict, strict=False)
93-
assert len(result.unexpected_keys) == 0
77+
with posthoc_ema.state_dict(sigma_rel=0.15) as state_dict:
78+
model.load_state_dict(state_dict, strict=False)
9479
```
9580

96-
You can visualize how well different EMA decay rates can be reconstructed from the stored checkpoints:
81+
### Visualize Reconstruction Quality
9782

9883
```python
9984
posthoc_ema.reconstruction_error()
10085
```
10186

102-
## Configuration
103-
104-
PostHocEMA provides several configuration options to customize its behavior:
87+
### Configuration
10588

10689
```python
10790
posthoc_ema = PostHocEMA.from_model(
10891
model,
10992
checkpoint_dir="path/to/checkpoints",
110-
max_checkpoints=20, # Keep last 20 checkpoints per EMA model (default=20)
93+
max_checkpoints=20, # Keep last 20 checkpoints per EMA model
11194
sigma_rels=(0.05, 0.28), # Default relative standard deviations from paper
112-
update_every=10, # Update EMA weights every 10 steps (default)
113-
checkpoint_every=1000, # Create checkpoints every 1000 steps (default)
114-
checkpoint_dtype=torch.float16, # Store checkpoints in half precision (default is no change)
95+
update_every=10, # Update EMA weights every 10 steps
96+
checkpoint_every=1000, # Create checkpoints every 1000 steps
97+
checkpoint_dtype=torch.float16, # Store checkpoints in half precision
11598
)
11699
```
117100

118-
The default values are chosen based on the original paper:
119-
120-
- `max_checkpoints=20`: The paper notes that "a few dozen snapshots is more than sufficient for a virtually perfect EMA reconstruction"
121-
- `sigma_rels=(0.05, 0.28)`: These correspond to γ₁=16.97 and γ₂=6.94 from the paper
122-
- `checkpoint_every=1000`: While the paper used 4096 steps between checkpoints, we default to more frequent checkpoints for better granularity
123-
124-
### Relationship between sigma_rel and beta
125-
126-
The paper introduces `sigma_rel` as an alternative parameterization to the classical EMA decay rate `beta`. You can use either parameterization by specifying `betas` or `sigma_rels` when creating the EMA. The `sigma_rel` value represents the relative standard deviation of the EMA weights, while `beta` is the classical decay rate. Lower `sigma_rel` values (or higher `beta` values) result in slower decay and more stable averages.
127-
128101
## Citations
129102

130103
```bibtex
@@ -133,8 +106,7 @@ The paper introduces `sigma_rel` as an alternative parameterization to the class
133106
author = {Tero Karras and Miika Aittala and Jaakko Lehtinen and Janne Hellsten and Timo Aila and Samuli Laine},
134107
journal = {ArXiv},
135108
year = {2023},
136-
volume = {abs/2312.02696},
137-
url = {https://api.semanticscholar.org/CorpusID:265659032}
109+
volume = {abs/2312.02696}
138110
}
139111
```
140112

@@ -144,8 +116,7 @@ The paper introduces `sigma_rel` as an alternative parameterization to the class
144116
author = {Hojoon Lee and Hyeonseo Cho and Hyunseung Kim and Donghu Kim and Dugki Min and Jaegul Choo and Clare Lyle},
145117
journal = {ArXiv},
146118
year = {2024},
147-
volume = {abs/2406.02596},
148-
url = {https://api.semanticscholar.org/CorpusID:270258586}
119+
volume = {abs/2406.02596}
149120
}
150121
```
151122

@@ -155,7 +126,6 @@ The paper introduces `sigma_rel` as an alternative parameterization to the class
155126
author = {Siyuan Li and Zicheng Liu and Juanxi Tian and Ge Wang and Zedong Wang and Weiyang Jin and Di Wu and Cheng Tan and Tao Lin and Yang Liu and Baigui Sun and Stan Z. Li},
156127
journal = {ArXiv},
157128
year = {2024},
158-
volume = {abs/2402.09240},
159-
url = {https://api.semanticscholar.org/CorpusID:267657558}
129+
volume = {abs/2402.09240}
160130
}
161131
```

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
[tool.poetry]
22
name = "pytorch-posthoc-ema"
33
version = "0.0.0"
4-
description = ""
4+
description = "Post-hoc EMA synthesis for PyTorch"
55
authors = ["Phil Wang <[email protected]>, Richard Löwenström <[email protected]>"]
66
license = "MIT"
77
readme = "README.md"

0 commit comments

Comments
 (0)