Skip to content

Commit f180e42

Browse files
committed
some boilerplate for continuous gaussian diffusion with sigmoid noise schedule formulated in log snr form
1 parent 74a1337 commit f180e42

File tree

4 files changed

+825
-1
lines changed

4 files changed

+825
-1
lines changed

README.md

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,74 @@ Implementation of <a href="https://arxiv.org/abs/2212.11972">Recurrent Interface
88

99
The big surprise is that the generations can reach this level of fidelity. Will need to verify this on my own machine
1010

11+
## Install
12+
13+
```bash
14+
$ pip install rin-pytorch
15+
```
16+
17+
## Usage
18+
19+
```python
20+
from rin_pytorch import RIN, Trainer, GaussianDiffusion
21+
22+
model = RIN(
23+
dim = 32,
24+
channels = 3,
25+
dim_mults = (1, 2, 4, 8),
26+
).cuda()
27+
28+
diffusion = GaussianDiffusion(
29+
model,
30+
image_size = 128,
31+
timesteps = 100,
32+
use_ddim = True # use ddim
33+
).cuda()
34+
35+
trainer = Trainer(
36+
diffusion,
37+
'/path/to/your/data', # path to your folder of images
38+
results_folder = './results', # where to save results
39+
num_samples = 16, # number of samples
40+
train_batch_size = 4, # training batch size
41+
gradient_accumulate_every = 4, # gradient accumulation
42+
train_lr = 1e-4, # learning rate
43+
save_and_sample_every = 1000, # how often to save and sample
44+
train_num_steps = 700000, # total training steps
45+
ema_decay = 0.995, # exponential moving average decay
46+
)
47+
48+
trainer.train()
49+
```
50+
51+
Results will be saved periodically to the `./results` folder
52+
53+
If you would like to experiment with the `RIN` and `GaussianDiffusion` class outside the `Trainer`
54+
55+
```python
56+
import torch
57+
from rin_pytorch import RIN, GaussianDiffusion
58+
59+
model = RIN(
60+
dim = 64,
61+
dim_mults = (1, 2, 4, 8)
62+
)
63+
64+
diffusion = GaussianDiffusion(
65+
model,
66+
image_size = 128,
67+
timesteps = 1000
68+
)
69+
70+
training_images = torch.randn(8, 3, 128, 128) # images are normalized from 0 to 1
71+
loss = diffusion(training_images)
72+
loss.backward()
73+
# after a lot of training
74+
75+
sampled_images = diffusion.sample(batch_size = 4)
76+
sampled_images.shape # (4, 3, 128, 128)
77+
```
78+
1179
## Citations
1280

1381
```bibtex

rin_pytorch/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from rin_pytorch.rin_pytorch import GaussianDiffusion, RIN, Trainer

0 commit comments

Comments
 (0)