Skip to content

Commit 4bf2891

Browse files
committed
allow for mixed precision training with fp16 flag
1 parent 88f83d0 commit 4bf2891

File tree

3 files changed

+49
-12
lines changed

3 files changed

+49
-12
lines changed

README.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,8 @@ trainer = Trainer(
6666
train_lr = 2e-5,
6767
train_num_steps = 100000, # total training steps
6868
gradient_accumulate_every = 2, # gradient accumulation steps
69-
ema_decay = 0.995 # exponential moving average decay
69+
ema_decay = 0.995, # exponential moving average decay
70+
fp16 = True # turn on mixed precision training with apex
7071
)
7172

7273
trainer.train()

denoising_diffusion_pytorch/denoising_diffusion_pytorch.py

Lines changed: 46 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,12 @@
1616
from tqdm import tqdm
1717
from einops import rearrange
1818

19+
try:
20+
from apex import amp
21+
APEX_AVAILABLE = True
22+
except:
23+
APEX_AVAILABLE = False
24+
1925
# constants
2026

2127
SAVE_AND_SAMPLE_EVERY = 1000
@@ -37,6 +43,13 @@ def cycle(dl):
3743
for data in dl:
3844
yield data
3945

46+
def loss_backwards(fp16, loss, optimizer, **kwargs):
47+
if fp16:
48+
with amp.scale_loss(loss, optimizer) as scaled_loss:
49+
scaled_loss.backward(**kwargs)
50+
else:
51+
loss.backward(**kwargs)
52+
4053
# small helper modules
4154

4255
class EMA():
@@ -107,7 +120,7 @@ def forward(self, x):
107120
# building block modules
108121

