Skip to content

Commit 5b066a6

Browse files
committed
the magnitude preserving unet works best with inverse square root decay learning schedule
1 parent fd5abb9 commit 5b066a6

File tree

3 files changed

+18
-2
lines changed

3 files changed

+18
-2
lines changed

denoising_diffusion_pytorch/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,4 +8,4 @@
88

99
from denoising_diffusion_pytorch.denoising_diffusion_pytorch_1d import GaussianDiffusion1D, Unet1D, Trainer1D, Dataset1D
1010

11-
from denoising_diffusion_pytorch.karras_unet import KarrasUnet
11+
from denoising_diffusion_pytorch.karras_unet import KarrasUnet, InvSqrtDecayLRSched

denoising_diffusion_pytorch/karras_unet.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import torch
1010
from torch import nn, einsum
1111
from torch.nn import Module, ModuleList
12+
from torch.optim.lr_scheduler import LambdaLR
1213
import torch.nn.functional as F
1314

1415
from einops import rearrange, repeat, pack, unpack
@@ -680,6 +681,21 @@ def forward(self, x):
680681

681682
return x
682683

684+
# works best with inverse square root decay schedule
685+
686+
def InvSqrtDecayLRSched(
687+
optimizer,
688+
t_ref = 70000,
689+
sigma_ref = 0.01
690+
):
691+
"""
692+
refer to equation 67 and Table1
693+
"""
694+
def inv_sqrt_decay_fn(step: int):
695+
return sigma_ref / sqrt(max(t / t_ref, 1.))
696+
697+
return LambdaLR(optimizer, lr_lambda = inv_sqrt_decay_fn)
698+
683699
# example
684700

685701
if __name__ == '__main__':
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = '1.10.5'
1+
__version__ = '1.10.6'

0 commit comments

Comments
 (0)