Skip to content

Commit ae42f48

Browse files
committed
prepare so that unet can work with a channel of one, and also make it so image size is hard coded in diffusion class. preparing for training on protein distograms
1 parent 5989f4c commit ae42f48

File tree

3 files changed

+42
-24
lines changed

3 files changed

+42
-24
lines changed

README.md

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ model = Unet(
2727

2828
diffusion = GaussianDiffusion(
2929
model,
30+
image_size = 128,
3031
timesteps = 1000, # number of steps
3132
loss_type = 'l1' # L1 or L2
3233
)
@@ -36,7 +37,7 @@ loss = diffusion(training_images)
3637
loss.backward()
3738
# after a lot of training
3839

39-
sampled_images = diffusion.sample(128, batch_size = 4)
40+
sampled_images = diffusion.sample(batch_size = 4)
4041
sampled_images.shape # (4, 3, 128, 128)
4142
```
4243

@@ -52,14 +53,14 @@ model = Unet(
5253

5354
diffusion = GaussianDiffusion(
5455
model,
56+
image_size = 128,
5557
timesteps = 1000, # number of steps
5658
loss_type = 'l1' # L1 or L2
5759
).cuda()
5860

5961
trainer = Trainer(
6062
diffusion,
6163
'path/to/your/images',
62-
image_size = 128,
6364
train_batch_size = 32,
6465
train_lr = 2e-5,
6566
train_num_steps = 700000, # total training steps
@@ -77,23 +78,22 @@ Samples and model checkpoints will be logged to `./results` periodically
7778

7879
```bibtex
7980
@misc{ho2020denoising,
80-
title={Denoising Diffusion Probabilistic Models},
81-
author={Jonathan Ho and Ajay Jain and Pieter Abbeel},
82-
year={2020},
83-
eprint={2006.11239},
84-
archivePrefix={arXiv},
85-
primaryClass={cs.LG}
81+
title = {Denoising Diffusion Probabilistic Models},
82+
author = {Jonathan Ho and Ajay Jain and Pieter Abbeel},
83+
year = {2020},
84+
eprint = {2006.11239},
85+
archivePrefix = {arXiv},
86+
primaryClass = {cs.LG}
8687
}
8788
```
8889

8990
```bibtex
90-
@inproceedings{
91-
anonymous2021improved,
92-
title={Improved Denoising Diffusion Probabilistic Models},
93-
author={Anonymous},
94-
booktitle={Submitted to International Conference on Learning Representations},
95-
year={2021},
96-
url={https://openreview.net/forum?id=-NEXDKk8gZ},
97-
note={under review}
91+
@inproceedings{anonymous2021improved,
92+
title = {Improved Denoising Diffusion Probabilistic Models},
93+
author = {Anonymous},
94+
booktitle = {Submitted to International Conference on Learning Representations},
95+
year = {2021},
96+
url = {https://openreview.net/forum?id=-NEXDKk8gZ},
97+
note = {under review}
9898
}
9999
```

denoising_diffusion_pytorch/denoising_diffusion_pytorch.py

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -181,9 +181,16 @@ def forward(self, x):
181181
# model
182182

183183
class Unet(nn.Module):
184-
def __init__(self, dim, out_dim = None, dim_mults=(1, 2, 4, 8), groups = 8):
184+
def __init__(
185+
self,
186+
dim,
187+
out_dim = None,
188+
dim_mults=(1, 2, 4, 8),
189+
groups = 8,
190+
channels = 3
191+
):
185192
super().__init__()
186-
dims = [3, *map(lambda m: dim * m, dim_mults)]
193+
dims = [channels, *map(lambda m: dim * m, dim_mults)]
187194
in_out = list(zip(dims[:-1], dims[1:]))
188195

189196
self.time_pos_emb = SinusoidalPosEmb(dim)
@@ -279,8 +286,17 @@ def cosine_beta_schedule(timesteps, s = 0.008):
279286
return np.clip(betas, a_min = 0, a_max = 0.999)
280287

281288
class GaussianDiffusion(nn.Module):
282-
def __init__(self, denoise_fn, timesteps=1000, loss_type='l1', betas = None):
289+
def __init__(
290+
self,
291+
denoise_fn,
292+
*,
293+
image_size,
294+
timesteps = 1000,
295+
loss_type = 'l1',
296+
betas = None
297+
):
283298
super().__init__()
299+
self.image_size = image_size
284300
self.denoise_fn = denoise_fn
285301

286302
if exists(betas):
@@ -371,7 +387,8 @@ def p_sample_loop(self, shape):
371387
return img
372388

373389
@torch.no_grad()
374-
def sample(self, image_size, batch_size = 16):
390+
def sample(self, batch_size = 16):
391+
image_size = self.image_size
375392
return self.p_sample_loop((batch_size, 3, image_size, image_size))
376393

377394
@torch.no_grad()
@@ -415,7 +432,8 @@ def p_losses(self, x_start, t, noise = None):
415432
return loss
416433

417434
def forward(self, x, *args, **kwargs):
418-
b, *_, device = *x.shape, x.device
435+
b, c, h, w, device, img_size, = *x.shape, x.device, self.image_size
436+
assert h == img_size and w == img_size, f'height and width of image must be {img_size}'
419437
t = torch.randint(0, self.num_timesteps, (b,), device=device).long()
420438
return self.p_losses(x, t, *args, **kwargs)
421439

@@ -467,7 +485,7 @@ def __init__(
467485
self.step_start_ema = step_start_ema
468486

469487
self.batch_size = train_batch_size
470-
self.image_size = image_size
488+
self.image_size = diffusion_model.image_size
471489
self.gradient_accumulate_every = gradient_accumulate_every
472490
self.train_num_steps = train_num_steps
473491

@@ -528,7 +546,7 @@ def train(self):
528546
if self.step != 0 and self.step % SAVE_AND_SAMPLE_EVERY == 0:
529547
milestone = self.step // SAVE_AND_SAMPLE_EVERY
530548
batches = num_to_groups(36, self.batch_size)
531-
all_images_list = list(map(lambda n: self.ema_model.sample(self.image_size, batch_size=n), batches))
549+
all_images_list = list(map(lambda n: self.ema_model.sample(batch_size=n), batches))
532550
all_images = torch.cat(all_images_list, dim=0)
533551
utils.save_image(all_images, str(RESULTS_FOLDER / f'sample-{milestone}.png'), nrow=6)
534552
self.save(milestone)

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'denoising-diffusion-pytorch',
55
packages = find_packages(),
6-
version = '0.5.2',
6+
version = '0.6.0',
77
license='MIT',
88
description = 'Denoising Diffusion Probabilistic Models - Pytorch',
99
author = 'Phil Wang',

0 commit comments

Comments
 (0)