@@ -19,30 +19,33 @@ $ pip install rin-pytorch
1919## Usage
2020
2121``` python
22- from rin_pytorch import RIN , Trainer, GaussianDiffusion
22+ from rin_pytorch import GaussianDiffusion, RIN , Trainer
2323
2424model = RIN(
25- dim = 32 ,
26- channels = 3 ,
27- dim_mults = (1 , 2 , 4 , 8 ),
25+ dim = 256 , # model dimensions
26+ image_size = 128 , # image size
27+ patch_size = 8 , # patch size
28+ num_latents = 128 , # number of latents. they used 256 in the paper
29+ latent_self_attn_depth = 4 , # number of latent self attention blocks per recurrent step, K in the paper
2830).cuda()
2931
3032diffusion = GaussianDiffusion(
3133 model,
3234 image_size = 128 ,
33- timesteps = 100 ,
34- use_ddim = True # use ddim
35+ use_ddim = False ,
36+ timesteps = 400 ,
37+ train_prob_self_cond = 0.9 # how often to self condition on latents
3538).cuda()
3639
3740trainer = Trainer(
3841 diffusion,
39- ' /path/to/your /data' , # path to your folder of images
40- results_folder = ' ./results ' , # where to save results
41- num_samples = 16 , # number of samples
42- train_batch_size = 4 , # training batch size
43- gradient_accumulate_every = 4 , # gradient accumulation
44- train_lr = 1e-4 , # learning rate
45- save_and_sample_every = 1000 , # how often to save and sample
42+ ' /home/phil/dl /data/flowers ' ,
43+ results_folder = ' ./rin ' ,
44+ num_samples = 16 ,
45+ train_batch_size = 4 ,
46+ gradient_accumulate_every = 4 ,
47+ train_lr = 1e-4 ,
48+ save_and_sample_every = 1000 ,
4649 train_num_steps = 700000 , # total training steps
4750 ema_decay = 0.995 , # exponential moving average decay
4851)
@@ -59,14 +62,18 @@ import torch
5962from rin_pytorch import RIN , GaussianDiffusion
6063
6164model = RIN(
62- dim = 64 ,
63- dim_mults = (1 , 2 , 4 , 8 )
64- )
65+ dim = 256 , # model dimensions
66+ image_size = 128 , # image size
67+ patch_size = 8 , # patch size
68+ num_latents = 128 , # number of latents. they used 256 in the paper
69+ latent_self_attn_depth = 4 , # number of latent self attention blocks per recurrent step, K in the paper
70+ ).cuda()
6571
6672diffusion = GaussianDiffusion(
6773 model,
6874 image_size = 128 ,
69- timesteps = 1000
75+ timesteps = 1000 ,
76+ train_prob_self_cond = 0.9
7077)
7178
7279training_images = torch.randn(8 , 3 , 128 , 128 ) # images are normalized from 0 to 1
0 commit comments