109122
class Block(nn.Module):
110-
def __init__(self, dim, dim_out, groups = 32):
123+
def __init__(self, dim, dim_out, groups = 8):
111124
super().__init__()
112125
self.block = nn.Sequential(
113126
nn.Conv2d(dim, dim_out, 3, padding=1),
@@ -118,7 +131,7 @@ def forward(self, x):
118131
return self.block(x)
119132

120133
class ResnetBlock(nn.Module):
121-
def __init__(self, dim, dim_out, *, time_emb_dim, groups = 32):
134+
def __init__(self, dim, dim_out, *, time_emb_dim, groups = 8):
122135
super().__init__()
123136
self.mlp = nn.Sequential(
124137
Mish(),
@@ -157,7 +170,7 @@ def forward(self, x):
157170
# model
158171

159172
class Unet(nn.Module):
160-
def __init__(self, dim, out_dim = None, dim_mults=(1, 2, 4, 8), groups = 32):
173+
def __init__(self, dim, out_dim = None, dim_mults=(1, 2, 4, 8), groups = 8):
161174
super().__init__()
162175
dims = [3, *map(lambda m: dim * m, dim_mults)]
163176
in_out = list(zip(dims[:-1], dims[1:]))
@@ -178,6 +191,7 @@ def __init__(self, dim, out_dim = None, dim_mults=(1, 2, 4, 8), groups = 32):
178191

179192
self.downs.append(nn.ModuleList([
180193
ResnetBlock(dim_in, dim_out, time_emb_dim = dim),
194+
ResnetBlock(dim_out, dim_out, time_emb_dim = dim),
181195
Residual(Rezero(LinearAttention(dim_out))),
182196
Downsample(dim_out) if not is_last else nn.Identity()
183197
]))
@@ -192,6 +206,7 @@ def __init__(self, dim, out_dim = None, dim_mults=(1, 2, 4, 8), groups = 32):
192206

193207
self.ups.append(nn.ModuleList([
194208
ResnetBlock(dim_out * 2, dim_in, time_emb_dim = dim),
209+
ResnetBlock(dim_in, dim_in, time_emb_dim = dim),
195210
Residual(Rezero(LinearAttention(dim_in))),
196211
Upsample(dim_in) if not is_last else nn.Identity()
197212
]))
@@ -208,8 +223,9 @@ def forward(self, x, time):
208223

209224
h = []
210225

211-
for resnet, attn, downsample in self.downs:
226+
for resnet, resnet2, attn, downsample in self.downs:
212227
x = resnet(x, t)
228+
x = resnet2(x, t)
213229
x = attn(x)
214230
h.append(x)
215231
x = downsample(x)
@@ -218,9 +234,10 @@ def forward(self, x, time):
218234
x = self.mid_attn(x)
219235
x = self.mid_block2(x, t)
220236

221-
for resnet, attn, upsample in self.ups:
237+
for resnet, resnet2, attn, upsample in self.ups:
222238
x = torch.cat((x, h.pop()), dim=1)
223239
x = resnet(x, t)
240+
x = resnet2(x, t)
224241
x = attn(x)
225242
x = upsample(x)
226243

@@ -417,23 +434,40 @@ def __init__(
417434
train_lr = 2e-5,
418435
train_num_steps = 100000,
419436
gradient_accumulate_every = 2,
437+
fp16 = False
420438
):
421439
super().__init__()
422440
self.model = diffusion_model
441+
self.ema = EMA(ema_decay)
442+
self.ema_model = copy.deepcopy(self.model)
423443

424444
self.image_size = image_size
425445
self.gradient_accumulate_every = gradient_accumulate_every
426446
self.train_num_steps = train_num_steps
427447

428-
self.ema = EMA(ema_decay)
429-
self.ema_model = copy.deepcopy(self.model)
430-
431448
self.ds = Dataset(folder, image_size)
432449
self.dl = cycle(data.DataLoader(self.ds, batch_size = train_batch_size, shuffle=True, pin_memory=True))
433450
self.opt = Adam(diffusion_model.parameters(), lr=train_lr)
434451

435452
self.step = 0
436453

454+
assert not fp16 or fp16 and APEX_AVAILABLE, 'Apex must be installed in order for mixed precision training to be turned on'
455+
456+
self.fp16 = fp16
457+
if fp16:
458+
(self.model, self.ema_model), self.opt = amp.initialize([self.model, self.ema_model], self.opt, opt_level='O1')
459+
460+
self.reset_parameters()
461+
462+
def reset_parameters(self):
463+
self.ema_model.load_state_dict(self.model.state_dict())
464+
465+
def step_ema(self):
466+
if self.step < 2000:
467+
self.reset_parameters()
468+
return
469+
self.ema.update_model_average(self.ema_model, self.model)
470+
437471
def save(self, milestone):
438472
data = {
439473
'step': self.step,
@@ -450,18 +484,20 @@ def load(self, milestone):
450484
self.ema_model.load_state_dict(data['ema'])
451485

452486
def train(self):
487+
backwards = partial(loss_backwards, self.fp16)
488+
453489
while self.step < self.train_num_steps:
454490
for i in range(self.gradient_accumulate_every):
455491
data = next(self.dl).cuda()
456492
loss = self.model(data)
457493
print(f'{self.step}: {loss.item()}')
458-
(loss / self.gradient_accumulate_every).backward()
494+
backwards(loss / self.gradient_accumulate_every, self.opt)
459495

460496
self.opt.step()
461497
self.opt.zero_grad()
462498

463499
if self.step % UPDATE_EMA_EVERY == 0:
464-
self.ema.update_model_average(self.ema_model, self.model)
500+
self.step_ema()
465501

466502
if self.step % SAVE_AND_SAMPLE_EVERY == 0:
467503
milestone = self.step // SAVE_AND_SAMPLE_EVERY

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'denoising-diffusion-pytorch',
55
packages = find_packages(),
6-
version = '0.2.2',
6+
version = '0.2.4',
77
license='MIT',
88
description = 'Denoising Diffusion Probabilistic Models - Pytorch',
99
author = 'Phil Wang',

0 commit comments

Comments
 (0)