Skip to content

Commit 5cd08b2

Browse files
committed
complete first pass
1 parent 16601b8 commit 5cd08b2

File tree

2 files changed

+263
-120
lines changed

2 files changed

+263
-120
lines changed

README.md

Lines changed: 24 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -19,30 +19,33 @@ $ pip install rin-pytorch
1919
## Usage
2020

2121
```python
22-
from rin_pytorch import RIN, Trainer, GaussianDiffusion
22+
from rin_pytorch import GaussianDiffusion, RIN, Trainer
2323

2424
model = RIN(
25-
dim = 32,
26-
channels = 3,
27-
dim_mults = (1, 2, 4, 8),
25+
dim = 256, # model dimensions
26+
image_size = 128, # image size
27+
patch_size = 8, # patch size
28+
num_latents = 128, # number of latents. they used 256 in the paper
29+
latent_self_attn_depth = 4, # number of latent self attention blocks per recurrent step, K in the paper
2830
).cuda()
2931

3032
diffusion = GaussianDiffusion(
3133
model,
3234
image_size = 128,
33-
timesteps = 100,
34-
use_ddim = True # use ddim
35+
use_ddim = False,
36+
timesteps = 400,
37+
train_prob_self_cond = 0.9 # how often to self condition on latents
3538
).cuda()
3639

3740
trainer = Trainer(
3841
diffusion,
39-
'/path/to/your/data', # path to your folder of images
40-
results_folder = './results', # where to save results
41-
num_samples = 16, # number of samples
42-
train_batch_size = 4, # training batch size
43-
gradient_accumulate_every = 4, # gradient accumulation
44-
train_lr = 1e-4, # learning rate
45-
save_and_sample_every = 1000, # how often to save and sample
42+
'/home/phil/dl/data/flowers',
43+
results_folder = './rin',
44+
num_samples = 16,
45+
train_batch_size = 4,
46+
gradient_accumulate_every = 4,
47+
train_lr = 1e-4,
48+
save_and_sample_every = 1000,
4649
train_num_steps = 700000, # total training steps
4750
ema_decay = 0.995, # exponential moving average decay
4851
)
@@ -59,14 +62,18 @@ import torch
5962
from rin_pytorch import RIN, GaussianDiffusion
6063

6164
model = RIN(
62-
dim = 64,
63-
dim_mults = (1, 2, 4, 8)
64-
)
65+
dim = 256, # model dimensions
66+
image_size = 128, # image size
67+
patch_size = 8, # patch size
68+
num_latents = 128, # number of latents. they used 256 in the paper
69+
latent_self_attn_depth = 4, # number of latent self attention blocks per recurrent step, K in the paper
70+
).cuda()
6571

6672
diffusion = GaussianDiffusion(
6773
model,
6874
image_size = 128,
69-
timesteps = 1000
75+
timesteps = 1000,
76+
train_prob_self_cond = 0.9
7077
)
7178

7279
training_images = torch.randn(8, 3, 128, 128) # images are normalized from 0 to 1

0 commit comments

Comments
 (0)