@@ -8,6 +8,74 @@ Implementation of <a href="https://arxiv.org/abs/2212.11972">Recurrent Interface
88
99The big surprise is that the generations can reach this level of fidelity. Will need to verify this on my own machine
1010
11+ ## Install
12+
13+ ``` bash
14+ $ pip install rin-pytorch
15+ ```
16+
17+ ## Usage
18+
19+ ``` python
20+ from rin_pytorch import RIN , Trainer, GaussianDiffusion
21+
22+ model = RIN(
23+ dim = 32 ,
24+ channels = 3 ,
25+ dim_mults = (1 , 2 , 4 , 8 ),
26+ ).cuda()
27+
28+ diffusion = GaussianDiffusion(
29+ model,
30+ image_size = 128 ,
31+ timesteps = 100 ,
32+ use_ddim = True # use ddim
33+ ).cuda()
34+
35+ trainer = Trainer(
36+ diffusion,
37+ ' /path/to/your/data' , # path to your folder of images
38+ results_folder = ' ./results' , # where to save results
39+ num_samples = 16 , # number of samples
40+ train_batch_size = 4 , # training batch size
41+ gradient_accumulate_every = 4 , # gradient accumulation
42+ train_lr = 1e-4 , # learning rate
43+ save_and_sample_every = 1000 , # how often to save and sample
44+ train_num_steps = 700000 , # total training steps
45+ ema_decay = 0.995 , # exponential moving average decay
46+ )
47+
48+ trainer.train()
49+ ```
50+
51+ Results will be saved periodically to the ` ./results ` folder
52+
53+ If you would like to experiment with the ` RIN ` and ` GaussianDiffusion ` class outside the ` Trainer `
54+
55+ ``` python
56+ import torch
57+ from rin_pytorch import RIN , GaussianDiffusion
58+
59+ model = RIN(
60+ dim = 64 ,
61+ dim_mults = (1 , 2 , 4 , 8 )
62+ )
63+
64+ diffusion = GaussianDiffusion(
65+ model,
66+ image_size = 128 ,
67+ timesteps = 1000
68+ )
69+
70+ training_images = torch.randn(8 , 3 , 128 , 128 ) # images are normalized from 0 to 1
71+ loss = diffusion(training_images)
72+ loss.backward()
73+ # after a lot of training
74+
75+ sampled_images = diffusion.sample(batch_size = 4 )
76+ sampled_images.shape # (4, 3, 128, 128)
77+ ```
78+
1179## Citations
1280
1381``` bibtex
0 commit comments